diffusers 使用脚本导入自定义数据集

在训练扩散模型时,如果附加额外的条件图片数据,则需要我们准备相应的数据集。此时我们可以使用官网提供的脚本模板来控制导入我们需要的数据。

您可以参考官方的教程来实现具体的功能需求,为了更加简洁,我将简单描述一下整个流程的关键点:

  1. 首先按照您的需求准备好所有的数据集文件,统一放到一个dataset_name(可以自己定义)目录下,可以划分多个子文件夹,但是需要在您的matadata.json中描述好相对路径位置;这一步和平时准备数据集的过程一样,只是多了额外的条件图片数据。
  2. 在dataset_name下创建同名的dataset_name.py脚本文件,该脚本文件的类名要和脚本名一致,并复制下文的模板内容,然后修改特定位置:
import pandas as pd
from huggingface_hub import hf_hub_url
import datasets
import os

_VERSION = datasets.Version("0.0.2")

_DESCRIPTION = "TODO"
_HOMEPAGE = "TODO"
_LICENSE = "TODO"
_CITATION = "TODO"

_FEATURES = datasets.Features(
    {
        "image": datasets.Image(),
        "conditioning_image": datasets.Image(),
        "text": datasets.Value("string"),
    },
)

METADATA_URL = hf_hub_url(
    "fusing/fill50k",
    filename="train.jsonl",
    repo_type="dataset",
)

IMAGES_URL = hf_hub_url(
    "fusing/fill50k",
    filename="images.zip",
    repo_type="dataset",
)

CONDITIONING_IMAGES_URL = hf_hub_url(
    "fusing/fill50k",
    filename="conditioning_images.zip",
    repo_type="dataset",
)

_DEFAULT_CONFIG = datasets.BuilderConfig(name="default", version=_VERSION)


class Fill50k(datasets.GeneratorBasedBuilder):
    BUILDER_CONFIGS = [_DEFAULT_CONFIG]
    DEFAULT_CONFIG_NAME = "default"

    def _info(self):
        return datasets.DatasetInfo(
            description=_DESCRIPTION,
            features=_FEATURES,
            supervised_keys=None,
            homepage=_HOMEPAGE,
            license=_LICENSE,
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        metadata_path = dl_manager.download(METADATA_URL)
        images_dir = dl_manager.download_and_extract(IMAGES_URL)
        conditioning_images_dir = dl_manager.download_and_extract(
            CONDITIONING_IMAGES_URL
        )

        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,
                # These kwargs will be passed to _generate_examples
                gen_kwargs={
                    "metadata_path": metadata_path,
                    "images_dir": images_dir,
                    "conditioning_images_dir": conditioning_images_dir,
                },
            ),
        ]

    def _generate_examples(self, metadata_path, images_dir, conditioning_images_dir):
        metadata = pd.read_json(metadata_path, lines=True)

        for _, row in metadata.iterrows():
            text = row["text"]

            image_path = row["image"]
            image_path = os.path.join(images_dir, image_path)
            image = open(image_path, "rb").read()

            conditioning_image_path = row["conditioning_image"]
            conditioning_image_path = os.path.join(
                conditioning_images_dir, row["conditioning_image"]
            )
            conditioning_image = open(conditioning_image_path, "rb").read()

            yield row["image"], {
                "text": text,
                "image": {
                    "path": image_path,
                    "bytes": image,
                },
                "conditioning_image": {
                    "path": conditioning_image_path,
                    "bytes": conditioning_image,
                },
            }

  1. 修改时主要关注两个函数,和一些命名:
  • 第一个是_split_generators(),把所有download相关的内容注释掉,这里会让你去下载官方的数据集,我们的需求是准备自己的数据集,所以为了方便直接把这个函数中的关键文件路径改为自己的绝对路径,比如metadata_path,就是你的metadata.json的路径,images_dir和conditioning_images_dir是你的图片的上级目录的绝对路径。这里我曾经测试过使用相对路径,发现是行不通的,主要的问题是diffuers在项目运行时会把当前的脚本先拷贝到c盘,然后再加载入内存,所以相对路径会不起作用。
  • 第二个是_generate_examples(),我们需要按照上个函数给出的路径依次加载图片文件和文本,这里主要是把所有的数据集内容修改为你需要的信息。这里有个关键点是,你必须保证metadata.json中第一列image的内容是不重复的,因为该列会作为索引的key值出现,否则会报错。
  • 最后是把脚本中所有与数据集信息相关的名称校对为你需要的。

在训练过程中,指定好数据集dataset_name的位置,diffusers会自动调用dataset_name.py来读取数据集中的数据。

相关推荐

  1. diffusers 使用脚本导入定义数据

    2024-06-09 03:10:02       12 阅读
  2. DataLoader定义数据制作

    2024-06-09 03:10:02       18 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-09 03:10:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-09 03:10:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-09 03:10:02       19 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-09 03:10:02       20 阅读

热门阅读

  1. 【设计模式】装饰器模式(结构型)⭐⭐

    2024-06-09 03:10:02       8 阅读
  2. 电商API在促进销售与营销中的影响

    2024-06-09 03:10:02       11 阅读
  3. Zookeeper 详解:分布式协调服务的核心概念与实践

    2024-06-09 03:10:02       11 阅读
  4. Pytorch中的广播机制

    2024-06-09 03:10:02       8 阅读
  5. Access数据中的SQL偏移注入

    2024-06-09 03:10:02       9 阅读
  6. 在Spark SQL中,fillna函数

    2024-06-09 03:10:02       11 阅读
  7. SELinux:安全增强型Linux

    2024-06-09 03:10:02       10 阅读
  8. 嵌入式C中Hex与Bin文件对比分析

    2024-06-09 03:10:02       11 阅读
  9. 数据结构学习笔记-二叉树

    2024-06-09 03:10:02       10 阅读