pytorch学习day2

1 数据加载Dataset

PyTorch的数据读取机制主要依赖于DatasetDataLoader这两个核心组件。它们用于加载和处理数据,以便在训练模型时进行高效的数据流动和处理。

Dataset

Dataset是一个抽象类,用户可以继承这个类并重载以下两个方法来创建自定义的数据集:

  1. __init__ 方法:

    • csv_file:指向包含图像路径和标签的CSV文件路径。
    • root_dir:包含所有图像的根目录路径。
    • transform:一个可选的变换,用于在返回样本之前处理数据。

    在初始化过程中,读取CSV文件并存储在self.data_frame中,还设置了图像的根目录和可选的变换。

  2. __len__ 方法:

    • 返回数据集中样本的数量,即CSV文件中记录的行数。
  3. __getitem__ 方法:

    • 接收一个索引 idx,从CSV文件中获取对应的图像路径和标签。
    • 使用PIL库打开图像文件,并将其转换为RGB格式。
    • 如果定义了变换,则将其应用到图像。
    • 返回处理后的图像和对应的标签。

自定义Dataset示例

import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        label = self.labels[index]
        return sample, label

# 示例数据
data = torch.randn(100, 3)  # 100个样本,每个样本3个特征
labels = torch.randint(0, 2, (100,))  # 100个标签

# 创建自定义数据集
dataset = CustomDataset(data, labels)

2 可迭代的数据装载器DataLoader

DataLoader 是 PyTorch 中一个非常重要的类,用于构建可迭代的数据装载器。它能够有效地加载数据并在训练模型时提供数据批次。下面我们详细介绍 DataLoader 的各个参数和使用方法。

DataLoader 的功能

DataLoader 主要用于在训练过程中,每个 for 循环中从数据集中获取一个指定大小(batch_size)的数据批次。

参数解释

  1. dataset:

    • 类型:Dataset 类实例
    • 功能:决定数据从哪里读取以及如何读取。Dataset 类定义了数据集的具体内容及访问方式。
  2. batch_size:

    • 类型:整数
    • 功能:每个数据批次的大小。例如,batch_size=32 表示每次从数据集中获取32个样本。
  3. num_workers:

    • 类型:整数
    • 功能:决定使用多少个子进程来加载数据。更多的进程数可以加快数据加载速度,但过多的进程数可能会导致系统资源不足,建议设置为 4、8、16 等。
  4. shuffle:

    • 类型:布尔值
    • 功能:决定每个 epoch 开始时是否打乱数据顺序。打乱数据可以增加训练过程的随机性,通常设置为 True
  5. drop_last:

    • 类型:布尔值
    • 功能:如果数据集中的样本数不能被 batch_size 整除,决定是否舍弃最后一个不完整的数据批次。设置为 True 表示舍弃。

重要概念

  1. Epoch:

    • 定义:所有训练样本都已输入到模型中,称为一个 epoch。
  2. Iteration:

    • 定义:一个批次的样本输入到模型中,称为一次 iteration。
  3. Batch Size:

    • 定义:批大小,决定一个 epoch 中有多少次 iteration。
# 创建 DataLoader 实例
dataloader = DataLoader(
    dataset=dataset,       # 自定义数据集
    batch_size=32,         # 每批次32个样本
    shuffle=True,          # 每个epoch开始时打乱数据
    num_workers=4,         # 使用4个子进程加载数据
    drop_last=True         # 当样本数不能被batch_size整除时,舍弃最后一批数据
)

# 训练循环示例
for epoch in range(num_epochs):
    for batch_idx, (data, labels) in enumerate(dataloader):
        # 模型训练代码
        pass

3 图像预处理transforms

在PyTorch中,transforms是一个用于图像预处理的模块。transforms提供了一组常用的图像变换方法,可以对图像进行数据增强、归一化、裁剪、缩放等操作。transforms主要用于将图像数据转换成适合模型输入的格式。

常用的Transforms

以下是一些常用的transforms操作:

  1. transforms.Compose:将多个变换组合起来。
  2. transforms.Resize:调整图像大小。
  3. transforms.CenterCrop:从图像中心裁剪。
  4. transforms.RandomCrop:随机裁剪图像。
  5. transforms.RandomHorizontalFlip:随机水平翻转图像。
  6. transforms.ToTensor:将PIL图像或Numpy数组转换为张量,并将像素值归一化到[0, 1]。
  7. transforms.Normalize:用均值和标准差归一化张量。
  8. transforms.ColorJitter:随机改变图像的亮度、对比度和饱和度。
  9. transforms.RandomRotation:随机旋转图像。
from torchvision import transforms
from PIL import Image

# 定义图像预处理变换
transform = transforms.Compose([
    transforms.Resize((128, 128)),             # 调整图像大小
    transforms.RandomHorizontalFlip(),         # 随机水平翻转
    transforms.RandomRotation(10),             # 随机旋转10度
    transforms.ColorJitter(brightness=0.5),    # 随机改变亮度
    transforms.ToTensor(),                     # 转换为张量并归一化到[0, 1]
    transforms.Normalize((0.5,), (0.5,))       # 用均值0.5和标准差0.5归一化
])

# 加载图像
image = Image.open("path_to_image.jpg").convert("RGB")

# 应用预处理变换
transformed_image = transform(image)

# 检查变换后的图像
print(transformed_image.size())

如果现有的transforms无法满足需求,可以自定义变换。只需实现__call__方法即可

import torch

class CustomTransform:
    def __call__(self, sample):
        # 自定义变换逻辑,例如将图像转换为灰度图
        return transforms.functional.rgb_to_grayscale(sample)

# 使用自定义变换
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    CustomTransform(),
    transforms.ToTensor()
])

image = Image.open("path_to_image.jpg").convert("RGB")
transformed_image = transform(image)
print(transformed_image.size())

4 综合数据读取和数据预处理

以下是一个综合示例,展示如何定义数据集并使用各种transforms进行图像预处理和数据增强。

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import pandas as pd

class CustomCSVImageDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.data_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir, self.data_frame.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        label = self.data_frame.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)

        return image, label

# 定义图像预处理和数据增强
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 示例数据
csv_file = './data/labels.csv'
root_dir = './data/images'

# 创建数据集
dataset = CustomCSVImageDataset(csv_file=csv_file, root_dir=root_dir, transform=transform)

# 创建DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# 迭代DataLoader
for batch_idx, (data, labels) in enumerate(dataloader):
    print(f"Batch {batch_idx}:")
    print("数据大小:", data.size())
    print("标签大小:", labels.size())

相关推荐

  1. pytorch学习day2

    2024-06-07 01:32:04       34 阅读
  2. 学习基于pytorch的VGG图像分类 day2

    2024-06-07 01:32:04       36 阅读
  3. 【Bootstrap学习 day2

    2024-06-07 01:32:04       56 阅读
  4. 机器学习day2

    2024-06-07 01:32:04       29 阅读
  5. 求职学习day2

    2024-06-07 01:32:04       28 阅读
  6. PyTorch 2-深度学习-模块

    2024-06-07 01:32:04       26 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-07 01:32:04       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-07 01:32:04       106 阅读
  3. 在Django里面运行非项目文件

    2024-06-07 01:32:04       87 阅读
  4. Python语言-面向对象

    2024-06-07 01:32:04       96 阅读

热门阅读

  1. React ahooks库和React Query库使用场景分析

    2024-06-07 01:32:04       33 阅读
  2. [力扣题解] 257. 二叉树的所有路径

    2024-06-07 01:32:04       28 阅读
  3. WEB三大主流框架之Vue.js

    2024-06-07 01:32:04       27 阅读
  4. 黑马es数据同步mq解决方案

    2024-06-07 01:32:04       21 阅读
  5. 如何快速入门使用Vue.js

    2024-06-07 01:32:04       32 阅读
  6. 在Linux/Ubuntu/Debian系统中使用 `tar` 压缩文件

    2024-06-07 01:32:04       26 阅读
  7. Attention as an RNN

    2024-06-07 01:32:04       29 阅读