transformer学习

import torch
from torch import nn
import math


# params
batch_size = 3
d_model = 8
max_len = max_feat_len = 7
max_label_len = 5
n_head = 2
vocab_size = 100
drop_prob = 0.12


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)



# ------------------------------------------------------
# 位置编码

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # we don't need to compute gradient
        
        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D unsqueeze to represent word's position
        
        _2i = torch.arange(0, d_model, step=2, device=device).float()
        # 'i' means index of d_model (e.g. embedding size = 50, 'i' = [0,50])
        # "step=2" means 'i' multiplied with two (same with 2 * i)
        
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # compute positional encoding to consider positional information of words
        
    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :]


# example

_x = torch.randint(low = 0, high = vocab_size, size = (batch_size, max_len), device=device)
print(_x)


_PositionalEncoding = PositionalEncoding(d_model, max_len, device)
pos_output = _PositionalEncoding(_x)
print(pos_output)


# ------------------------------------------------------

# embedding 化

class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, d_model):
        super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)


# example

_TokenEmbedding = TokenEmbedding(vocab_size, d_model).to(device)
token_output = _TokenEmbedding(_x)
print(token_output)


# ------------------------------------------------------

# 输入的embedding: token_embedding + pos_embedding

class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model).to(device)
        self.pos_emb = PositionalEncoding(d_model, max_len, device)
        self.drop_out = nn.Dropout(p=drop_prob)

    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.drop_out(tok_emb + pos_emb)
    

# example

_TransformerEmbedding = TransformerEmbedding(vocab_size, d_model, max_len, drop_prob, device)
emb_output = _TransformerEmbedding(_x)
print(emb_output)


# ------------------------------------------------------

# 构造左上矩阵为1的mask(padding_mask), 最后softmax之后:
#     有效的值为左上矩阵, 右上矩阵为0; 
#     下面几行的矩阵的值=1/max_length
# 然后与v相乘: 前2行的元素为有效值; 后5行的元素值虽然非0, 但是为无效值

# 构造左上矩阵为1的mask(padding_mask) == 和下面博客介绍说的构造左侧矩阵为1的作用一致
# 第1个case的有效长度2, 构造的后5行(max_length-2 = 7-2)为0; 但是mask(q*k)后, 后5行的元素值均相同(-10000); 计算softmax时, 该批元素值也是均相同(=1/max_length)

# 有博客介绍说, 在计算loss时会通过torch.nn.functional.cross_entropy的ignore_idx参数设成 "<ignore>" 将这部分的loss去掉, 不纳入计算.
# 实际上我没有理解...???====
# https://zhuanlan.zhihu.com/p/648127076

# (1) encoder mask
# mask example
# example tensor
one_tensor = torch.ones((batch_size, max_len, max_len), device=device)

# effective lengths
effective_lengths = [2, 5, max_len]

# create a mask tensor
mask = torch.arange(one_tensor.size(1)).unsqueeze(0) >= torch.tensor(effective_lengths).unsqueeze(1)

# apply the mask
masked_tensor_1 = one_tensor.masked_fill(mask.unsqueeze(-1), 0) # 行维度: 1/0
masked_tensor_2 = torch.transpose(masked_tensor_1, 1, 2)        # 列维度: 1/0
masked_tensor = (masked_tensor_1 + masked_tensor_2 > 1).to(torch.int)

print(masked_tensor)

# for split
masked_expanded_tensor = masked_tensor.unsqueeze(1).expand(-1, n_head, -1, -1)
print(masked_expanded_tensor)



# (2) decoder mask: 包含2个
# (2.1) decoder中的第1个mask-multi-head-attention 中的mask
# 说明: 该mask形式是: padding_mask u sequence_mask(上三角_mask)

# 上三角mask
tril_mask = torch.tril(one_tensor, diagonal=0)
tril_mask = (tril_mask + masked_tensor_1 > 1).to(torch.int)
print(tril_mask)

# for split
tril_mask_expanded = tril_mask.unsqueeze(1).expand(-1, n_head, -1, -1)
print(tril_mask_expanded)


# ---- (2.1 -2)
trg = torch.tensor([[1, 2, 3, 0, 0], [4, 5, 6, 7, 0]])
trg_pad_idx=0

trg_pad_mask = (trg != trg_pad_idx).unsqueeze(1).unsqueeze(3)
print(trg_pad_mask)

trg_len = trg.shape[1]
print(trg_len)

trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor)
print(trg_sub_mask)

trg_mask = trg_pad_mask & trg_sub_mask
print(trg_mask)




# (2.2) encoder_decoder_mask: decoder中的第2个multi-head-attention需要使用到的mask (cross-mask)
# 计算decoder的query 与 encoder的key 的相似度
# 然后使用相似度获取 encoder的value 的综合信息

'''
query: [batch_size, max_label_len=5, dim]
key: [batch_size, max_feat_len=7, dim]

q * k_t = [batch_size, max_label_len=5, max_feat_len=7]

-> mask(q * k_t)
-> softmax(mask(q * k_t))
-> * v = [batch_size, max_label_len=5, dim]
'''
def get_enc_dec_mask(batch_size, max_feat_len, feat_lens, max_label_len, device):
    attn_mask_tensor = torch.ones((batch_size, max_label_len, max_feat_len), device=device)       # (b, seq_q, seq_k)
    for i in range(batch_size):
        attn_mask_tensor[i, :, feat_lens[i]:] = 0
    return attn_mask_tensor

def get_enc_dec_mask(batch_size, max_feat_len, feat_lens, max_label_len, device):
    attn_mask_tensor = torch.ones((batch_size, max_label_len, max_feat_len), device=device)       # (b, seq_q, seq_k)
    mask = torch.arange(attn_mask_tensor.size(2)).unsqueeze(0) >= torch.tensor(feat_lens).unsqueeze(1)
    mask = mask.to(device)
    masked_tensor = attn_mask_tensor.masked_fill(mask.unsqueeze(1), 0)
    return masked_tensor

# 其中: max_feat_len 代表encoder的最长信息; max_label_len 代表decoder的最长信息

enc_dec_mask = get_enc_dec_mask(batch_size, max_feat_len, effective_lengths, max_label_len, device)
print(enc_dec_mask)


# ------------------------------------------------------

# 每个独立的head, 分别计算如下步骤: 
# dot_value = q*k_t / sqrt(dim) 
# -> mask_dot = mask(dot_value) 
# -> score = softmax(mask_dot) * v

class ScaleDotProductAttention(nn.Module):

    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None, e=1e-12):
        # input is 4 dimension tensor
        # [batch_size, head, length, d_tensor]
        batch_size, head, length, d_tensor = k.size()

        # 1. dot product Query with Key^T to compute similarity
        k_t = k.transpose(2, 3)  # transpose
        score = (q @ k_t) / math.sqrt(d_tensor)  # scaled dot product
        print(score)
        print(score.shape)

        # 2. apply masking (opt)
        if mask is not None:
            score = score.masked_fill(mask == 0, -10000)
            
        print(score)

        # 3. pass them softmax to make [0, 1] range
        score = self.softmax(score)

        # 4. multiply with Value
        v = score @ v

        return v, score
    

def split_tensor_function(n_head, input_tensor):
    """
    split tensor by number of head

    :param tensor: [batch_size, length, d_model]
    :return: [batch_size, head, length, d_tensor]
    """
    batch_size, length, d_model = input_tensor.size()

    d_tensor = d_model // n_head
    output_tensor = input_tensor.view(batch_size, length, n_head, d_tensor).transpose(1, 2)
    # it is similar with group convolution (split by number of heads)

    return output_tensor


# example

split_tensor = split_tensor_function(n_head, emb_output)
print(split_tensor)

masked_expanded_tensor = masked_expanded_tensor.to(split_tensor.device)
print(masked_expanded_tensor)
# shape: [batch_size, n_head, max_length, max_length]


_ScaleDotProductAttention = ScaleDotProductAttention()
_ScaleDotProductAttention.to(device)

scale_output = _ScaleDotProductAttention(split_tensor, split_tensor, split_tensor, masked_expanded_tensor)
concat_scale_output, weight_score = scale_output


# ------------------------------------------------------

# 分头计算, 然后合并

class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_concat = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        # 1. dot product with weight matrices
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)

        # 2. split tensor by number of heads
        q, k, v = self.split(q), self.split(k), self.split(v)

        # 3. do scale dot product to compute similarity
        out, attention = self.attention(q, k, v, mask=mask)

        # 4. concat and pass to linear layer
        out = self.concat(out)
        out = self.w_concat(out)

        # 5. visualize attention map
        # TODO : we should implement visualization

        return out

    def split(self, tensor):
        """
        split tensor by number of head

        :param tensor: [batch_size, length, d_model]
        :return: [batch_size, head, length, d_tensor]
        """
        batch_size, length, d_model = tensor.size()

        d_tensor = d_model // self.n_head
        tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1, 2)
        # it is similar with group convolution (split by number of heads)

        return tensor

    def concat(self, tensor):
        """
        inverse function of self.split(tensor : torch.Tensor)

        :param tensor: [batch_size, head, length, d_tensor]
        :return: [batch_size, length, d_model]
        """
        batch_size, head, length, d_tensor = tensor.size()
        d_model = head * d_tensor

        tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
        return tensor
    
    
_MultiHeadAttention = MultiHeadAttention(d_model, n_head)
_MultiHeadAttention.to(device)

output_MultiHeadAttention = _MultiHeadAttention(emb_output, emb_output, emb_output, masked_expanded_tensor)
print(output_MultiHeadAttention)


# ------------------------------------------------------

class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, hidden, drop_prob=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(d_model, hidden)
        self.linear2 = nn.Linear(hidden, d_model)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x
    

class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        # '-1' means last dimension. 

        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out
    

# ------------------------------------------------------

# encoder只有一个multi-head-attention
# 其中mask是src_mask, 属于padding_mask (解决batch中训练数据长度不一样)

class EncoderLayer(nn.Module):

    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
        self.norm1 = LayerNorm(d_model=d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNorm(d_model=d_model)
        self.dropout2 = nn.Dropout(p=drop_prob)

    def forward(self, x, src_mask):
        # 1. compute self attention
        _x = x
        x = self.attention(q=x, k=x, v=x, mask=src_mask)
        
        # 2. add and norm
        x = self.dropout1(x)
        x = self.norm1(x + _x)
        
        # 3. positionwise feed forward network
        _x = x
        x = self.ffn(x)
      
        # 4. add and norm
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        return x


# decoder有2个multi-head-attention
# 第1个: 结构和encoder一样
#        只是mask类型不一样: 是padding_mask u 上三角_mask(sequence_mask)
# 第2个: q来自decoder, k/v 来自 encoder


class DecoderLayer(nn.Module):

    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
        self.norm1 = LayerNorm(d_model=d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)

        self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
        self.norm2 = LayerNorm(d_model=d_model)
        self.dropout2 = nn.Dropout(p=drop_prob)

        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm3 = LayerNorm(d_model=d_model)
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, dec, enc, trg_mask, src_mask):
        # 1. compute self attention
        _x = dec
        x = self.self_attention(q=dec, k=dec, v=dec, mask=trg_mask)
        
        # 2. add and norm
        x = self.dropout1(x)
        x = self.norm1(x + _x)

        if enc is not None:
            # 3. compute encoder - decoder attention
            _x = x
            x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=src_mask)
            
            # 4. add and norm
            x = self.dropout2(x)
            x = self.norm2(x + _x)

        # 5. positionwise feed forward network
        _x = x
        x = self.ffn(x)
        
        # 6. add and norm
        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x
    

1. Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
2. 代码参考: https://github.com/hyunwoongko/transformer/blob/master/models/blocks/encoder_layer.py
3. mask方式参考: 三万字最全解析!从零实现Transformer: https://zhuanlan.zhihu.com/p/648127076
4. mask torch方式: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
5. decoder并行方式理解: https://zhuanlan.zhihu.com/p/368592551

相关推荐

  1. transformer学习

    2024-04-21 16:30:02       11 阅读
  2. Transformer学习(一)

    2024-04-21 16:30:02       40 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-21 16:30:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-21 16:30:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-21 16:30:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-21 16:30:02       20 阅读

热门阅读

  1. 数据结构递归算法总结

    2024-04-21 16:30:02       13 阅读
  2. VSCode和CMake实现C/C++开发

    2024-04-21 16:30:02       16 阅读
  3. 小记一篇 vuecli4项目 打包内存溢出问题

    2024-04-21 16:30:02       12 阅读
  4. 线上出现问题后如何排查呢

    2024-04-21 16:30:02       13 阅读
  5. C及C++标准与QT版本介绍

    2024-04-21 16:30:02       13 阅读