昇思25天学习打卡营第20天|基于MindSpore的GPT2文本摘要

基于MindSpore的GPT2文本摘要

GPT2文本摘要介绍

GPT-2(Generative Pre-trained Transformer 2)是OpenAI开发的一种基于Transformer架构的预训练语言模型。它是继GPT-1之后的升级版本,通过更大的数据集和更复杂的模型架构,进一步提升了自然语言处理任务的性能。

GPT-2的主要功能之一是文本生成和文本摘要。文本摘要是指从输入的文本中提取关键信息,生成简明扼要的总结或概述。在GPT-2中,这种能力通过预训练阶段学习到的语言模型来实现,模型能够理解和生成自然语言文本的结构和内容,并且能够根据输入文本的上下文生成合理的摘要。

使用GPT-2进行文本摘要时,模型通常会被输入一个长文本或文章,然后生成一个较短的摘要,概括原始文本的主要内容和要点。这种技术对于信息检索、自动化写作和内容生成领域具有重要的应用价值,可以帮助用户快速理解和处理大量文本信息。

GPT2实践

实践环境

python: 3.9.19

安装环境依赖

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

pip install tokenizers==0.15.0 -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install mindnlp

完整的Python环境依赖

pip list
Package                        Version
------------------------------ --------------
absl-py                        2.1.0
addict                         2.4.0
aiofiles                       22.1.0
aiohttp                        3.9.5
aiosignal                      1.3.1
aiosqlite                      0.20.0
altair                         5.3.0
annotated-types                0.7.0
anyio                          4.4.0
argon2-cffi                    23.1.0
argon2-cffi-bindings           21.2.0
arrow                          1.3.0
astroid                        3.2.2
asttokens                      2.0.5
astunparse                     1.6.3
async-timeout                  4.0.3
attrs                          23.2.0
auto-tune                      0.1.0
autopep8                       1.5.5
Babel                          2.15.0
backcall                       0.2.0
beautifulsoup4                 4.12.3
black                          24.4.2
bleach                         6.1.0
certifi                        2024.6.2
cffi                           1.16.0
charset-normalizer             3.3.2
click                          8.1.7
cloudpickle                    3.0.0
colorama                       0.4.6
comm                           0.2.1
contextlib2                    21.6.0
contourpy                      1.2.1
cycler                         0.12.1
dataflow                       0.0.1
datasets                       2.20.0
debugpy                        1.6.7
decorator                      5.1.1
defusedxml                     0.7.1
dill                           0.3.8
dnspython                      2.6.1
download                       0.3.5
easydict                       1.13
email_validator                2.2.0
entrypoints                    0.4
evaluate                       0.4.2
exceptiongroup                 1.2.0
executing                      0.8.3
fastapi                        0.111.0
fastapi-cli                    0.0.4
fastjsonschema                 2.20.0
ffmpy                          0.3.2
filelock                       3.15.3
flake8                         3.8.4
fonttools                      4.53.0
fqdn                           1.5.1
frozenlist                     1.4.1
fsspec                         2024.5.0
gitdb                          4.0.11
GitPython                      3.1.43
gradio                         4.26.0
gradio_client                  0.15.1
h11                            0.14.0
hccl                           0.1.0
hccl-parser                    0.1
httpcore                       1.0.5
httptools                      0.6.1
httpx                          0.27.0
huggingface-hub                0.23.4
hypothesis                     6.105.1
idna                           3.7
importlib-metadata             7.0.1
importlib_resources            6.4.0
iniconfig                      2.0.0
ipykernel                      6.28.0
ipympl                         0.9.4
ipython                        8.15.0
ipython-genutils               0.2.0
ipywidgets                     8.1.3
isoduration                    20.11.0
isort                          5.13.2
jedi                           0.17.2
jieba                          0.42.1
Jinja2                         3.1.4
joblib                         1.4.2
json5                          0.9.25
jsonpointer                    3.0.0
jsonschema                     4.22.0
jsonschema-specifications      2023.12.1
jupyter_client                 7.4.9
jupyter_core                   5.7.2
jupyter-events                 0.10.0
jupyter-lsp                    2.2.5
jupyter-resource-usage         0.7.2
jupyter_server                 2.14.1
jupyter_server_fileid          0.9.2
jupyter-server-mathjax         0.2.6
jupyter_server_terminals       0.5.3
jupyter_server_ydoc            0.8.0
jupyter-ydoc                   0.2.5
jupyterlab                     3.6.7
jupyterlab_code_formatter      2.2.1
jupyterlab_git                 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp                 4.3.0
jupyterlab_pygments            0.3.0
jupyterlab_server              2.27.2
jupyterlab-system-monitor      0.8.0
jupyterlab-topbar              0.6.1
jupyterlab_widgets             3.0.11
kiwisolver                     1.4.5
markdown-it-py                 3.0.0
MarkupSafe                     2.1.5
matplotlib                     3.9.0
matplotlib-inline              0.1.6
mccabe                         0.6.1
mdurl                          0.1.2
mindnlp                        0.3.1
mindspore                      2.2.14
mindvision                     0.1.0
mistune                        3.0.2
ml_collections                 0.1.1
ml-dtypes                      0.4.0
mpmath                         1.3.0
msadvisor                      1.0.0
multidict                      6.0.5
multiprocess                   0.70.16
mypy-extensions                1.0.0
nbclassic                      1.1.0
nbclient                       0.10.0
nbconvert                      7.16.4
nbdime                         4.0.1
nbformat                       5.10.4
nest-asyncio                   1.6.0
notebook                       6.5.7
notebook_shim                  0.2.4
numpy                          1.26.4
op-compile-tool                0.1.0
op-gen                         0.1
op-test-frame                  0.1
opc-tool                       0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python                  4.10.0.84
opencv-python-headless         4.10.0.84
orjson                         3.10.5
overrides                      7.7.0
packaging                      23.2
pandas                         2.2.2
pandocfilters                  1.5.1
parso                          0.7.1
pathlib2                       2.3.7.post1
pathspec                       0.12.1
pexpect                        4.8.0
pickleshare                    0.7.5
pillow                         10.3.0
pip                            24.1
platformdirs                   4.2.2
pluggy                         1.5.0
prometheus_client              0.20.0
prompt-toolkit                 3.0.43
protobuf                       5.27.1
psutil                         5.9.0
ptyprocess                     0.7.0
pure-eval                      0.2.2
pyarrow                        16.1.0
pyarrow-hotfix                 0.6
pycodestyle                    2.6.0
pycparser                      2.22
pyctcdecode                    0.5.0
pydantic                       2.7.4
pydantic_core                  2.18.4
pydocstyle                     6.3.0
pydub                          0.25.1
pyflakes                       2.2.0
Pygments                       2.15.1
pygtrie                        2.5.0
pylint                         3.2.3
pyparsing                      3.1.2
pytest                         7.2.0
python-dateutil                2.9.0.post0
python-dotenv                  1.0.1
python-json-logger             2.0.7
python-jsonrpc-server          0.4.0
python-language-server         0.36.2
python-multipart               0.0.9
pytoolconfig                   1.3.1
pytz                           2024.1
PyYAML                         6.0.1
pyzmq                          25.1.2
referencing                    0.35.1
regex                          2024.5.15
requests                       2.32.3
rfc3339-validator              0.1.4
rfc3986-validator              0.1.1
rich                           13.7.1
rope                           1.13.0
rpds-py                        0.18.1
ruff                           0.4.10
safetensors                    0.4.3
schedule-search                0.0.1
scikit-learn                   1.5.0
scipy                          1.13.1
semantic-version               2.10.0
Send2Trash                     1.8.3
sentencepiece                  0.2.0
setuptools                     69.5.1
shellingham                    1.5.4
six                            1.16.0
smmap                          5.0.1
sniffio                        1.3.1
snowballstemmer                2.2.0
sortedcontainers               2.4.0
soupsieve                      2.5
stack-data                     0.2.0
starlette                      0.37.2
sympy                          1.12.1
synr                           0.5.0
te                             0.4.0
terminado                      0.18.1
threadpoolctl                  3.5.0
tinycss2                       1.3.0
tokenizers                     0.15.0
toml                           0.10.2
tomli                          2.0.1
tomlkit                        0.12.0
toolz                          0.12.1
tornado                        6.4.1
tqdm                           4.66.4
traitlets                      5.14.3
typer                          0.12.3
types-python-dateutil          2.9.0.20240316
typing_extensions              4.11.0
tzdata                         2024.1
ujson                          5.10.0
uri-template                   1.3.0
urllib3                        2.2.2
uvicorn                        0.30.1
uvloop                         0.19.0
watchfiles                     0.22.0
wcwidth                        0.2.5
webcolors                      24.6.0
webencodings                   0.5.1
websocket-client               1.8.0
websockets                     11.0.3
wheel                          0.43.0
widgetsnbextension             4.0.11
xxhash                         3.4.1
y-py                           0.6.2
yapf                           0.40.2
yarl                           1.9.4
ypy-websocket                  0.8.4
zipp                           3.17.0

实践代码

from mindnlp.utils import http_get

# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')

from mindspore.dataset import TextFileDataset

# load dataset
# dataset = TextFileDataset(str(path), shuffle=False)  # 跑的太慢,所以不用全量数据测试
dataset = TextFileDataset(str(path), shuffle=True,num_samples=5000)
# dataset.get_dataset_size()

# split into training and testing dataset
train_dataset, test_dataset = dataset.split([0.9, 0.1], randomize=False)

2. 数据预处理

    原始数据格式:
    
    article: [CLS] article_context [SEP]
    summary: [CLS] summary_context [SEP]
   
    预处理后的数据格式:


    [CLS] article_context [SEP] summary_context [SEP]
 

import json
import numpy as np

# preprocess dataset
def process_dataset(dataset, tokenizer, batch_size=6, max_seq_len=1024, shuffle=False):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def merge_and_pad(article, summary):
        # tokenization
        # pad to max_seq_length, only truncate the article
        tokenized = tokenizer(text=article, text_pair=summary,
                              padding='max_length', truncation='only_first', max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['input_ids']
    
    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    # change column names to input_ids and labels for the following training
    dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])

    dataset = dataset.batch(batch_size)
    if shuffle:
        dataset = dataset.shuffle(batch_size)

    return dataset


# 因GPT2无中文的tokenizer,我们使用BertTokenizer替代。



from mindnlp.transformers import BertTokenizer

# We use BertTokenizer for tokenizing chinese context.
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
len(tokenizer)

train_dataset = process_dataset(train_dataset, tokenizer, batch_size=4)

next(train_dataset.create_tuple_iterator())

模型构建

  1. 构建GPT2ForSummarization模型,注意shift right的操作。
# 

from mindspore import ops
from mindnlp.transformers import GPT2LMHeadModel

class GPT2ForSummarization(GPT2LMHeadModel):
    def construct(
        self,
        input_ids = None,
        attention_mask = None,
        labels = None,
    ):
        outputs = super().construct(input_ids=input_ids, attention_mask=attention_mask)
        shift_logits = outputs.logits[..., :-1, :]
        shift_labels = labels[..., 1:]
        # Flatten the tokens
        loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)
        return loss

# 动态学习率
from mindspore import ops
from mindspore.nn.learning_rate_schedule import LearningRateSchedule

class LinearWithWarmUp(LearningRateSchedule):
    """
    Warmup-decay learning rate.
    """
    def __init__(self, learning_rate, num_warmup_steps, num_training_steps):
        super().__init__()
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps
        self.num_training_steps = num_training_steps

    def construct(self, global_step):
        if global_step < self.num_warmup_steps:
            return global_step / float(max(1, self.num_warmup_steps)) * self.learning_rate
        return ops.maximum(
            0.0, (self.num_training_steps - global_step) / (max(1, self.num_training_steps - self.num_warmup_steps))
        ) * self.learning_rate

模型训练

num_epochs = 1
warmup_steps = 2000
learning_rate = 1.5e-4

num_training_steps = num_epochs * train_dataset.get_dataset_size()

from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModel

config = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)

lr_scheduler = LinearWithWarmUp(learning_rate=learning_rate, num_warmup_steps=warmup_steps, num_training_steps=num_training_steps)
optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler)

# 记录模型参数数量
print('number of model parameters: {}'.format(model.num_parameters()))
number of model parameters: 102068736
from mindnlp._legacy.engine import Trainer
from mindnlp._legacy.engine.callbacks import CheckpointCallback

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt2_summarization',
                                epochs=1, keep_checkpoint_max=2)

trainer = Trainer(network=model, train_dataset=train_dataset,
                  epochs=1, optimizer=optimizer, callbacks=ckpoint_cb)
trainer.set_amp(level='O1')  # 开启混合精度

注:建议使用较高规格的算力,训练时间较长

trainer.run(tgt_columns="labels")

模型推理

数据处理,将向量数据变为中文数据

def process_test_dataset(dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):
    def read_map(text):
        data = json.loads(text.tobytes())
        return np.array(data['article']), np.array(data['summarization'])

    def pad(article):
        tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)
        return tokenized['input_ids']

    dataset = dataset.map(read_map, 'text', ['article', 'summary'])
    dataset = dataset.map(pad, 'article', ['input_ids'])
    
    dataset = dataset.batch(batch_size)

    return dataset


test_dataset = process_test_dataset(test_dataset, tokenizer, batch_size=1)

print(next(test_dataset.create_tuple_iterator(output_numpy=True)))

model = GPT2LMHeadModel.from_pretrained('./checkpoint/gpt2_summarization_epoch_0.ckpt', config=config)

model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in test_dataset.create_tuple_iterator():
    output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)
    output_text = tokenizer.decode(output_ids[0].tolist())
    print(output_text)
    i += 1
    if i == 1:
        break

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-13 11:26:03       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-13 11:26:03       71 阅读
  3. 在Django里面运行非项目文件

    2024-07-13 11:26:03       58 阅读
  4. Python语言-面向对象

    2024-07-13 11:26:03       69 阅读

热门阅读

  1. prompt第四讲-fewshot

    2024-07-13 11:26:03       19 阅读
  2. Netty Websocket SpringBoot Starter

    2024-07-13 11:26:03       23 阅读
  3. 第五十五章 生成的 WSDL 的详细信息 - types

    2024-07-13 11:26:03       22 阅读
  4. 开发指南044-切片编程

    2024-07-13 11:26:03       26 阅读
  5. 触发器练习

    2024-07-13 11:26:03       22 阅读
  6. Flutter框架时间线梳理

    2024-07-13 11:26:03       26 阅读
  7. ubuntu wifi ap

    2024-07-13 11:26:03       26 阅读
  8. 基于Hadoop的区块链海量数据存储的设计与实现

    2024-07-13 11:26:03       25 阅读