昇思25天学习打卡营第11天|文本解码原理-以MindNLP为例

文本解码原理-以MindNLP为例

这篇主要讲讲文本生成的几个方法,首先介绍一下什么是自回归语言模型。

自回归语言模型

autoregressive language model,根据前面的词或上下文,生成后续的词或句子的语言模型。
有几种典型的自回归语言模型:

1. 马尔科夫链
最简单的模型,假设当前词只依赖前一个或几个词。

2. ngram 模型
是马尔可夫的拓展,假设当前词依赖固定个数的前n个词。

3. 循环神经网络RNN
能捕捉序列中的长期依赖关系

4. 长短期记忆网络LSTM和门控循环单元GRU
RNN改进版本,解决了RNN的梯度消失和梯度爆炸的问题

在自回归模型中,一个文本序列的概率分布可以分解为每个词基于其上文的条件概率的乘积。

贪心法

每一步都选择概率最高的词作为当前的输出词。

from mindnlp.transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("iiBcai/gpt2", mirror='modelscope')
model = GPT2LMHeadModel.from_pretrained("iiBcai/gpt2", pad_token_id=tokenizer.eos_token_id, mirror='modelscope')

input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='ms')

greedy_output = model.generate(input_ids, max_length=50)

print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

解释一下代码构成:
首先import了两个模块

  • GPT2Tokenizer
    用于将文本转换为模型可处理的输入格式

  • GPT2LMHeadModel
    具体的GPT-2模型,用于生成文本。

接着加载tokenizer模型

使用tokenizer.encode()对指定文本进行编码,return_tensors指定了返回的张量格式。

使用model.generate()生成文本,max_length=50指输入和输出的最大token数为50.

最后使用tokenizer.decode()解码生成的文本。

Beam search

束搜索,这种方法会保留当前最可能的n个候选词,再根据这些词的得分选择概率最高的最佳序列。它可以一定程度的保留最优路径。


beam_output = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    early_stopping=True
)

beam_output = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    early_stopping=True
)

beam_outputs = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    num_return_sequences=5, 
    early_stopping=True
)

第一段,指定了束的宽度是5,即,在每步中保留5个候选序列。early_stopping=True表示当所有候选序列生成结束标记是,提前停止生成。

第二段,增加了 no_repeat_ngram_size=2 参数,指在生成过程中避免重复长度为2的n-gram(即连续两个token的组合)。
这里指的不是不包含2token的序列,而是在生成的序列内部不会包含重复的长度为2的n-gram。
如果在生成序列中已经包含了某个2token的序列,下面的生成不会再出现了。这可以增强生成文本的多样性。

第三段,增加了 num_return_sequences=5 参数,指生成并返回5个不同的候选序列。

缺点

1. 无法解决重复问题
beam search本质上是在每一步选择得分最高的候选项。因此在生成时容易重复使用得分比较高的词或者短语,尤其是没有明确停止条件或生成较长文本时。例如,可能会生成“this is a great great great idea“。no_repeat_ngram_size参数的引入就是为了减轻这一问题。

2. 开放域生成效果差
beam search在特定任务下表现比较好,因为目标的序列上下文更明确,歧义更少。但开放域时模型需要生成内容丰富的文本,而束搜索的过程更加保守。

Sample

根据当前的条件概率分布 随机选择输出词。这种方式的优点是文本生成的多样性高,而生成的文本更可能不连续。

mindspore.set_seed(0)
# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=0
)

top_k=0指禁用Top-k采样。Top-k采样是指在每一步生成时仅从前k个最有可能的候选词中进行采样。设置top_k=0表示不限制候选词的数量,即从所有可能的词中进行采样。

Temperature

在文本生成任务中,temperature是一个控制生成文本时随机性的参数。具体来说,temperature调整softmax函数的输出概率分布以改变模型生成下一个词时的概率分布,从而影响生成文本的多样性和确定性。
temperature大于1时,概率分布会更平滑,也就是每个词之间的概率差距会减小,文本也会更随机多样。temperature小于1时,概率分布会变得更加陡峭,高概率的词变得更有优势,生成的文本更加确定和集中。

sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=0,
    temperature=0.7
)

top k smaple

在生成下一个词时,仅从前 K 个概率最高的候选词中进行随机采样,而不是从整个词汇表中采样。也就是选出概率最大的k个词,重新归一化(使概率总和为1),最后在归一化后的k个词中进行随机采样。

mindspore.set_seed(0)
# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=50
)

这里的k值过小会导致文本缺乏多样性,k值过大可能引入不合理的词汇。

top-P sample

这种方式又称为核采样Nucleus Sampling,它会采样动态调整候选词的集合,使得生成的文本更加灵活和自然。具体的,top p sample基于累积概率的概念,在每一步生成下一个词时,它会选择那些累积概率达到或超过预定义阈值p的词,并从这些词中进行随机采样。

sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_p=0.92, 
    top_k=0
)

这个值需要根据具体任务来设定,一般在0.9-0.95之间。

总结

本篇介绍了在自回归语言模型下的几类文本生成算法,包括最简单的贪心法、有多种变体的sample、以及beam search束搜索。

打卡凭证

在这里插入图片描述

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-10 01:34:08       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-10 01:34:08       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-10 01:34:08       58 阅读
  4. Python语言-面向对象

    2024-07-10 01:34:08       69 阅读

热门阅读

  1. linux指令学习

    2024-07-10 01:34:08       23 阅读
  2. 钉钉消息异常通知

    2024-07-10 01:34:08       19 阅读
  3. python 学习

    2024-07-10 01:34:08       20 阅读
  4. 【Unix/Linux】Unix/Linux如何查看系统版本

    2024-07-10 01:34:08       19 阅读
  5. 【Unix/Linux】$bash-3.2是什么

    2024-07-10 01:34:08       21 阅读
  6. Win11系统vscode配置C语言环境

    2024-07-10 01:34:08       24 阅读
  7. Mojo有哪些优势和劣势

    2024-07-10 01:34:08       19 阅读
  8. 论文阅读:Large Language Models for Education: A Survey

    2024-07-10 01:34:08       25 阅读
  9. ARM汇编的基础语法

    2024-07-10 01:34:08       24 阅读
  10. postman

    postman

    2024-07-10 01:34:08      20 阅读
  11. Redis

    Redis

    2024-07-10 01:34:08      20 阅读
  12. [Linux安全运维] Linux命令相关

    2024-07-10 01:34:08       26 阅读
  13. PCL 点云最小外接球形包围盒

    2024-07-10 01:34:08       20 阅读
  14. Pytest单元测试系列[v1.0.0][高级技巧]

    2024-07-10 01:34:08       19 阅读