LLaMA-Factory添加adalora

感谢https://github.com/tsingcoo/LLaMA-Efficient-Tuning/commit/f3a532f56b4aa7d4200f24d93fade4b2c9042736https://github.com/huggingface/peft/issues/432的帮助。

在LLaMA-Factory中添加adalora

1. 修改src/llmtuner/hparams/finetuning_args.py代码
在FinetuningArguments中修改finetuning_type,添加target_r和init_r
在这里插入图片描述
修改__post_init__函数
在这里插入图片描述

2. 修改src/llmtuner/tuner/core/adapter.py代码
添加AdaLoraConfig
在这里插入图片描述
在init_adapter函数中添加一个if判断,添加位置在如红框所示:
在这里插入图片描述

    if finetuning_args.finetuning_type == "adalora":
        logger.info("Fine-tuning method: AdaLoRA")
        latest_checkpoint = None

        if model_args.checkpoint_dir is not None:
            if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
                checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
            else:
                checkpoints_to_merge = model_args.checkpoint_dir

            for checkpoint in checkpoints_to_merge:
                model = PeftModel.from_pretrained(model, checkpoint)
                model = model.merge_and_unload()

            if len(checkpoints_to_merge) > 0:
                logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))

            if latest_checkpoint is not None: # resume lora training or quantized inference
                model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)


        if is_trainable and latest_checkpoint is None: # create new lora weights while training
            if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
                target_modules = find_all_linear_modules(model, model_args.quantization_bit)
            else:
                target_modules = finetuning_args.lora_target
                
            lora_config = AdaLoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                target_r=finetuning_args.target_r,
                init_r=finetuning_args.init_r,
                r=finetuning_args.lora_rank,
                target_modules=target_modules,
                lora_alpha=finetuning_args.lora_alpha,
                lora_dropout=finetuning_args.lora_dropout,
            )

            model = get_peft_model(model, lora_config)
            if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
                model.base_model.peft_config = model.peft_config

3. 修改src/llmtuner/tuner/core/parser.py的代码
这边建议所有有关finetuning_args.finetuning_type==/!= "lora"的都改成图片所示
在这里插入图片描述

修改transformer源码

按照上面的改完之后虽然可以训练,但是其实并没有实现adalora的秩的调整。

我是通过在update_and_allocate函数中设置断点发现模型训练没有调用update_and_allocate函数,update_and_allocate函数位于python3.10/site-packages/peft/tuners/adalora.py中。

1. 修改python3.10/site-packages/transformers/trainer.py代码

                    from peft import PeftModel
                    if isinstance(model, PeftModel):
                            if getattr(model.base_model, "update_and_allocate", None) is not None:
                                model.base_model.update_and_allocate(total_batched_samples)

把上面的代码复制到train函数中,具体的位置应该是整个文件的第二个model.zero_grad()上面,不同transformers的位置可能不一样
在这里插入图片描述
2. 设置adalora的总迭代次数
两个方法一个是在adaloraconfig定义的时候设定(我没试),另外一个就是一样修改train.py,如下:
在for epoch in range(epochs_trained, num_train_epochs):上面一行设置

        # 设置总迭代数
        model.base_model.peft_config[model.base_model.trainable_adapter_name].total_step = len(train_dataloader)

在这里插入图片描述

训练启动

在这里插入图片描述

相关推荐

  1. llama-factory简介

    2024-01-17 09:30:07       35 阅读
  2. LLaMA-Factory 微调训练

    2024-01-17 09:30:07       29 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-17 09:30:07       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-17 09:30:07       106 阅读
  3. 在Django里面运行非项目文件

    2024-01-17 09:30:07       87 阅读
  4. Python语言-面向对象

    2024-01-17 09:30:07       96 阅读

热门阅读

  1. iOS和安卓端个人踩坑史

    2024-01-17 09:30:07       57 阅读
  2. RabbitMQ

    RabbitMQ

    2024-01-17 09:30:07      47 阅读
  3. 【征服redis4】一文征服redis的Lettuce客户端

    2024-01-17 09:30:07       46 阅读
  4. Python3 如何做数据类型转换

    2024-01-17 09:30:07       58 阅读
  5. uniapp 实现tabBar-switchTab之间的传参

    2024-01-17 09:30:07       61 阅读
  6. webpack打包机制,构建过程和配置

    2024-01-17 09:30:07       64 阅读
  7. .NET gRPC

    2024-01-17 09:30:07       61 阅读
  8. 设计模式——模板方法模式

    2024-01-17 09:30:07       62 阅读
  9. FPGA的电平标准

    2024-01-17 09:30:07       61 阅读
  10. Hive数据导出的四种方法

    2024-01-17 09:30:07       60 阅读
  11. 贪心算法part04 算法

    2024-01-17 09:30:07       59 阅读