【大模型】微调实战—使用 ORPO 微调 Llama 3

ORPO 是一种新颖微调(fine-tuning)技术,它将传统的监督微调(supervised fine-tuning)和偏好对齐(preference alignment)阶段合并为一个过程。这减少了训练所需的计算资源和时间。此外,实证结果表明,ORPO 在各种模型规模和基准测试(benchmarks)上优于其他对齐方法。
在本文中,我们将使用 ORPO 和 TRL 库对新的 Llama 3 8B 模型进行微调。

ORPO

指令微调(instruction tuning)和偏好对齐(preference alignment)是使LLM适应特定任务的基本技术。传统上,这涉及一个多阶段的过程:1/ 在指令上进行监督微调(Supervised Fine-Tuning, SFT),以使模型适应目标领域,然后 2/ 使用偏好对齐方法,如基于人类反馈的强化学习(Reinforcement Learning with Human Feedback, RLHF)或直接偏好优化(Direct Preference Optimization, DPO),以增加生成首选响应而非被拒绝响应的可能性。
在这里插入图片描述

然而,研究人员发现了这种方法的局限性。虽然 SFT 有效地使模型适应所需的领域,但它无意中增加了在首选答案的同时生成不需要的答案的可能性。这就是为什么偏好调整阶段对于扩大首选输出和拒绝输出的可能性之间的差距是必要的。
ORPO 由 Hong 和 Lee (2024) 提出,通过将指令调整和偏好对齐结合到一个单一的整体训练过程中,为这个问题提供了一个优雅的解决方案。 ORPO 修改了标准语言建模目标,将负对数似然损失与优势比 (OR) 项相结合。这种 OR 损失对被拒绝的响应进行弱惩罚,同时对首选响应进行强烈奖励,从而使模型能够同时学习目标任务并与人类偏好保持一致。
在这里插入图片描述
ORPO 已在主要微调库中实现,如 TRL、Axolotl 和 LLaMA-Factory。在下一节中,我们将了解如何与 TRL 一起使用。

使用 ORPO 微调 Llama 3

Llama 3 是Meta开发的最新大型语言模型(LLM)家族。该模型在一个包含15万亿个标记的数据集上进行了训练(相比之下,Llama 2 的训练数据集为2万亿个标记)。目前已经发布了两种模型尺寸:一个是拥有70B参数的模型,另一个是较小的8B参数模型。70B参数的模型已经展示了令人印象深刻的性能,在MMLU基准测试中得分为82,在HumanEval基准测试中得分为81.7。
Llama 3 模型还将上下文长度增加到了8,192个标记(相比之下,Llama 2 为4,096个标记),并且有可能通过RoPE扩展到32k。此外,这些模型使用了一种新的分词器,具有128K标记的词汇量,从而减少了编码文本所需的标记数量15%。这种词汇量的增加也解释了参数从70亿增加到80亿。
ORPO 需要一个偏好数据集,包括提示、选择的答案和拒绝的答案。在此示例中,我们将使用 mlabonne/orpo-dpo-mix-40k ,它是以下高质量 DPO 数据集的组合:

首先安装所需的库:

pip install -U transformers datasets accelerate peft trl bitsandbytes wandb

安装完成后,我们可以导入必要的库并登录W&B(可选)

import gc
import os

import torch
import wandb
from datasets import load_dataset
# from google.colab import userdata
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
)
from trl import ORPOConfig, ORPOTrainer, setup_chat_format

# wb_token = userdata.get('wandb')
# wandb.login(key=wb_token)

如果您有最新的 GPU,还应该能够使用 Flash Attention 库将默认的 eager Attention 实现替换为更高效的实现。

if torch.cuda.get_device_capability()[0] >= 8:
    #!pip install -qqq flash-attn
    attn_implementation = "flash_attention_2"
    torch_dtype = torch.bfloat16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float16

接下来,我们将借助bitsandbytes 以 4 位精度加载 Llama 3 8B 模型。然后,我们使用 QLoRA 的 PEFT 设置 LoRA 配置。我还使用方便的 setup_chat_format() 函数来修改模型和标记生成器以支持 ChatML。它会自动应用此聊天模板,添加特殊标记,并调整模型嵌入层的大小以匹配新的词汇表大小。
请注意,您需要提交访问 meta-llama/Meta-Llama-3-8B 的请求并登录您的 Hugging Face 帐户。或者,您可以加载模型的非门控副本,例如 NousResearch/Meta–Llama-3-8B。(我选择手动从NousResearch/Meta–Llama-3-8B下载)

# Model
base_model = "meta-llama/Meta-Llama-3-8B"
new_model = "OrpoLlama-3-8B"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = prepare_model_for_kbit_training(model)

现在模型已准备好进行训练,我们可以处理数据集了。我们加载 mlabonne/orpo-dpo-mix-40k 并使用 apply_chat_template() 函数将“chosen”和“rejected”列转换为 ChatML 格式。请注意,我仅使用 1,00 个样本,而不是整个数据集,因为运行时间太长。(我选择手动下载)

dataset_name = "mlabonne/orpo-dpo-mix-40k"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(100))

def format_chat_template(row):
    row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
    row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
    return row

dataset = dataset.map(
    format_chat_template,
    num_proc= os.cpu_count(),
)
dataset = dataset.train_test_split(test_size=0.01)

首先,我们需要设置一些超参数: * learning_rate :与传统的 SFT 甚至 DPO 相比,ORPO 使用非常低的学习率。 8e-6这个值来自原始论文,大致对应于SFT学习率1e-5和DPO学习率5e-6。我建议将其增加到 1e-6 左右以进行真正的微调。 * beta :即论文中的 𝜆 参数,默认值为0.1。原始论文的附录显示了如何通过消融研究选择它。 * 其他参数,如 max_length 和批量大小设置为使用尽可能多的可用 VRAM(此配置中约为 20 GB)。理想情况下,我们会训练模型 3-5 个 epoch,但这里我们坚持使用 1 个 epoch。
最后,我们可以使用 ORPOTrainer 来训练模型,它充当包装器。

orpo_args = ORPOConfig(
    learning_rate=8e-6,
    beta=0.1,
    lr_scheduler_type="linear",
    max_length=1024,
    max_prompt_length=512,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    optim="paged_adamw_8bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    report_to="wandb",
    output_dir="./results/",
)

trainer = ORPOTrainer(
    model=model,
    args=orpo_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    tokenizer=tokenizer,
)
trainer.train()
trainer.save_model(new_model)

中间需要选择是否使用W&B,不会使用,我选择不使用
在这里插入图片描述
完成了 Llama 3 的快速微调:mlabonne/OrpoLlama-3-8B
在这里插入图片描述

生成目录:
在这里插入图片描述

合并完整模型到本地:

# Flush memory
del trainer, model
gc.collect()
torch.cuda.empty_cache()

# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="auto",
)
model, tokenizer = setup_chat_format(model, tokenizer)

# Merge adapter with base model
model = PeftModel.from_pretrained(model, new_model)
model = model.merge_and_unload()

# Save the merged model and tokenizer to local directory
local_save_directory = "new_model"
model.save_pretrained(local_save_directory)
tokenizer.save_pretrained(local_save_directory)

得到和初始模型一样结构的微调模型;
在这里插入图片描述
完整教程:https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html
本文使用代码对原代码改了一部分。

相关推荐

  1. 使用 LLaMA Factory 微调 Llama-3 中文对话模型

    2024-07-10 12:50:04       21 阅读
  2. 使用 torchtune 微调 Llama3

    2024-07-10 12:50:04       19 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-10 12:50:04       4 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-10 12:50:04       5 阅读
  3. 在Django里面运行非项目文件

    2024-07-10 12:50:04       4 阅读
  4. Python语言-面向对象

    2024-07-10 12:50:04       5 阅读

热门阅读

  1. 小程序的制作费用很贵么

    2024-07-10 12:50:04       8 阅读
  2. c#实现23种常见的设计模式--动态更新

    2024-07-10 12:50:04       6 阅读
  3. 银河麒麟(V10SP1)-arm版交叉编译-qt-5.12.12源码

    2024-07-10 12:50:04       7 阅读
  4. 华为机考真题 -- 游戏分组

    2024-07-10 12:50:04       10 阅读
  5. Linux 期末速成(知识点+例题)

    2024-07-10 12:50:04       10 阅读
  6. 【基础篇】1.8 C语言基础(二)

    2024-07-10 12:50:04       8 阅读
  7. element ui form添加校验规则

    2024-07-10 12:50:04       8 阅读
  8. splice方法的使用#Vue3

    2024-07-10 12:50:04       9 阅读
  9. 使用Dockerfile和ENTRYPOINT运行Python 3脚本

    2024-07-10 12:50:04       9 阅读
  10. 黑龙江等保测评对中小企业成本效益分析

    2024-07-10 12:50:04       9 阅读