pytorch诗词生成3--创建模型

先上代码:


import tensorflow as tf
from dataset import tokenizer

# 构建模型
model = tf.keras.Sequential([
    # 不定长度的输入
    tf.keras.layers.Input((None,)),
    # 词嵌入层
    tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),
    # 第一个LSTM层,返回序列作为下一层的输入
    tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
    # 第二个LSTM层,返回序列作为下一层的输入
    tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),
    # 对每一个时间点的输出都做softmax,预测下一个词的概率
    tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),
])

# 查看模型结构
model.summary()
# 配置优化器和损失函数
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)

下面我们进行分析:

首先我们导入所需要的库,我们经常用到tensorflow中的keras库。

我们直接创建模型。首先看第一个参数:

tf.keras.layers.Input((None,)),

这是Keras中用于定义模型输入的一种方式,在这段代码中tf.keras.Input层被用于定义模型的输入,(None,)是一个元组,用于定义输入的形状,具体来说,(None,)表示输入是一个一维向量,其长度可以是可变的。

None表示在这个维度上可以接受任意长度的输入数据,在模型训练和推理时,可以将不同长度的输入传递给模型,而不需要固定长度的限制。

,表示这个元组只有一个元素,在这个例子中,这个元组只有一个维度,即序列的长度。
总的来说,tf.keras.layers.Input((None,))表示创建了一个输入层,该层接受一个不定长度的一维输入序列,这种方式可以灵活的处理不同长度的序列数据。

接着我们来看第二个参数:

tf.keras.layers.Embedding(input_dim=tokenizer.vocab_size, output_dim=128),

这是keras中的嵌入层,用于将离散的符号表示(例如整数序列)转化为连续的向量表示。

input_dim参数指定了嵌入层的输入维度,即词汇表的大小,在这个例子中,tokenizer.vocab_size表示词汇表的大小,即有多少个不同的符号需要进行嵌入。
output_dim参数指定了嵌入向量的维度,在这个例子中,output_dim=128表示嵌入层将输入的每个符号映射为一个128维的向量。

总的来说,这段代码用于创建一个嵌入层,将离散的符号序列转化为包含更多权重的向量,每个符号都用更多的权重表示,从而更好的表示符号之间的语义关系。我们的输出形状是input_dim.shape*output_dim.shape。

接下来我们看神经网络:

 tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),

首先我们需要明确一个概念,LSTM(长短期记忆)是一种循环网络(RNN)的变体,用于处理序列数据。LSTM层解决了传统的循环神经网络(RNN)在处理长期依赖关系时出现的梯度消失和梯度爆炸的问题。LSTM层的核心是LSTM单元,它包含一个记忆单元和一些门控机制。记忆单元可以存储和访问过去的信息并通过门控机制来控制信息的流动。

LSTM单元的主要组成部分如下:

输入门:决定是否将新的记忆合并到记忆单元中。
遗忘门:决定是否从记忆单元中忘记一些信息。
输出门:决定记忆单元中的信息如何输出到下一层或输出层。

LSTM层的工作原理如下:

对于每个时间步,LSTM层接受一个输入和前一时间步的隐藏状态,并通过一系列的操作产生一个新的隐藏状态作为输出。
输入和前一时间步的隐藏状态被用于计算输入门,遗忘门和输出门的激活值。
输入门决定哪些信息应该被添加到记忆单元中,遗忘门决定哪些信息应该从记忆单元中丢弃,输出门决定哪些信息应该输出到下一层或者输出层。
记忆单元根据输入门和遗忘门的激活值来更新其内容。
最后,根据输出门的激活值,记忆单元的内容被输入到下一层或者输出层,并成为当前时间步的隐藏状态。

通过使用输入门,遗忘门和输出门的门控机制,LSTM层可以有效的处理长期以来关系,同时减轻梯度消失和梯度爆炸问题,这使得LSTM在处理各种序列数据时,如自然语言处理任务中的文本,时间序列预测等方面非常有效。

有点偏题了,回到我们对代码的解析。

tf.keras.layer.LSTM函数是TensorFlow中keras API提供的LSTM层函数,用于构建LSTM(长短期记忆)模型。

我们的第一个参数表示的是LSTM层的输出维度,即隐藏状态维度,在这个例子中,LSTM层输出一个大小为128的隐藏状态。
dropout=0.5表示在训练过程中,每个时间步的输入都有50%的概率被设置为0,用于减少过拟合。
return_sequences=True表示将该层返回的完整序列作为下一层的输入,默认情况下,LSTM层只返回最后一个时间步的输入,但是在某些情况下,需要将LSTM层的每个时间步的输入都传递给下一层。

下面来详细解释一下:return_sequences=True 

默认情况下该值是False,表示只返回LSTM最后一个时间步的输出,某些情况下是有用的,比如序列分类任务,其中只需要一个固定长度的输出来预测。
然而,在一些任务中,需要用到LSTM层的完整输出序列,而不仅仅是最后一个时间步的输出,这些任务包括:

序列到序列的任务,例如机器翻译,其中输入和输出都是序列。
序列标注任务,如命名实体识别或语音识别,需要对序列的每个位置进行标注。
情感分析或情绪识别任务,需要对序列中每个时间步进行情感或情绪分析。

我们在写诗的时候要考虑到包含完整语句的信息,所以这里我们使用True,保留整个序列的信息。

综上,上面代码的作用是:
创建一个具有128个隐藏单元的LSTM层。该层将接受一个序列作为输入,并返回一个具有相同长度的序列作为输出。每个时间步的输入都会经过LSTM单元的计算,并产生一个隐藏状态作为输出,同时,通过dropout机制可以减少过拟合问题。

那么问题来了,为什么需要使用连续两个

tf.keras.layers.LSTM(128, dropout=0.5, return_sequences=True),

:连续使用两个该函数适用于构建一个深层的LSTM网络,深层的LSTM网络可以更好的捕捉到数据中的复杂模式和长期依赖关系,通过堆叠多个LSTM层,每一层都可以对输入序列进行进一步的抽象和表示学习,以提取高层次的信息。在连续使用两个LSTM层的情况下,第一个LSTM层的输出序列将成为第二个LSTM层的输入序列。这样第二个LSTM层可以进一步处理第一个LSTM层的输出,从而进行更深层次的特征学习。(这要求保留完整的序列输出,也就是return_sequences要求设置为True)。

通过堆叠多个LSTM层,可以逐渐提高模型的表示能力,并且在处理复杂序列数据时能提供更好的性能,然后,需要根据具体任务和数据集的特点进行实验和调整,以确定最佳的网络结构和层数。

看下一个:

tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')),

这是keras中的一个包装器层(wrapper)。用于将指定的层应用于输入序列的每个时间步。

在这段代码中,tf.keras.layers.TimeDistributed被用于包装一个tf.keras.layers.Dense层,并将其应用于输入序列的每个时间步。

  • tf.keras.layers.Dense是一个全链接层(密集)层,用于将输入数据映射到指定数量的输出单元,在这个例子中,Dense层的输出单元数量被设置为tokenizer.vocab_size,即词汇表的大小。
  • activation='softmax'指定了激活函数是softmax函数,用于将输出转化为概率分布,以便进行多分类问题的预测。
  • 通过使用tf.keras.layers.TimeDistributed包装Dense层,可以确保Dense层被用于输入序列的每个时间步,并生成相同时间步数的输出序列。

这在处理序列数据时非常有用,特别是在序列到序列的任务中,其中需要对每个时间步的输入进行独立的处理和预测,而不仅仅是最后一个时间步输出。

值得注意的是,tf.keras.layers.Dense(tokenizer.vocab_size, activation='softmax')创建了一个全连接层,作用是将输入数据映射到由tokenizer.vocab_size指定的输出单元数量。并指定该层的激活函数使用softmax函数。(用于处理多分类问题)。

tf.keras.layers.TimeDistributed是一个包装器,表示将层用于每一个时间步。
来看一个例子:如果我们有一个输入形状为(batch_size,timesteps,intput_dim)的3D张量,其中timesteps表示时间步的数量,input_dim是每个时间步的输入维度,将一个层包装在TimeDistributed中,就可以对输出序列的每个时间步应用该层,并生成一个形状相同的输出序列。

之后我们查看模型结构:

model.summary()

这段代码用于打印模型的结构摘要,显示每个层的名称,输出形状和参数数量。

如下图:

最后:

model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.categorical_crossentropy)

最后我们使用keras中的model.compile方法来配置模型的训练过程。

  • optimizer=tf.keras.optimizers.Adam()指定了优化器为Adam优化器。Adam是一种常用的优化算法,用于根据损失函数的梯度来更新模型的权重,它具有自适应学习率的特性,能够在训练过程中自动调整学习率,从而提高模型的收敛性和性能。
  • loss=tf.keras.losses.categorical_crossentropy指定了损失函数为分类交叉熵损失。

通过调用model.compile方法并传递优化器和损失函数,可以将它们与模型相关联,并在训练过程中使用它们来计算梯度和更新模型的权重。

一旦模型经过编译,就可以使用model.fit方法来训练模型,并根据指定的优化器和损失函数来更新模型的权重,以最小化损失并提高模型的性能。

相关推荐

  1. pytorch创建模型方式

    2024-03-14 09:54:03       29 阅读
  2. 3.创建模式--创建者模式

    2024-03-14 09:54:03       14 阅读
  3. 从零实现诗词GPT大模型:了解Transformer架构

    2024-03-14 09:54:03       14 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-03-14 09:54:03       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-14 09:54:03       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-14 09:54:03       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-14 09:54:03       20 阅读

热门阅读

  1. Redis实现全局唯一id

    2024-03-14 09:54:03       23 阅读
  2. Redisson

    2024-03-14 09:54:03       19 阅读
  3. Http 请求状态码

    2024-03-14 09:54:03       18 阅读
  4. 前端框架的发展史

    2024-03-14 09:54:03       19 阅读
  5. git命令行提交——github

    2024-03-14 09:54:03       23 阅读
  6. react diff 原理

    2024-03-14 09:54:03       21 阅读
  7. C语言下使用SQL语言

    2024-03-14 09:54:03       21 阅读
  8. 探索大语言模型(LLM):部分数据集介绍

    2024-03-14 09:54:03       22 阅读
  9. 同程旅行前端面试汇总

    2024-03-14 09:54:03       21 阅读
  10. 数据结构导航 -- 38篇

    2024-03-14 09:54:03       19 阅读
  11. gen_arrow_contour_xld

    2024-03-14 09:54:03       19 阅读
  12. wayland(xdg_wm_base) + egl + opengles 光照模型实例(十五)

    2024-03-14 09:54:03       23 阅读
  13. OMP实现MATLAB压缩感知实例

    2024-03-14 09:54:03       25 阅读
  14. vue中使用video.js,且可以截图、录制和下载视频

    2024-03-14 09:54:03       41 阅读