import torch
from torch import nn
import math
# params
batch_size = 3
d_model = 8
max_len = max_feat_len = 7
max_label_len = 5
n_head = 2
vocab_size = 100
drop_prob = 0.12
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
# ------------------------------------------------------
# 位置编码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len, device):
super(PositionalEncoding, self).__init__()
self.encoding = torch.zeros(max_len, d_model, device=device)
self.encoding.requires_grad = False # we don't need to compute gradient
pos = torch.arange(0, max_len, device=device)
pos = pos.float().unsqueeze(dim=1)
# 1D => 2D unsqueeze to represent word's position
_2i = torch.arange(0, d_model, step=2, device=device).float()
# 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
# "step=2" means 'i' multiplied with two (same with 2 * i)
self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
# compute positional encoding to consider positional information of words
def forward(self, x):
batch_size, seq_len = x.size()
return self.encoding[:seq_len, :]
# example
_x = torch.randint(low = 0, high = vocab_size, size = (batch_size, max_len), device=device)
print(_x)
_PositionalEncoding = PositionalEncoding(d_model, max_len, device)
pos_output = _PositionalEncoding(_x)
print(pos_output)
# ------------------------------------------------------
# embedding 化
class TokenEmbedding(nn.Embedding):
def __init__(self, vocab_size, d_model):
super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)
# example
_TokenEmbedding = TokenEmbedding(vocab_size, d_model).to(device)
token_output = _TokenEmbedding(_x)
print(token_output)
# ------------------------------------------------------
# 输入的embedding: token_embedding + pos_embedding
class TransformerEmbedding(nn.Module):
def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
super(TransformerEmbedding, self).__init__()
self.tok_emb = TokenEmbedding(vocab_size, d_model).to(device)
self.pos_emb = PositionalEncoding(d_model, max_len, device)
self.drop_out = nn.Dropout(p=drop_prob)
def forward(self, x):
tok_emb = self.tok_emb(x)
pos_emb = self.pos_emb(x)
return self.drop_out(tok_emb + pos_emb)
# example
_TransformerEmbedding = TransformerEmbedding(vocab_size, d_model, max_len, drop_prob, device)
emb_output = _TransformerEmbedding(_x)
print(emb_output)
# ------------------------------------------------------
# 构造左上矩阵为1的mask(padding_mask), 最后softmax之后:
# 有效的值为左上矩阵, 右上矩阵为0;
# 下面几行的矩阵的值=1/max_length
# 然后与v相乘: 前2行的元素为有效值; 后5行的元素值虽然非0, 但是为无效值
# 构造左上矩阵为1的mask(padding_mask) == 和下面博客介绍说的构造左侧矩阵为1的作用一致
# 第1个case的有效长度2, 构造的后5行(max_length-2 = 7-2)为0; 但是mask(q*k)后, 后5行的元素值均相同(-10000); 计算softmax时, 该批元素值也是均相同(=1/max_length)
# 有博客介绍说, 在计算loss时会通过torch.nn.functional.cross_entropy的ignore_idx参数设成 "<ignore>" 将这部分的loss去掉, 不纳入计算.
# 实际上我没有理解...???====
# https://zhuanlan.zhihu.com/p/648127076
# (1) encoder mask
# mask example
# example tensor
one_tensor = torch.ones((batch_size, max_len, max_len), device=device)
# effective lengths
effective_lengths = [2, 5, max_len]
# create a mask tensor
mask = torch.arange(one_tensor.size(1)).unsqueeze(0) >= torch.tensor(effective_lengths).unsqueeze(1)
# apply the mask
masked_tensor_1 = one_tensor.masked_fill(mask.unsqueeze(-1), 0) # 行维度: 1/0
masked_tensor_2 = torch.transpose(masked_tensor_1, 1, 2) # 列维度: 1/0
masked_tensor = (masked_tensor_1 + masked_tensor_2 > 1).to(torch.int)
print(masked_tensor)
# for split
masked_expanded_tensor = masked_tensor.unsqueeze(1).expand(-1, n_head, -1, -1)
print(masked_expanded_tensor)
# (2) decoder mask: 包含2个
# (2.1) decoder中的第1个mask-multi-head-attention 中的mask
# 说明: 该mask形式是: padding_mask u sequence_mask(上三角_mask)
# 上三角mask
tril_mask = torch.tril(one_tensor, diagonal=0)
tril_mask = (tril_mask + masked_tensor_1 > 1).to(torch.int)
print(tril_mask)
# for split
tril_mask_expanded = tril_mask.unsqueeze(1).expand(-1, n_head, -1, -1)
print(tril_mask_expanded)
# ---- (2.1 -2)
trg = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 6, 7, 0]])
trg_pad_idx=0
trg_pad_mask = (trg != trg_pad_idx).unsqueeze(1).unsqueeze(3)
print(trg_pad_mask)
trg_len = trg.shape[1]
print(trg_len)
trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor)
print(trg_sub_mask)
trg_mask = trg_pad_mask & trg_sub_mask
print(trg_mask)
# (2.2) encoder_decoder_mask: decoder中的第2个multi-head-attention需要使用到的mask (cross-mask)
# 计算decoder的query 与 encoder的key 的相似度
# 然后使用相似度获取 encoder的value 的综合信息
'''
query: [batch_size, max_label_len=5, dim]
key: [batch_size, max_feat_len=7, dim]
q * k_t = [batch_size, max_label_len=5, max_feat_len=7]
-> mask(q * k_t)
-> softmax(mask(q * k_t))
-> * v = [batch_size, max_label_len=5, dim]
'''
def get_enc_dec_mask(batch_size, max_feat_len, feat_lens, max_label_len, device):
attn_mask_tensor = torch.ones((batch_size, max_label_len, max_feat_len), device=device) # (b, seq_q, seq_k)
for i in range(batch_size):
attn_mask_tensor[i, :, feat_lens[i]:] = 0
return attn_mask_tensor
def get_enc_dec_mask(batch_size, max_feat_len, feat_lens, max_label_len, device):
attn_mask_tensor = torch.ones((batch_size, max_label_len, max_feat_len), device=device) # (b, seq_q, seq_k)
mask = torch.arange(attn_mask_tensor.size(2)).unsqueeze(0) >= torch.tensor(feat_lens).unsqueeze(1)
mask = mask.to(device)
masked_tensor = attn_mask_tensor.masked_fill(mask.unsqueeze(1), 0)
return masked_tensor
# 其中: max_feat_len 代表encoder的最长信息; max_label_len 代表decoder的最长信息
enc_dec_mask = get_enc_dec_mask(batch_size, max_feat_len, effective_lengths, max_label_len, device)
print(enc_dec_mask)
# ------------------------------------------------------
# 每个独立的head, 分别计算如下步骤:
# dot_value = q*k_t / sqrt(dim)
# -> mask_dot = mask(dot_value)
# -> score = softmax(mask_dot) * v
class ScaleDotProductAttention(nn.Module):
def __init__(self):
super(ScaleDotProductAttention, self).__init__()
self.softmax = nn.Softmax(dim=-1)
def forward(self, q, k, v, mask=None, e=1e-12):
# input is 4 dimension tensor
# [batch_size, head, length, d_tensor]
batch_size, head, length, d_tensor = k.size()
# 1. dot product Query with Key^T to compute similarity
k_t = k.transpose(2, 3) # transpose
score = (q @ k_t) / math.sqrt(d_tensor) # scaled dot product
print(score)
print(score.shape)
# 2. apply masking (opt)
if mask is not None:
score = score.masked_fill(mask == 0, -10000)
print(score)
# 3. pass them softmax to make [0, 1] range
score = self.softmax(score)
# 4. multiply with Value
v = score @ v
return v, score
def split_tensor_function(n_head, input_tensor):
"""
split tensor by number of head
:param tensor: [batch_size, length, d_model]
:return: [batch_size, head, length, d_tensor]
"""
batch_size, length, d_model = input_tensor.size()
d_tensor = d_model // n_head
output_tensor = input_tensor.view(batch_size, length, n_head, d_tensor).transpose(1, 2)
# it is similar with group convolution (split by number of heads)
return output_tensor
# example
split_tensor = split_tensor_function(n_head, emb_output)
print(split_tensor)
masked_expanded_tensor = masked_expanded_tensor.to(split_tensor.device)
print(masked_expanded_tensor)
# shape: [batch_size, n_head, max_length, max_length]
_ScaleDotProductAttention = ScaleDotProductAttention()
_ScaleDotProductAttention.to(device)
scale_output = _ScaleDotProductAttention(split_tensor, split_tensor, split_tensor, masked_expanded_tensor)
concat_scale_output, weight_score = scale_output
# ------------------------------------------------------
# 分头计算, 然后合并
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_head):
super(MultiHeadAttention, self).__init__()
self.n_head = n_head
self.attention = ScaleDotProductAttention()
self.w_q = nn.Linear(d_model, d_model)
self.w_k = nn.Linear(d_model, d_model)
self.w_v = nn.Linear(d_model, d_model)
self.w_concat = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
# 1. dot product with weight matrices
q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
# 2. split tensor by number of heads
q, k, v = self.split(q), self.split(k), self.split(v)
# 3. do scale dot product to compute similarity
out, attention = self.attention(q, k, v, mask=mask)
# 4. concat and pass to linear layer
out = self.concat(out)
out = self.w_concat(out)
# 5. visualize attention map
# TODO : we should implement visualization
return out
def split(self, tensor):
"""
split tensor by number of head
:param tensor: [batch_size, length, d_model]
:return: [batch_size, head, length, d_tensor]
"""
batch_size, length, d_model = tensor.size()
d_tensor = d_model // self.n_head
tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
# it is similar with group convolution (split by number of heads)
return tensor
def concat(self, tensor):
"""
inverse function of self.split(tensor : torch.Tensor)
:param tensor: [batch_size, head, length, d_tensor]
:return: [batch_size, length, d_model]
"""
batch_size, head, length, d_tensor = tensor.size()
d_model = head * d_tensor
tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
return tensor
_MultiHeadAttention = MultiHeadAttention(d_model, n_head)
_MultiHeadAttention.to(device)
output_MultiHeadAttention = _MultiHeadAttention(emb_output, emb_output, emb_output, masked_expanded_tensor)
print(output_MultiHeadAttention)
# ------------------------------------------------------
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, hidden, drop_prob=0.1):
super(PositionwiseFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, hidden)
self.linear2 = nn.Linear(hidden, d_model)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=drop_prob)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.linear2(x)
return x
class LayerNorm(nn.Module):
def __init__(self, d_model, eps=1e-12):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(d_model))
self.beta = nn.Parameter(torch.zeros(d_model))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
var = x.var(-1, unbiased=False, keepdim=True)
# '-1' means last dimension.
out = (x - mean) / torch.sqrt(var + self.eps)
out = self.gamma * out + self.beta
return out
# ------------------------------------------------------
# encoder只有一个multi-head-attention
# 其中mask是src_mask, 属于padding_mask (解决batch中训练数据长度不一样)
class EncoderLayer(nn.Module):
def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
super(EncoderLayer, self).__init__()
self.attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
self.norm1 = LayerNorm(d_model=d_model)
self.dropout1 = nn.Dropout(p=drop_prob)
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
self.norm2 = LayerNorm(d_model=d_model)
self.dropout2 = nn.Dropout(p=drop_prob)
def forward(self, x, src_mask):
# 1. compute self attention
_x = x
x = self.attention(q=x, k=x, v=x, mask=src_mask)
# 2. add and norm
x = self.dropout1(x)
x = self.norm1(x + _x)
# 3. positionwise feed forward network
_x = x
x = self.ffn(x)
# 4. add and norm
x = self.dropout2(x)
x = self.norm2(x + _x)
return x
# decoder有2个multi-head-attention
# 第1个: 结构和encoder一样
# 只是mask类型不一样: 是padding_mask u 上三角_mask(sequence_mask)
# 第2个: q来自decoder, k/v 来自 encoder
class DecoderLayer(nn.Module):
def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
super(DecoderLayer, self).__init__()
self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
self.norm1 = LayerNorm(d_model=d_model)
self.dropout1 = nn.Dropout(p=drop_prob)
self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
self.norm2 = LayerNorm(d_model=d_model)
self.dropout2 = nn.Dropout(p=drop_prob)
self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
self.norm3 = LayerNorm(d_model=d_model)
self.dropout3 = nn.Dropout(p=drop_prob)
def forward(self, dec, enc, trg_mask, src_mask):
# 1. compute self attention
_x = dec
x = self.self_attention(q=dec, k=dec, v=dec, mask=trg_mask)
# 2. add and norm
x = self.dropout1(x)
x = self.norm1(x + _x)
if enc is not None:
# 3. compute encoder - decoder attention
_x = x
x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)
# 4. add and norm
x = self.dropout2(x)
x = self.norm2(x + _x)
# 5. positionwise feed forward network
_x = x
x = self.ffn(x)
# 6. add and norm
x = self.dropout3(x)
x = self.norm3(x + _x)
return x
1. Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
2. 代码参考: https://github.com/hyunwoongko/transformer/blob/master/models/blocks/encoder_layer.py
3. mask方式参考: 三万字最全解析!从零实现Transformer: https://zhuanlan.zhihu.com/p/648127076
4. mask torch方式: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
5. decoder并行方式理解: https://zhuanlan.zhihu.com/p/368592551