【机器学习300问】121、RNN是如何生成文本的?

        当RNN模型训练好后,如何让他生成一个句子?其实就是一个RNN前向传播的过程。通常遵循以下的步骤。

(1)初始化

        文本生成可以什么都不给,让他生成一首诗。首先,你需要确定采样的起始点。这可以是一个特殊的开始标记<START>或者是一个随机选取的词汇索引作为第一个时间步的输入。如果是基于字符的模型,则可能从一个特殊字符或空格开始。

如果什么都不输入,那么a^{<0>}=0x^{<1>}=0

(2)前向传播

        将起始输入送入RNN模型,得到第一个时间步的隐藏状态。对于之后的每个时间步,使用上一时间步的隐藏状态和当前输入(上一时间步模型预测的词或字符的索引)来计算新的隐藏状态,并得到下一个词的概率分布。

第一个时间步得到的输出(吴恩达老师手写)
第一个时间步(吴恩达老师手写)

        模型会得到一个概率分布,在这个分布上采样以预测下一个token。通常会使用softmax函数输出每个可能token的概率。例如,有10000个token的词典,那么得到的就是每一个token的概率。

(3)采样

        根据当前时间步的词概率分布进行采样,以决定下一个词。贪婪采样为例,每一步都选择概率最高的词作为下一个词,也就是y^{<1>}

举例说明一下,比如我们的RNN模型在一个给定时间步产生了以下5个token及其对应概率

Token Probability
the 0.4
cat 0.25
sat 0.15
on 0.1
mat 0.1

        'the'具有最高的概率0.4。因此,根据贪婪采样策略,我们会选择'the'作为下一个词。 

(4)更新序列

        将采样出的token加入到输入序列的末端。如果模型使用固定长度的序列,则需要将序列的第一个token去掉,以确保长度保持不变。

(5)重复采样

        反复执行步骤2至步骤4,逐步生成新的tokens,将它们加入到序列中。继续这个过程直到达到句子结束标记或达到预定最大序列长度。

(6)终止采样

        设定一个终止条件,比如:达到预定的最大序列长度;遇到结束标记(如<EOS>);基于某种规则判断生成完成(如遇到句号、问号等)。

最近更新

  1. TCP协议是安全的吗?

    2024-06-18 03:18:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-18 03:18:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-18 03:18:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-18 03:18:02       18 阅读

热门阅读

  1. NDS域名解析服务

    2024-06-18 03:18:02       4 阅读
  2. 国际化项目开发中关于时间的问题二

    2024-06-18 03:18:02       6 阅读
  3. Linux知识汇总

    2024-06-18 03:18:02       6 阅读
  4. Flink集群运行模式

    2024-06-18 03:18:02       8 阅读
  5. 617作业

    617作业

    2024-06-18 03:18:02      7 阅读
  6. k8s_DaemonSet和Deployment区别

    2024-06-18 03:18:02       10 阅读
  7. 细说MCU定时器中断的实现方法

    2024-06-18 03:18:02       8 阅读
  8. webpack之HMR

    2024-06-18 03:18:02       6 阅读