Pytorch学习笔记——认识数据

        最近在跟着小土堆pytorch的视频跟着学习python,根据自己的理解和课程上面的知识,写了这一篇学习笔记。

1、加载数据        

        数据的加载是学习pytorch的第一步,我们需要加载数据,完成特征工程,对加载数据存在的一些特征来进行分析和处理,进而利用相关算法训练得到模型。

        数据该如何加载呢?

        首先,如果是文本之类数据的话,可以使用open()函数进行文件读取操作,对于图片的话,可以使用PIL下面的一个api,调用里面的open()方法来打开图片,如果要使用,则需要进行导包的操作。

from PIL import Image

如果导需要从某种数据源加载数据,并对这些数据进行预处理和格式化的话,利用pytorch中的Dataset类是最为方便的,也需要导包。

from torch.utils.data import Dataset

Dataset类里面定义了两种方法:

  1. __len__(): 返回数据集中的样本数量。
  2. __getitem__(idx): 根据给定的索引 idx 返回一个样本

我们需要自己定义一个类,这个类继承Dataset类,并重写相关方法

需要调用系统路径,导入os模块,不要忘记了

import os

此时先定义一个自定义类,这个自定义类继承于Dataset类

class MyData(Dataset):

然后重写Dataset里面的__init__和__getitem__方法

class MyData(Dataset):

    # 初始化方法,当创建MyData对象时会被调用
    # root_dir: 数据集的根目录
    # label_dir: 自定义的类别目录,通常是某个类别的子目录
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir         # 存储根目录
        self.label_dir = label_dir       # 存储类别目录
        # 拼接根目录和类别目录,得到完整路径
        self.path = os.path.join(self.root_dir, self.label_dir)
        # 获取该类别目录下的所有文件/图片名,并存储到self.img_path列表中
        self.img_path = os.listdir(self.path)

    # 根据索引获取数据集中的单个样本
    # idx: 样本的索引
    def __getitem__(self, idx):
        # 获取索引对应的图片名
        img_name = self.img_path[idx]
        # 拼接完整的图片路径
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        # 打开图片并获取图片对象
        image = Image.open(img_item_path)
        # 使用类别目录作为标签
        label = self.label_dir
        # 返回图片对象和标签
        return image, label

记得在py文件统计目录下有相关文件,我是跟着小土堆的课程,所以就下载了dataset的数据集

最后定义好相关的变量和函数即可

root_dir = "dataset/train" 
ants_label_dir = "ants_image" 
bees_label_dir = "bees_image" 
ants_dataset = MyData(root_dir,ants_label_dir) 
bees_dataset = MyData(root_dir,bees_label_dir) 

这样就能得到ants_dataset和bees_dataset两个类别的数据集

2、查看数据

        查看数据集的话,需要用到DataLoader数据加载类来加载数据,调用Transform来对数据进行增强,通过使用 transforms.Compose 来组合多个转换操作(这是后面要学习的)

下面的代码可以看下,但不用深究,看个大致就行。

from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
from torchvision import transforms

class MyData(Dataset):

    def __init__(self, root_dir, label_dir, transform=None):
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_names = os.listdir(self.path)
        self.transform = transform
        self.label = self.label_dir.replace("_image", "")  # 假设标签是目录名的前缀

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        img_path = os.path.join(self.path, img_name)
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)

        return image, self.label

    def __len__(self):
        return len(self.img_names)

root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"

# 定义数据增强
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整为适合模型输入的尺寸
    transforms.ToTensor(),  # 将 PIL Image 或 numpy.ndarray 转换为 torch.FloatTensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet 的均值和标准差
])

ants_dataset = MyData(root_dir, ants_label_dir, transform=transform)
bees_dataset = MyData(root_dir, bees_label_dir, transform=transform)

# 创建 DataLoader
batch_size = 4
train_loader = DataLoader(ants_dataset, batch_size=batch_size, shuffle=True)

# 查看数据集
for images, labels in train_loader:
    print("Images batch shape:", images.shape)
    print("Labels batch:", labels)

相关推荐

  1. PyTorch学习笔记(一)

    2024-04-30 20:34:02       27 阅读
  2. PyTorch学习笔记(三)

    2024-04-30 20:34:02       24 阅读
  3. PyTorch学习笔记(四)

    2024-04-30 20:34:02       21 阅读
  4. PyTorch学习笔记(六)

    2024-04-30 20:34:02       21 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-30 20:34:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-30 20:34:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-30 20:34:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-30 20:34:02       20 阅读

热门阅读

  1. MAKEFILE 从易到难

    2024-04-30 20:34:02       10 阅读
  2. 【华为OD机试】-(C卷+D卷)-2024最新真题目录

    2024-04-30 20:34:02       19 阅读
  3. 【个人博客搭建】(13)SqlSugar仓储实现

    2024-04-30 20:34:02       9 阅读
  4. 黑客眼中最简单的漏洞,弱口令暴力破解

    2024-04-30 20:34:02       12 阅读
  5. Spring中实现策略模式的几种方式

    2024-04-30 20:34:02       12 阅读
  6. Kafka集群搭建

    2024-04-30 20:34:02       11 阅读
  7. ndk编译android系统下运行的ffmpeg配置

    2024-04-30 20:34:02       10 阅读
  8. 使用通义千问,为汽车软件需求生成测试用例

    2024-04-30 20:34:02       13 阅读