python-pytorch 下批量seq2seq+Bahdanau Attention实现问答1.0.000

前言

前面实现了 luong的dot 、general、concat注意力实现简单问答,这里参考官方文档,实现了python-pytorch 下批量seq2seq+Bahdanau Attention实现问答

原理看图

在这里插入图片描述
这里模型选择和官方不一样,官方选择的是GRU,我更喜欢使用LSTM,解码器和编码器都是如此。
意思大致思路是:

  1. 计算encoder的encoder_outputs、encoder_hn、encoder_cn
  2. 使用encoder_outputs、encoder_hn计算新的向量和注意力
  3. 在deconder中,以SOS单字开始,循环句子最大长度,在循环中,使用新的向量和单字SOS做cat计算得到decoder的LSTM输入数据,将该LSTM存起来,最后做cat计算得到decoder的输出

数据准备

结果类似还是采用前面的结构和数据

seq_example = [“你认识我吗”, “你住在哪里”, “你知道我的名字吗”, “你是谁”, “你会唱歌吗”, “谁是张学友”]
seq_answer = [“当然认识”, “我住在成都”, “我不知道”, “我是机器人”, “我不会”, “她旁边那个就是”]

分词、index2word、word2index、vocab_size

分词然后做基础准备,包括数据:index2word、word2index、vocab_size、最长的句子长度seq_length,和一些超参数的设置

输入模型的数据构造

  1. 长度要统一
  2. 问答的句子以EOS结尾,不足补0,如

tensor([[ 3, 4, 5, 6, 2, 0, 0],
[ 3, 7, 8, 9, 2, 0, 0],
[ 3, 10, 5, 11, 12, 6, 2],
[ 3, 13, 14, 2, 0, 0, 0],
[ 3, 15, 16, 6, 2, 0, 0],
[14, 13, 17, 2, 0, 0, 0]])

注意力模型

可以复用,用官方的即可

# Bahdanau
# query=hidden [layer_num,batch_size,hidden_size] keys=encoder_outputs  [seq_len,batch_size,hidden_size]
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys))) #[seq_len,batch_size,1]
        scores = scores.permute(1,0,2).squeeze(2).unsqueeze(1)#[batch_size,1,seq_len]

        weights = nn.functional.softmax(scores, dim=-1)#[batch_size,1,seq_len]
        context = torch.bmm(weights, keys.permute(1,0,2))#[batch_size,1,hidden_size]

        return context, weights

decoder的编写

思路是,获得encoder的输出和hn后,计算得到向量,然后使用向量和目标的每一个字做cat计算,输入decoder的模型中,然后得出一个字的预测,循环完了以后,就会得到最大句子长度,最后做cat和softmax计算得到输出。另外,这里要区分训练和测试,训练的时候有target,测试的没有target数据。

关于损失函数和优化器

NLLLoss+Adam的组合优于CrossEntropyLoss+SGD的组合

在预测时

获取到模型输出,size是[batch_size,seq_len,vocab_size]后,对结果做topk计算,会得到每一字在vocab_size的概率,连接起来就是一句话

完整代码

# def getAQ():
#     ask=[]
#     answer=[]
#     with open("./data/flink.txt","r",encoding="utf-8") as f:
#         lines=f.readlines()
#         for line in lines:
#             ask.append(line.split("----")[0])
#             answer.append(line.split("----")[1].replace("\n",""))
#     return answer,ask

# seq_answer,seq_example=getAQ()



import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import os
from tqdm import tqdm
 
seq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "谁是张学友"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "她旁

相关推荐

  1. python-pytorch seq2seq+attention笔记0.5.00

    2024-05-26 03:34:37       11 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-05-26 03:34:37       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-26 03:34:37       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-26 03:34:37       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-26 03:34:37       18 阅读

热门阅读

  1. 黄金价格创新高,交易风险提示

    2024-05-26 03:34:37       12 阅读
  2. gateway基本配置

    2024-05-26 03:34:37       10 阅读
  3. 时政|杂粮产业

    2024-05-26 03:34:37       13 阅读
  4. MYSQL--多表查询

    2024-05-26 03:34:37       9 阅读
  5. Gopeed的高级用法

    2024-05-26 03:34:37       13 阅读
  6. GitLab的原理及应用详解(四)

    2024-05-26 03:34:37       11 阅读
  7. 揭秘软件测试工程师:事业前景和成功秘诀

    2024-05-26 03:34:37       10 阅读
  8. 前端面试题日常练-day33 【面试题】

    2024-05-26 03:34:37       13 阅读