AI多模态教程:从0到1搭建VisualGLM图文大模型案例

在这里插入图片描述

大家好,我是刘一手,专注于CV算法和多模态大模型在类教育场景的实际应用。

今天给大家带来《AI多模态教程:从0到1搭建VisualGLM图文大模型案例》
——————————————————————————————————————

一、模型介绍

开源多模态模型:VisualGLM-6B 是一个开源的对话模型,具备处理中英文对话和图像的能力**。
参数规模:模型拥有高达78亿参数,提供强大的语言和视觉处理能力。
语言支持:专门设计用于中英文对话,通过BLIP2-Qformer技术实现视觉与语言的有效整合。
预训练数据集:在CogView数据集上进行预训练,包含30M中文和300M英文图文对,确保语言处理的均衡。
微调优化:经过微调,模型能够生成更加贴合人类偏好的对话回答。
训练工具库:利用SwissArmyTransformer库进行训练,支持模型的灵活修改和高效微调。
部署便捷性:通过模型量化技术,可以在消费级显卡上部署,最低仅需8GB显存。
应用场景:模型适用于图像描述和知识问答任务,展示其在视觉和语言理解的综合应用能力。

在这里插入图片描述

二、仓库结构

2.1 克隆Github仓库

克隆命令:

git clone https://github.com/THUDM/VisualGLM-6B.git

如果读者访问不了外网,也可以从下面的云盘下载(见文末)。

2.2 仓库结构

在这里插入图片描述

三、环境安装

我使用的环境:(推荐租用AutoDL或者恒源云的云服务器,显卡显存在24G及以上就可以)

系统:Ubuntu22.04
CUDA驱动版本:11.7
显卡显存:RTX 3090 Ti 24GB
Python版本:3.8

VisualGLM模型的环境中有几个非常重要的依赖包,对版本有要求,版本不同可能会有各种报错,经过反复测试,下面的各版本可以正常运行最新代码(代码更新为2024年3月)

SwissArmyTransformer0.4.5
bitsandbytes
0.39.0
transformers4.33.1
torch
1.13.1
torchvision==0.14.1

这里一手也给大家准备了完整环境的压缩包(本站蜘蛛链接见文末),使用方法:先在miniconda或者conda的envs目录新建一个文件夹,比如visualglm_env,然后进入这个文件夹内,把压缩包复制进去,直接解压就可以使用,免安装:

cd /home/miniconda3/envs
mkdir visualglm_env
tar -xzf visualglm_env.tar.gz -C visualglm_env

四、预训练权重下载

预训练模型是依靠来自于 CogView 数据集的30M高质量中文图文对,与300M经过筛选的英文图文对进行预训练,中英文权重相同。该训练方式较好地将视觉信息对齐到ChatGLM的语义空间;之后的微调阶段,模型在长视觉问答数据上训练,以生成符合人类偏好的答案。

官方提供了两种预训练模型:基于Huggingface的权重和基于 SwissArmyTransformer(简称sat) 的权重。
huggingface版预训练权重

区别:Huggingface可以基于命令行和网页进行推理,但不可以用于训练;sat权重可以基于命令行和网页进行推理,可以当前预训练权重进行微调。这里推荐下载sat权重

下载方法:
(1)官方推荐的方法
如果使用Huggingface transformers库调用模型,可以通过如下代码(其中图像路径为本地路径)自动下载权重文件。
PS:实际上的下载地址:https://huggingface.co/THUDM/visualglm-6b

from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
image_path = "your image path"
response, history = model.chat(tokenizer, image_path, "描述这张图片。", history=[])
print(response)
response, history = model.chat(tokenizer, image_path, "这张图片可能是在什么场所拍摄的?", history=history)
print(response)

如果使用SwissArmyTransformer库调用模型,方法类似,可以使用环境变量SAT_HOME决定模型下载位置。在本仓库目录下:

import argparse
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
from model import chat, VisualGLMModel
model, model_args = VisualGLMModel.from_pretrained('visualglm-6b', args=argparse.Namespace(fp16=True, skip_init=True))
from sat.model.mixins import CachedAutoregressiveMixin
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
image_path = "your image path or URL"
response, history, cache_image = chat(image_path, model, tokenizer, "描述这张图片。", history=[])
print(response)
response, history, cache_image = chat(None, model, tokenizer, "这张图片可能是在什么场所拍摄的?", history=history, image=cache_image)
print(response)

但是以上方法都不推荐!!因为你会因为网速慢、连接不上官网、文件大(~16G)下载速度慢等原因非常崩溃。因为一手推荐用第二种方法。

(2)通过网盘下载
下载链接见文末

除了预训练权重之外,我们还需要依赖ChatGLM来进行图文对话,这里不需要下载整个ChatGLM模型,只需要5个跟tokenizer相关的文件:
在这里插入图片描述

上面两个文件下载完成后解压到项目根目录:
在这里插入图片描述

五、预训练权重推理

我们先使用下载好的SAT权重进行命令行和网页端的推理,测试整体的环境安装是否正确,以及整个数据加载–>推理的流程能否跑通。

5.1 命令行推理

示例代码cli_demo.py:

代码1~2行:增加设置指定的GPU编号,可以根据自己显卡情况修改,如果只有一张卡,可以将1改为0;
代码22行:在quant参数里面设置量化大小,可以选择量化为4bit或者8bit;
代码23行:在from_pretrained参数里面修改为自己本地的visualglm-6b预训练权重路径;
代码50行:在from_pretrained函数里面修改为自己本地的chatglm tokenizer目录路径;

import os  # 导入操作系统接口模块
os.environ['CUDA_VISIBLE_DEVICES'] = "1"  # 设置环境变量,指定使用的第2个CUDA设备,从0开始编号

import sys  # 导入系统模块,用于访问与Python解释器相关的变量和函数
import torch  # 导入PyTorch深度学习框架
import argparse  # 导入命令行参数解析模块
from transformers import AutoTokenizer  # 从transformers库导入自动分词器
from sat.model.mixins import CachedAutoregressiveMixin  # 从sat库导入自动回归混合类
from sat.quantization.kernels import quantize  # 从sat库导入量化函数
from model import chat  # 从model模块导入chat函数
from sat.model import AutoModel  # 从sat.model模块导入AutoModel


def main():
    parser = argparse.ArgumentParser()  # 创建命令行参数解析器
    # 添加命令行参数
    parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')  # 最大序列长度
    parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')  # 核采样的top p值
    parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling')  # top k采样的top k值
    parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')  # 采样温度
    parser.add_argument("--english", action='store_true', help='only output English')  # 是否只输出英文
    parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')  # 量化位数
    parser.add_argument("--from_pretrained", type=str, default="./visualglm-6b",
                        help='pretrained ckpt')  # 预训练模型路径
    parser.add_argument("--prompt_zh", type=str, default="描述这张图片。",
                        help='Chinese prompt for the first round')  # 中文提示语
    parser.add_argument("--prompt_en", type=str, default="Describe the image.",
                        help='English prompt for the first round')  # 英文提示语
    args = parser.parse_args()  # 解析命令行参数

    # 加载模型
    model, model_args = AutoModel.from_pretrained(  # 使用from_pretrained方法加载预训练模型
        args.from_pretrained,
        args=argparse.Namespace(
            fp16=True,  # 是否使用半精度浮点数
            skip_init=True,  # 是否跳过初始化
            use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,  # 是否使用GPU初始化
            device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',  # 设备选择
        )
    )
    model = model.eval()  # 设置模型为评估模式

    if args.quant:  # 如果指定了量化位数
        quantize(model, args.quant)  # 对模型进行量化
        if torch.cuda.is_available():  # 如果CUDA可用
            model = model.cuda()  # 将模型移动到GPU

    model.add_mixin('auto-regressive', CachedAutoregressiveMixin())  # 给模型添加自动回归混合类

    tokenizer = AutoTokenizer.from_pretrained("./chatglm/", trust_remote_code=True)  # 加载分词器
    if not args.english:  # 如果不是英文模式
        print(
            '欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序')  # 打印使用说明
    else:  # 英文模式
        print(
            'Welcome to VisualGLM-6B model. Enter an image URL or local file path to load an image. Continue '
            'inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.')
        # 打印英文使用说明
    with torch.no_grad():  # 禁用梯度计算
        while True:  # 进入主循环
            history = None  # 初始化历史对话记录
            cache_image = None  # 初始化缓存的图像
            if not args.english:  # 如果不是英文模式
                image_path = input("请输入图像路径或URL(回车进入纯文本对话): ")  # 输入图像路径或URL
            else:  # 英文模式
                image_path = input(
                    "Please enter the image path or URL (press Enter for plain text conversation): ")  # 输入图像路径或URL

            if image_path == 'stop':  # 如果输入stop
                break  # 退出循环
            if len(image_path) > 0:  # 如果输入了图像路径
                query = args.prompt_en if args.english else args.prompt_zh  # 设置查询提示语
            else:  # 如果没有输入图像路径,进入纯文本对话
                if not args.english:
                    query = input("用户:")  # 输入中文用户对话
                else:
                    query = input("User: ")  # 输入英文用户对话
            while True:  # 进入对话循环
                if query == "clear":  # 如果用户输入clear
                    break  # 重置对话
                if query == "stop":  # 如果用户输入stop
                    sys.exit(0)  # 退出程序
                try:  # 尝试执行对话
                    response, history, cache_image = chat(  # 调用chat函数进行对话
                        image_path,
                        model,
                        tokenizer,
                        query,
                        history=history,
                        image=cache_image,
                        max_length=args.max_length,
                        top_p=args.top_p,
                        temperature=args.temperature,
                        top_k=args.top_k,
                        english=args.english,
                        invalid_slices=[slice(63823, 130000)] if args.english else []
                    )
                except Exception as e:  # 如果发生异常
                    print(e)  # 打印异常信息
                    break  # 退出循环
                sep = 'A:' if args.english else '答:'  # 设置分隔符
                print("VisualGLM-6B:" + response.split(sep)[-1].strip())  # 打印模型的回复
                image_path = None  # 重置图像路径
                if not args.english:  # 如果不是英文模式
                    query = input("用户:")  # 输入中文用户对话
                else:  # 英文模式
                    query = input("User: ")  # 输入英文用户对话


if __name__ == "__main__":  # 如果是主模块
    main()  # 调用主函数

运行之后可以进行文字对话,也可以输入图片路径进行图像理解和对话。
文字对话:
在这里插入图片描述

图文理解:
在这里插入图片描述

经过测试在显卡型号为3090TI的设备下,不同量化后的模型所占显存大小如下,单位GB:

无量化 INT8 INT4
19.8 14.3 8.8

5.2 网页版推理

示例代码web_demo.py:

代码1~2行:增加设置指定的GPU编号,可以根据自己显卡情况修改,如果只有一张卡,可以将1改为0;
代码131行和159行:修改为自己本地的chatglm tokenizer目录路径;
代码205行:在quant参数里面设置量化大小,可以选择量化为4bit或者8bit;
代码206行:是否共享应用,若为True,则会生成一个公网链接,可分享给其他人使用;
代码207行:在from_pretrained参数里面修改为自己本地的visualglm-6b预训练权重路径;

import os  # 导入操作系统接口模块
os.environ['CUDA_VISIBLE_DEVICES'] = "1"  # 设置环境变量,指定使用的第一个CUDA设备

import argparse  # 导入命令行参数解析模块
import gradio as gr  # 导入Gradio库,用于创建Web应用界面
from PIL import Image  # 从PIL库导入Image类,用于图像处理
from model import is_chinese, generate_input  # 从model模块导入is_chinese和generate_input函数
import torch  # 导入PyTorch深度学习框架
from transformers import AutoTokenizer  # 从transformers库导入自动分词器
from finetune_visualglm import FineTuneVisualGLMModel
from model import chat  # 从model模块导入chat函数
from sat.model import AutoModel  # 从sat.model模块导入AutoModel
from sat.model.mixins import CachedAutoregressiveMixin  # 从sat.model.mixins模块导入CachedAutoregressiveMixin
from sat.quantization.kernels import quantize  # 从sat.quantization.kernels模块导入quantize函数


# 定义一个函数,用于根据输入文本和图像生成文本
def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
    # 设置默认的输入参数
    input_para = {
        "max_length": 2048,
        "min_length": 50,
        "temperature": 0.8,
        "top_p": 0.4,
        "top_k": 100,
        "repetition_penalty": 1.2
    }
    # 更新输入参数
    input_para.update(request_data)

    # 生成输入数据
    input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
    input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
    # 在不计算梯度的情况下执行模型
    with torch.no_grad():
        # 使用chat函数生成回答
        answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
                                  max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
                                  top_k=gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
    return answer


# 定义一个函数,用于处理模型请求
def request_model(input_text, temperature, top_p, image_prompt, result_previous):
    # 处理历史记录
    result_text = [(ele[0], ele[1]) for ele in result_previous]
    # 清理历史记录
    for i in range(len(result_text) - 1, -1, -1):
        if result_text[i][0] == "" or result_text[i][1] == "":
            del result_text[i]
    print(f"history {result_text}")

    # 判断输入文本是否为中文
    is_zh = is_chinese(input_text)
    # 如果没有图像提示,给出错误信息
    if image_prompt is None:
        if is_zh:
            result_text.append((input_text, '图片为空!请上传图片并重试。'))
        else:
            result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
        return input_text, result_text
    # 如果输入文本为空,给出错误信息
    elif input_text == "":
        result_text.append((input_text, 'Text empty! Please enter text and retry.'))
        return "", result_text

    # 设置请求参数
    request_para = {"temperature": temperature, "top_p": top_p}
    # 打开图像文件
    image = Image.open(image_prompt)
    try:
        # 生成文本
        answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
    except Exception as e:
        print(f"error: {e}")
        if is_zh:
            result_text.append((input_text, '超时!请稍等几分钟再重试。'))
        else:
            result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
        return "", result_text

    # 添加新的对话到历史记录
    result_text.append((input_text, answer))
    print(result_text)
    return "", result_text


# 设置应用描述
DESCRIPTION = '''# <a href="">Visual-GLM</a>'''

# 设置维护通知
MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'

# 设置注释
NOTES = 'This app is adapted from <a href="https://github.com/THUDM/VisualGLM-6B">https://github.com/THUDM/VisualGLM-6B</a>. It would be recommended to check out the repo if you want to see the detail of our model and training process.'

# 定义清除函数
def clear_fn(value):
    return "", [("", "Hi, What do you want to know about this image?")], None


# 定义清除函数2
def clear_fn2(value):
    return [("", "Hi, What do you want to know about this image?")]


# 获取模型的函数
def get_model(args):
    global model, tokenizer
    # 加载模型
    model, model_args = AutoModel.from_pretrained(
        args.from_pretrained,
        args=argparse.Namespace(
            fp16=True,
            skip_init=True,
            use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
            device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
        ))
    model = model.eval()

    if args.quant:
        quantize(model.transformer, args.quant)

    if torch.cuda.is_available():
        model = model.cuda()

    model.add_mixin('auto-regressive', CachedAutoregressiveMixin())

    tokenizer = AutoTokenizer.from_pretrained(
        "./chatglm",
        trust_remote_code=True)
    # tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
    return model, tokenizer


# 主函数
def main(args):
    global model, tokenizer
    # 加载模型
    model, model_args = AutoModel.from_pretrained(
        args.from_pretrained,
        args=argparse.Namespace(
            fp16=True,
            skip_init=True,
            use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
            device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
        ))
    model = model.eval()

    if args.quant:
        quantize(model.transformer, args.quant)

    if torch.cuda.is_available():
        model = model.cuda()

    model.add_mixin('auto-regressive', CachedAutoregressiveMixin())

    tokenizer = AutoTokenizer.from_pretrained("./chatglm", trust_remote_code=True)

    # 使用Gradio创建界面
    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(DESCRIPTION)
        with gr.Row():
            with gr.Column(scale=4.5):
                with gr.Group():
                    input_text = gr.Textbox(label='Input Text',
                                            placeholder='Please enter text prompt below and press ENTER.')
                    with gr.Row():
                        run_button = gr.Button('Generate')
                        clear_button = gr.Button('Clear')

                    image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
                with gr.Row():
                    temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
                    top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
                with gr.Group():
                    with gr.Row():
                        maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
            with gr.Column(scale=5.5):
                result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[
                    ("", "Hi, What do you want to know about this image?")]).style(height=550)

        gr.Markdown(NOTES)

        # 设置Gradio版本信息
        print(gr.__version__)
        # 设置按钮点击事件
        run_button.click(fn=request_model, inputs=[input_text, temperature, top_p, image_prompt, result_text],
                         outputs=[input_text, result_text])
        input_text.submit(fn=request_model, inputs=[input_text, temperature, top_p, image_prompt, result_text],
                          outputs=[input_text, result_text])
        clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
        image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
        image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])

    # 启动Gradio应用
    demo.queue(concurrency_count=10)  # 设置并发请求数
    demo.launch(share=args.share)  # 启动应用,允许共享


# 主程序入口
if __name__ == '__main__':
    parser = argparse.ArgumentParser()  # 设置命令行参数解析器
    parser.add_argument("--quant", choices=[8, 4], type=int, default=4)  # 设置量化位数参数
    parser.add_argument("--share", action="store_true", default=True)  # 是否共享应用
    parser.add_argument("--from_pretrained", type=str, default="./visualglm-6b", help='pretrained ckpt')  # 预训练模型路径
    args = parser.parse_args()  # 解析命令行参数
    main(args)  # 调用主函数

运行之后进入网址,可以输入图片路径和问题进行图像理解对话:
在这里插入图片描述
在这里插入图片描述

六、模型微调训练方法

模型微调步骤:按照官方样例准备好数据集、设置微调参数、启动微调训练、使用微调权重推理

6.1 数据准备

官方提供了一个微调数据的格式,解压fewshot-data.zip得到如下结构:
在这里插入图片描述
也就是我们需要准备好需要微调的图片+一个dataset.json文件,后者里面存放的是图文对:
在这里插入图片描述
可以看到dataset.json的内容是一个大列表,列表里面的每一个元素是字典格式:img表示图片的路径、prompt表示提问的文本,label表示回答。这个示例数据主要关注图片的背景内容,所以回复也是背景是xxx的形式。实际上在创建自己的数据集时,prompt和label可以修改为自己想要实现的问题-答案对。

6.2 配置微调脚本

目前支持三种方式的微调:

LoRA:样例中为ChatGLM模型的第0层和第14层加入了rank=10的LoRA微调,可以根据具体情景和数据量调整–layer_range和–lora_rank参数。
QLoRA:如果资源有限,可以考虑使用bash
finetune/finetune_visualglm_qlora.sh,QLoRA将ChatGLM的线性层进行了4-bit量化,只需要9.8GB显存即可微调。
P-tuning:可以将–use_lora替换为–use_ptuning,不过不推荐使用,除非模型应用场景非常固定。

注意微调需要安装deepspeed库,目前本流程仅支持linux系统。

下面是在单卡上使用LoRA在示例数据集fewshot-data的微调脚本:finetune_visualglm.sh,batch-size=2时训练需要18GB显存。

#! /bin/bash
NUM_WORKERS=1
NUM_GPUS_PER_WORKER=8
MP_SIZE=1

script_path=$(realpath $0)
script_dir=$(dirname $script_path)
main_dir=$(dirname $script_dir)
MODEL_TYPE="visualglm-6b"
MODEL_ARGS="--max_source_length 64 \
    --max_target_length 256 \
    --lora_rank 10 \
    --layer_range 0 14 \
    --pre_seq_len 4"

# OPTIONS_SAT="SAT_HOME=$1" #"SAT_HOME=/raid/dm/sat_models"
OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
HOST_FILE_PATH="hostfile"
HOST_FILE_PATH="hostfile_single"

train_data="./fewshot-data/dataset.json"
eval_data="./fewshot-data/dataset.json"


gpt_options=" \
       --experiment-name finetune-$MODEL_TYPE \
       --model-parallel-size ${MP_SIZE} \
       --mode finetune \
       --train-iters 1000 \
       --resume-dataloader \
       $MODEL_ARGS \
       --train-data ${train_data} \
       --valid-data ${eval_data} \
       --distributed-backend nccl \
       --lr-decay-style cosine \
       --warmup .02 \
       --checkpoint-activations \
       --save-interval 300 \
       --eval-interval 100 \
       --save "./checkpoints" \
       --split 1 \
       --eval-iters 10 \
       --eval-batch-size 1 \
       --zero-stage 1 \
       --lr 0.0001 \
       --batch-size 2 \
       --skip-init \
       --fp16 \
       --use_lora
"

              

run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --include localhost:1 --hostfile ${HOST_FILE_PATH} finetune_visualglm.py ${gpt_options}"
echo ${run_cmd}
eval ${run_cmd}

set +x

微调参数说明:
这个脚本是一个用于配置和启动VisualGLM训练任务的Bash脚本,它包含了多个参数,用于设置训练环境和训练过程的各个方面。下面是对脚本中参数的详细说明:

NUM_WORKERS: 指定训练过程中使用的worker数量。每个worker可以运行在不同的机器上。
NUM_GPUS_PER_WORKER: 每个worker使用的GPU数量。这决定了每个训练进程可以利用的GPU资源。 MP_SIZE:
模型并行大小,即模型被分割成多少部分在不同的进程中并行处理。 script_path: 当前脚本的绝对路径。 script_dir:
当前脚本所在的目录。 main_dir: 脚本所在目录的上一级目录,通常用于存放项目的主要文件。 MODEL_TYPE:
使用的模型类型,这里是visualglm-6b,表示一个特定的视觉语言模型。 MODEL_ARGS: 一系列与模型相关的参数,包括:
–max_source_length 64: 输入序列的最大长度。
–max_target_length 256: 输出序列的最大长度。
–lora_rank 10: LoRA(局部重参数化)的秩。
–layer_range 0 14: 参与并行处理的模型层的范围。
–pre_seq_len 4: 预序列的长度。 OPTIONS_SAT: 环境变量设置,这里被注释掉了,如果取消注释,它将设置SAT_HOME环境变量。 OPTIONS_NCCL: 一系列与NVIDIA
Collective Communications Library (NCCL) 相关的环境变量设置,用于优化GPU之间的通信。
HOST_FILE_PATH: 指定主机文件的路径,该文件包含了参与训练的所有机器的列表。 train_data 和 eval_data:
分别指定训练数据和评估数据的路径。 gpt_options: 一系列用于配置训练任务的参数,包括:
–experiment-name: 实验名称。
–model-parallel-size: 模型并行大小。
–mode: 训练模式,这里是finetune。
–train-iters: 训练迭代次数。
–resume-dataloader: 是否恢复数据加载器的状态。
–train-data 和 --valid-data: 分别指定训练和验证数据的路径。
–distributed-backend: 分布式训练的后端,这里是nccl。
–lr-decay-style: 学习率衰减方式。
–warmup: 学习率预热的比例。
–save-interval 和 --eval-interval: 分别指定保存和评估的间隔。
–save: 模型保存的路径。
–split: 数据集分割的比例。
–eval-iters 和 --eval-batch-size: 分别指定评估的迭代次数和批次大小。
–zero-stage: 零优化的阶段。
–lr: 学习率。
–batch-size: 批次大小。
–skip-init: 是否跳过模型初始化。
–fp16: 是否使用半精度浮点数。
–use_lora: 是否使用LoRA技术。 run_cmd: 构建用于启动训练任务的命令字符串,包括NCCL和SAT的环境变量设置,以及deepspeed训练脚本的调用。 echo
${run_cmd}: 打印构建的命令字符串。 eval ${run_cmd}: 执行构建的命令字符串,启动训练任务。 set +x:
在脚本的最后,这行命令用于关闭xtrace(调试模式),这样在执行脚本时不会打印出所有的命令。

6.3 训练过程

在终端进入项目根目录,使用下面的命令启动训练:
bash finetune/finetune_visualglm.sh
在微调训练过程中会打印当前训练步数、学习率、损失值:
在这里插入图片描述

在eval-interval步数后开始验证,并打印在验证集上的损失值和PPL(PPL全称Perplexity ,指困惑度,用来衡量语言模型好坏的指标。简单说,perplexity值刻画的是语言模型预测一个语言样本的能力。在一个测试集上得到的perplexity值越低,说明建模效果越好):
在这里插入图片描述

在save-interval步数后进行保存,并打印保存路径:
在这里插入图片描述

在train-iters步数后结束训练,保存最后的权重并关闭GPU连接:
在这里插入图片描述
在上述过程中,还会通过tensorboard将更多日志信息保存在runs文件夹,使用下面的命令打开查看:

tensorboard --logdir==runs --port 6007 --bind_all

在这里插入图片描述

这里能看到训练损失、验证损失、验证指标的变化情况(PPL值越小越好)

6.4 微调权重推理

微调后的权重存放在checkpoints目录下,加载方式web_demo.py相同,只需要把主函数中的from_pretrained参数修改为自己本地的微调权重路径:

.....
# 主程序入口
if __name__ == '__main__':
    parser = argparse.ArgumentParser()  # 设置命令行参数解析器
    parser.add_argument("--quant", choices=[8, 4], type=int, default=None)  # 设置量化位数参数
    parser.add_argument("--share", action="store_true", default=True)  # 是否共享应用
    parser.add_argument("--from_pretrained", type=str, default="./checkpoints/finetune-visualglm-6b-02-27-14-41", help='pretrained ckpt')  # 预训练模型路径
    args = parser.parse_args()  # 解析命令行参数
    main(args)  # 调用主函数
......

运行后上传图片进行推理:
在这里插入图片描述

七、模型部署

这里为大家提供一种基于Flask API的方式部署VisualGLM服务。

步骤一:主服务功能配置;
步骤二:API接口调用;
好处:调用速度快,可以灵活地在不同场景下使用。

7.1 实现主服务

主服务是指基于Flask构建一个后台运行的算法服务,使用可以直接使用API的方式调用。

主服务代码image_caption_server.py:

代码16行:from_pretrained参数可以设置为预训练权重路径或者微调权重路径; loguru
用来记录推理结果或者记录程序异常日志,方便后期进行bug排查;

import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"

import time
import torch
import argparse
from flask import Flask, request
from flask_cors import cross_origin
from loguru import logger
from web_demo import get_model
from model import is_chinese, generate_input, chat

parser = argparse.ArgumentParser()
parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
parser.add_argument("--share", action="store_true", default=False)
parser.add_argument("--from_pretrained", type=str, default="checkpoints/finetune-visualglm-6b-02-27-14-41",
                    help='pretrained ckpt')
args = parser.parse_args()
model, tokenizer = get_model(args)

ti_ = time.localtime()
date_ = f"{ti_[0]}_{ti_[1]}_{ti_[2]}"
current_time_path = f"./logs/runtime/{date_}"
os.makedirs(current_time_path, exist_ok=True)
logger.remove(handler_id=None)
logger.add(os.path.join(current_time_path, "runtime_{time}.log"), rotation='1 day')

app = Flask(__name__)


@app.route('/image_caption', methods=['POST'])
@logger.catch()
@cross_origin()
def image_caption():
    if request.method == "POST":
        try:
            print("Start to process request")
            request_data = request.get_json()
            input_text, input_image_encoded, history = request_data['text'], request_data['image'], request_data[
                'history']
            input_para = {
                "max_length": 2048,
                "min_length": 50,
                "temperature": 0.8,
                "top_p": 0.4,
                "top_k": 100,
                "repetition_penalty": 1.2
            }
            input_para.update(request_data)

            is_zh = is_chinese(input_text)
            input_data = generate_input(input_text, input_image_encoded, history, input_para)
            input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
            with torch.no_grad():
                answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
                                          max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
                                          top_k=gen_kwargs['top_k'], temperature=gen_kwargs['temperature'],
                                          english=not is_zh)
        except Exception as e:
            logger.info(e)
            answer = "暂无相关描述,请检查图像内容!"
        response = {
            "result": answer,
            "history": history,
            "status": 200,
        }
        logger.info(answer)
        return response


if __name__ == '__main__':
    app.run(debug=False, host='0.0.0.0', port=9500)

7.2 实现API接口调用

主服务相当于随时准备推理的模块,我们还需要通过API调用主服务的功能。

API调用代码image_caption_api.py:

import base64
import requests


def visual_api(imgb64):
    # 请求数据
    data = {
        "text": "请详细描述这张图片的内容",
        "image": imgb64,
        "history": []
    }
    # 发起POST请求
    response = requests.post('http://127.0.0.1:9500/image_caption', json=data)
    result = response.json()
    return result


if __name__ == '__main__':
    one_img = "examples/3.jpeg"
    imgbase64 = base64.b64encode(open(one_img, 'rb').read()).rstrip().decode('utf-8')
    res = visual_api(imgbase64)
    print("回答:", res['result'])

首先运行主服务代码image_caption_server.py,运行成功后会显示调用的链接:
在这里插入图片描述

然后在API调用程序image_caption_api.py里面设置图片路径和提问的text,text默认是"请详细描述这张图片的内容",运行程序就可以得到推理结果:
在这里插入图片描述

八、网盘链接汇总

1、VisualGLM代码文件:
链接:https://pan.baidu.com/s/15GfpCubvSgnkbSNtVnrKbQ?pwd=xy2t

2、预训练权重文件
链接:https://pan.baidu.com/s/1MLMIb42AOJnqa4R1so6dxg?pwd=23uf

3、chatglm tokenizer文件
链接:https://pan.baidu.com/s/1kq25Jdab5_p1umDgJaSdyQ?pwd=9ir8

4、python虚拟环境压缩包
链接:https://pan.baidu.com/s/1-7VA9s2iwjFMyqSOZhiMgQ?pwd=u39c

九、常见错误及解决方法

1、运行web_demo.py报错:Permission denied: /tmp/gradio/xx
解决方法:更改/tmp/gradio目录权限为可读可写:sudo chmod -R 777 /tmp/gradio/

2、运行cli_demo.py或者web_demo.py报错:model_class FineTuneVisualGLMModel not
found
解决方法:from finetune_visualglm import FineTuneVisualGLMModel

3、ChatGLMTokenizer’ object has no attribute 'tokenizer
解决方法:重装transformers到4.33.1

写在后面

如果大家对多模态大模型感兴趣,可以扫码加群学习交流,二维码失效可以添加微信:lzz9527288
在这里插入图片描述

最近更新

  1. TCP协议是安全的吗?

    2024-03-25 16:58:01       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-25 16:58:01       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-25 16:58:01       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-25 16:58:01       20 阅读

热门阅读

  1. 开源与闭源语言模型的较量:技术分析

    2024-03-25 16:58:01       16 阅读
  2. 大数据安全分析相关与安全分析的场景

    2024-03-25 16:58:01       15 阅读
  3. IOS面试题编程机制 46-50

    2024-03-25 16:58:01       15 阅读
  4. SGD优化器和Adam区别

    2024-03-25 16:58:01       18 阅读
  5. 我的算法刷题笔记(3.18-3.22)

    2024-03-25 16:58:01       21 阅读
  6. 什么是微任务?什么是宏任务?

    2024-03-25 16:58:01       19 阅读
  7. IOS面试题编程机制 31-35

    2024-03-25 16:58:01       17 阅读
  8. JVM G1垃圾回收器的工作内容

    2024-03-25 16:58:01       17 阅读
  9. 5.86 BCC工具之tcpstates.py解读

    2024-03-25 16:58:01       17 阅读
  10. 1928递归去处理压缩字符串

    2024-03-25 16:58:01       18 阅读
  11. P5963 [BalticOI ?] Card 卡牌游戏 贪心

    2024-03-25 16:58:01       20 阅读
  12. [Repo Git] manifests的写法

    2024-03-25 16:58:01       23 阅读