Pytorch:多模态大模型预训练、大模型微调:加载数据的正确姿势

对于近期兴起的多模态大模型的预训练和微调,常见情况是训练数据规模极大,通常可以达到1m-100m级别。此时,训练数据通常用一个上百万行的jsonl文件存储,每行对应一条json格式的训练数据,其中可能包括数据关联的其他图、音、视频数据的索引。例如,阿里通义千问多模态大模型QWen-VL的一条示例数据可能如下所示:

{
  "input": "Picture 1:<img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>这是什么?",
  "output": "图中是一名女子在沙滩上和狗玩耍,旁边是一只拉布拉多犬,它们处于沙滩上。"
}

由于训练数据集过大,在训练读取数据时,直接使用Dataset类可能会带来性能问题。Pytorch的Dataset类在初始化时会将整个数据集加载到内存中,如果数据集非常大,没法全部放在内存里,使用Dataset类会显著增加硬盘io次数,带来性能下降。此时的对策是使用IterableDataset类,可以按需加载数据,而不是一次性将整个数据集加载到内存中。
基于IterableDataset的数据加载,代码实现如下:

import torch
from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, data_file):
        self.data_file = data_file

    def __iter__(self):
        return iter(self._load_data())

    def _load_data(self):
        with open(self.data_file, 'r') as file:
            for line in file:
                sample = process_line(line)
                yield sample

    def process_line(self, line):
        # Process the line to convert it to a sample
        ...
        return sample

# Usage
data_file = 'data.txt'
dataset = MyIterableDataset(data_file)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

for batch in dataloader:
    # Train your model using the batch of data
    pass

在实际训练中还会遇到两个问题:

  1. 大模型一般需要使用多机多卡训练,需要避免多个进程中dataloader读取数据的竞争,并保证不同进程之间不会重复读取数据;
  2. 数据文件中某些行无法正确被解析,或者引用的外部资源找不到,导致process_line成员函数报错。数据集需要handle这类错误,防止因为报错中断训练。

以上问题对策如下:

  1. 在多机多卡的DDP训练中,可以使用DistributedSampler来处理多进程读数据的情形。DistributedSampler可以确保不同进程之间不会重复读取数据。具体的代码实现如下:
# Usage
data_file = 'data.txt'
dataset = MyIterableDataset(data_file)

# Create a DistributedSampler
sampler = DistributedSampler(dataset)

# Create a DataLoader using the DistributedSampler
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, sampler=sampler)

for batch in dataloader:
    # Train your model using the batch of data
    pass
  1. 可以在调用process_line的时候试图handle一个错误,如果出错就跳过这条数据,改为(试图)获取下一条数据。具体的代码实现如下:
import torch
import logger
from torch.utils.data import IterableDataset

class MyIterableDataset(IterableDataset):
    def __init__(self, data_file):
        self.data_file = data_file

    def __iter__(self):
        return iter(self._load_data())

    def _load_data(self):
        with open(self.data_file, 'r') as file:
            for line in file:
                try:
                    sample = process_line(line)
                    yield sample
                except Exception as e:
                    # Print the detailed error information
                    logger.error(line)
                    logger.error(e)
                    pass

    def process_line(self, line):
        # Process the line to convert it to a sample
        ...
        return sample

如果使用的是普通的Dataset,则参考以下代码,在__getitem__里面加入报错逻辑:

class MyDataset(Dataset):
    def __init__(self, file_path):
        self.data = []
        with open(file_path, 'r') as file:
            for line in file:
                self.data.append(line)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        line = self.data[index]
        try:
            sample = self.process_line(line)
            return sample
        except Exception as e:
            # Print the detailed error information
            logger.error(line)
            logger.error(e)
            return self.__getitem__((index+1) % self.__len__())

    def process_line(self, line):
        # Process the line to convert it to a sample
        ...
        return sample

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-19 11:46:05       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-19 11:46:05       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-19 11:46:05       82 阅读
  4. Python语言-面向对象

    2024-03-19 11:46:05       91 阅读

热门阅读

  1. 2022蓝桥杯/李白打酒加强版/c\c++

    2024-03-19 11:46:05       34 阅读
  2. windows平台Qt5连接wifi

    2024-03-19 11:46:05       34 阅读
  3. C++ 11

    C++ 11

    2024-03-19 11:46:05      33 阅读
  4. 一个j简单显示框架及简单实现再探编程_C++

    2024-03-19 11:46:05       35 阅读
  5. csv编辑器是干什么的?

    2024-03-19 11:46:05       38 阅读
  6. C++/CLI学习笔记12(快速打通c++与c#相互调用的桥梁)

    2024-03-19 11:46:05       36 阅读
  7. 【 React 】Real DOM 和Virtual DOM的区别?优缺点?

    2024-03-19 11:46:05       37 阅读
  8. React+umi+dva 项⽬实战-lesson6

    2024-03-19 11:46:05       44 阅读
  9. HCIA_IP路由基础问题?

    2024-03-19 11:46:05       41 阅读