基于RNN和Transformer的词级语言建模 代码分析 _generate_square_subsequent_mask

基于RNN和Transformer的词级语言建模 代码分析 _generate_square_subsequent_mask

flyfish

Word-level Language Modeling using RNN and Transformer

word_language_model

PyTorch 提供的 word_language_model 示例展示了如何使用循环神经网络RNN(GRU或LSTM)和 Transformer 模型进行词级语言建模 。默认情况下,训练使用Wikitext-2数据集,generate.py可以使用训练好的模型来生成新文本。

源码地址
https://github.com/pytorch/examples/tree/main/word_language_model

文件:model.py

import torch
import matplotlib.pyplot as plt
import numpy as np

def _generate_square_subsequent_mask(sz):
    return torch.log(torch.tril(torch.ones(sz, sz)))

# 设置矩阵大小
sz = 5
mask = _generate_square_subsequent_mask(sz)

# 将 mask 转换为 numpy 数组,方便可视化
mask_np = mask.numpy()

# 可视化
plt.imshow(mask_np, cmap='viridis')
plt.colorbar()
plt.title("Square Subsequent Mask")
plt.show()

可视化图示
在可视化结果中,你会看到一个下三角矩阵,其值为 0 的部分为下三角部分,值为负无穷的部分为上三角部分。图像中通常负无穷会被显示为一种不同的颜色。

这样,你可以直观地理解生成的掩码矩阵的结构和作用。这个掩码矩阵主要用于 Transformer 模型中,以确保模型在预测时只能看到当前时刻及之前的时刻信息,而不能看到未来的信息。
在这里插入图片描述
结果
运行这段代码,你会看到一个 5x5 的矩阵,其中下三角部分是 0(因为 log(1) = 0),上三角部分是负无穷(由于 log(0) 是负无穷)。

def _generate_square_subsequent_mask(sz):
    return torch.log(torch.tril(torch.ones(sz, sz)))
# 设置矩阵大小
sz = 5
mask = _generate_square_subsequent_mask(sz)

# 打印矩阵
print(mask)

输出

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

在数学上,定义对数函数时,log(0) 是未定义的,但在计算中,我们处理这种情况的方式是认为 log(0) 的极限值是负无穷。因此,计算机通常会返回负无穷来表示这种情况。

在 PyTorch 中,torch.log(0) 的结果是 -inf(负无穷)。这是因为对数函数是单调递增的,并且在接近0时值会急剧下降到负无穷。

最近更新

  1. TCP协议是安全的吗?

    2024-06-07 00:40:04       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-07 00:40:04       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-07 00:40:04       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-07 00:40:04       20 阅读

热门阅读

  1. 求二叉树第k层结点的个数--c++【做题记录】

    2024-06-07 00:40:04       10 阅读
  2. npm:Node.js包管理器的使用指南

    2024-06-07 00:40:04       7 阅读
  3. 【机器学习】之 kmean算法原理及实现

    2024-06-07 00:40:04       10 阅读
  4. DVWA-CSRF

    DVWA-CSRF

    2024-06-07 00:40:04      8 阅读
  5. 算法学习笔记——对数器

    2024-06-07 00:40:04       8 阅读
  6. 递推7-2 sdut-C语言实验-养兔子分数

    2024-06-07 00:40:04       5 阅读
  7. MacBook M系列芯片安装php8.2

    2024-06-07 00:40:04       11 阅读