从零开始学ChatGLM2-6B 模型基于 P-Tuning v2 的微调

ChatGLM2-6B-PT

本项目实现了对于 ChatGLM2-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。

下面以 ADGEN (广告生成) 数据集为例介绍代码的使用方法。

In [11]:

!pip install -r  /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt -i  https://pypi.tuna.tsinghua.edu.cn/simple/
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple/
Requirement already satisfied: protobuf in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 1)) (5.26.1)
Requirement already satisfied: transformers==4.30.2 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (4.30.2)
Requirement already satisfied: cpm_kernels in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 3)) (1.0.11)
Requirement already satisfied: torch>=2.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (2.2.2)
Requirement already satisfied: gradio in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (3.40.0)
Requirement already satisfied: mdtex2html in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 6)) (1.3.0)
Requirement already satisfied: sentencepiece in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 7)) (0.2.0)
Requirement already satisfied: accelerate in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 8)) (0.28.0)
Requirement already satisfied: sse-starlette in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from -r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 9)) (2.0.0)
Requirement already satisfied: filelock in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (3.13.3)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (0.22.2)
Requirement already satisfied: numpy>=1.17 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (24.0)
Requirement already satisfied: pyyaml>=5.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (6.0.1)
Requirement already satisfied: regex!=2019.12.17 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (2023.12.25)
Requirement already satisfied: requests in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (2.31.0)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (0.13.3)
Requirement already satisfied: safetensors>=0.3.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (0.4.2)
Requirement already satisfied: tqdm>=4.27 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from transformers==4.30.2->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 2)) (4.66.2)
Requirement already satisfied: typing-extensions>=4.8.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (4.10.0)
Requirement already satisfied: sympy in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (1.12)
Requirement already satisfied: networkx in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (3.2.1)
Requirement already satisfied: jinja2 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (3.1.3)
Requirement already satisfied: fsspec in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from torch>=2.0->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 4)) (2024.2.0)
...
Requirement already satisfied: referencing>=0.28.4 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (0.34.0)
Requirement already satisfied: rpds-py>=0.7.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from jsonschema>=3.0->altair<6.0,>=4.2.0->gradio->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (0.18.0)
Requirement already satisfied: uc-micro-py in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from linkify-it-py<3,>=1->markdown-it-py[linkify]>=2.0.0->gradio->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (1.0.3)
Requirement already satisfied: six>=1.5 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from python-dateutil>=2.7->matplotlib~=3.0->gradio->-r /mnt/e/AI-lab/ChatGLM2-6B/requirements.txt (line 5)) (1.16.0)
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

In [13]:

# 运行微调除 ChatGLM2-6B 的依赖之外,还需要安装以下依赖
!pip install rouge_chinese nltk jieba datasets transformers[torch] -i https://pypi.douban.com/simple/
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple/
Requirement already satisfied: rouge_chinese in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (1.0.3)
Requirement already satisfied: nltk in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (3.8.1)
Requirement already satisfied: jieba in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (0.42.1)
Requirement already satisfied: datasets in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (2.18.0)
Requirement already satisfied: transformers[torch] in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (4.30.2)
Requirement already satisfied: six in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from rouge_chinese) (1.16.0)
Requirement already satisfied: click in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from nltk) (8.1.7)
Requirement already satisfied: joblib in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from nltk) (1.3.2)
Requirement already satisfied: regex>=2021.8.3 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from nltk) (2023.12.25)
Requirement already satisfied: tqdm in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from nltk) (4.66.2)
Requirement already satisfied: filelock in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (3.13.3)
Requirement already satisfied: numpy>=1.17 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (1.26.4)
Requirement already satisfied: pyarrow>=12.0.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (15.0.2)
Requirement already satisfied: pyarrow-hotfix in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (0.6)
Requirement already satisfied: dill<0.3.9,>=0.3.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (0.3.8)
Requirement already satisfied: pandas in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (2.2.1)
Requirement already satisfied: requests>=2.19.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (2.31.0)
Requirement already satisfied: xxhash in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (3.4.1)
Requirement already satisfied: multiprocess in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (0.70.16)
Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets) (2024.2.0)
Requirement already satisfied: aiohttp in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (3.9.3)
Requirement already satisfied: huggingface-hub>=0.19.4 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (0.22.2)
Requirement already satisfied: packaging in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (24.0)
Requirement already satisfied: pyyaml>=5.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from datasets) (6.0.1)
...
Requirement already satisfied: pytz>=2020.1 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from pandas->datasets) (2024.1)
Requirement already satisfied: MarkupSafe>=2.0 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from jinja2->torch!=1.12.0,>=1.9->transformers[torch]) (2.1.5)
Requirement already satisfied: mpmath>=0.19 in /home/ai001/anaconda3/envs/chatglm2-6b/lib/python3.9/site-packages (from sympy->torch!=1.12.0,>=1.9->transformers[torch]) (1.3.0)
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

使用方法

下载数据集

ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。

{    
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",    
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"    
}

从 Google Drive 或者 Tsinghua Cloud 下载处理好的 ADGEN 数据集,将解压后的 AdvertiseGen 目录放到 ptuning 目录下。

本项目中默认已经挂载了 ADGEN 数据集。

In [2]:

# 微调生成的 Checkpoint 文件较大,为避免占用 project 目录空间,我们将工作目录移到 temp 目录中进行后续工作

!cp -r /mnt/e/AI-lab/ChatGLM2-6B/ptuning /mnt/e/AI-lab/ChatGLM2-6B/temp

In [3]:

import os
 
# 设置你想要切换到的目录路径
new_dir = '/mnt/e/AI-lab/ChatGLM2-6B/temp/ptuning'
 
# 切换当前工作目录
os.chdir(new_dir)
 
# 打印当前工作目录以确认切换成功
print(os.getcwd())

/mnt/e/AI-lab/ChatGLM2-6B/temp/ptuning 

In [4]:

# 拷贝 ADGEN 数据集到工作目录
!cp -r /home/mw/input/adgen9371 AdvertiseGen

In [5]:

# 检查数据集
!ls -alh AdvertiseGen
total 52M
drwxrwxrwx 1 ai001 ai001 4.0K Apr  3 21:16 .
drwxrwxrwx 1 ai001 ai001 4.0K Apr  4 17:34 ..
-rwxrwxrwx 1 ai001 ai001 487K Apr  4 17:34 dev.json
-rwxrwxrwx 1 ai001 ai001  52M Apr  4 17:34 train.json

训练

P-Tuning v2

PRE_SEQ_LEN 和 LR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。

在默认配置 quantization_bit=4per_device_train_batch_size=1gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

Finetune

如果需要进行全参数的 Finetune,需要安装 Deepspeed,然后运行以下指令:

bash ds_train_finetune.sh

我们以 P-tuning v2 方法为例,采取参数 quantization_bit=4per_device_train_batch_size=1gradient_accumulation_steps=16 进行微调训练

In [14]:

# P-tuning v2
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
PRE_SEQ_LEN=128
LR=2e-2
NUM_GPUS=1



!torchrun --standalone --nnodes=1 --nproc-per-node=1 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --preprocessing_num_workers 10 \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /mnt/e/AI-lab/ChatGLM2-6B/ \
    --output_dir output/adgen-chatglm2-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 128 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 2e-2 \
    --pre_seq_len 128 \
    --ddp_find_unused_parameters False
    #--quantization_bit 4
  

In [17]:

# 加载模型
model_path = "/mnt/e/ai-lab/ChatGLM2-6B"
from transformers import AutoTokenizer, AutoModel
from utils import load_model_on_gpus
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = load_model_on_gpus("/mnt/e/ai-lab/ChatGLM2-6B", num_gpus=2)
model = model.eval()
# 使用 Markdown 格式打印模型输出
from IPython.display import display, Markdown, clear_output

def display_answer(model, query, history=[]):
    for response, history in model.stream_chat(
            tokenizer, query, history=history):
        clear_output(wait=True)
        display(Markdown(response))
    return history

In [18]:

# 微调前
#model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
#model = model.eval()

display_answer(model, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞")

上衣材质为牛仔布,颜色为白色,风格为简约,图案为刺绣,衣款式为外套,衣样式为破洞。

Out[18]:

[('类型#上衣\\*材质#牛仔布\\*颜色#白色\\*风格#简约\\*图案#刺绣\\*衣样式#外套\\*衣款式#破洞',
  '上衣材质为牛仔布,颜色为白色,风格为简约,图案为刺绣,衣款式为外套,衣样式为破洞。')]

In [22]:

# 微调后
import os
import torch
from transformers import AutoConfig
from transformers import AutoModel
from transformers import AutoTokenizer, AutoModel

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

model_path = "/mnt/e/ai-lab/ChatGLM2-6B"
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join("/mnt/e/AI-lab/ChatGLM2-6B/temp/ptuning/output/adgen-chatglm2-6b-pt-128-0.02/checkpoint-1000", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

# 使用 Markdown 格式打印模型输出
from IPython.display import display, Markdown, clear_output


# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
response, history = model.chat(tokenizer, "类型#上衣*颜色#黑白*风格#简约*风格#休闲*图案#条纹*衣样式#风衣*衣样式#外套", history=[])
print(response)
#display_answer(model, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞")

response, history = model.chat(tokenizer, "风衣有什么特征呢", history=[])
print(response)

response, history = model.chat(tokenizer, "日常休闲一般穿什么风格的衣服好呢?", history=[])
print(response)

Loading checkpoint shards: 100%

 7/7 [02:39<00:00, 23.82s/it]

Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at /mnt/e/ai-lab/ChatGLM2-6B and are newly initialized: ['transformer.prefix_encoder.embedding.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Some weights of ChatGLMForConditionalGeneration were not initialized from the model checkpoint at /mnt/e/ai-lab/ChatGLM2-6B and are newly initialized: ['transformer.prefix_encoder.embedding.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
简约的条纹风衣,在黑白两色的搭配下,看起来非常的干练利落。经典的条纹元素,带来一种简约休闲的时尚感,将女性优雅的气质完美展现出来。
这款风衣是经典的风衣款式,采用优质的面料制作,质感舒适。在设计上,风衣采用经典的翻领设计,修饰颈部曲线,让你看起来更加优雅。风衣前襟采用斜线处理,整体看起来更加有设计感。
休闲风格是生活中不可或缺的,无论是在职场还是日常休闲,它都是一种很受欢迎的时尚元素。对于休闲的衣装来说,它一般都具有很亲和的气质,可以搭配出各种不同的风格。像这款休闲的连衣裙,它采用柔软的面料,穿着舒适亲肤,而且可以轻松的搭配出各种不同的风格。

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-10 01:06:02       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-10 01:06:02       106 阅读
  3. 在Django里面运行非项目文件

    2024-04-10 01:06:02       87 阅读
  4. Python语言-面向对象

    2024-04-10 01:06:02       96 阅读

热门阅读

  1. QT及C++中引用的用法和意义

    2024-04-10 01:06:02       31 阅读
  2. [ LeetCode ] 题刷刷(Python)-第70题:爬楼梯

    2024-04-10 01:06:02       37 阅读
  3. 大数据在医疗信息化中的应用

    2024-04-10 01:06:02       32 阅读
  4. 前端小白学习Vue2框架(一)

    2024-04-10 01:06:02       36 阅读
  5. 驾驭前端未来

    2024-04-10 01:06:02       33 阅读
  6. 大唐杯历届省赛押题训练(5)

    2024-04-10 01:06:02       31 阅读
  7. 【后端】OFD学习笔记

    2024-04-10 01:06:02       41 阅读
  8. 吴军《格局》对我的3点帮助

    2024-04-10 01:06:02       37 阅读
  9. 深入了解Linux: dbus-daemon系统总线的作用与管理

    2024-04-10 01:06:02       32 阅读
  10. Leetcode 165. 比较版本号

    2024-04-10 01:06:02       37 阅读