pytorch学习笔记6

想要找一些官方的小工具数据集,可以进入pytorch官网,DOCS-》pytorch下拉至libraries,点击torchversion,调整版本至0.9.0就可以找到相应的一些数据集,训练集
ctrl+p可以看一个函数中需要设置哪些参数

下载数据集可以参考官方文档中的描述对数据集进行下载
在这里插入图片描述

import torchvision

train_set=torchvision.datasets.CIFAR10(root='./dataset',train=True,download=True)
#root表示数据集存放在那个位置./表示当前目录
#train如果为True,则从训练集创建数据集,否则从测试集创建数据集
#download如果为true,则从互联网下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。
#这些参数都可以从官方文档中获取
test_set=torchvision.datasets.CIFAR10(root='./dataset',train=False,download=True)

下载慢的话可以用迅雷下载
在这里插入图片描述
使用这个地址
在这里插入图片描述
<PIL.Image.Image image mode=RGB size=32x32 at 0x179A0B20A90> 表示这个样本是一个 32x32 像素的 RGB 彩色图像,使用 PIL 库表示。
3 是这个图像对应的标签,即这个图像所代表的物体类别在 CIFAR-10 数据集中的索引(CIFAR-10 数据集共有 10 个类别,索引从 0 到 9)

torchvision.transforms.Compose 是一个方便的工具,用于将多个图像变换操作组合成一个单一的变换。在图像处理和深度学习模型的训练过程中,通常需要对图像进行一系列的预处理操作,例如裁剪、缩放、归一化等。Compose 允许你将这些操作串联起来,使其按顺序应用于每个输入图像。

import torchvision.transforms as transforms

dataset_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

如下进行

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
train_set=torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=dataset_transform,download=True)
test_set=torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=dataset_transform,download=True)
print(test_set[0])

加上transform的流程对数据进行处理
这样批量将导入的数据进行处理

在图像处理中和深度学习模型的训练过程中,Normalize 变换的主要作用是对图像数据进行标准化处理。具体来说,Normalize 会调整图像的像素值,使其符合某个特定的分布。这通常有助于加速模型的训练过程并提高模型的性能。下面是 Normalize 的作用和原理的详细解释。

作用
加速模型收敛:通过将输入数据标准化,可以使模型的梯度更加稳定,避免某些特征对模型训练造成过大的影响,从而加速模型的收敛速度。
提高模型性能:标准化可以使不同特征的数据分布更加一致,有助于模型更好地理解和学习数据的特征,提高模型的性能。
防止梯度消失和梯度爆炸:标准化可以将输入数据的范围限制在一个较小的范围内,防止梯度在传播过程中变得过大或过小,稳定模型的训练过程。

transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

对于红色通道,减去均值 0.485,再除以标准差 0.229。
对于绿色通道,减去均值 0.456,再除以标准差 0.224。
对于蓝色通道,减去均值 0.406,再除以标准差 0.225。

ctrl+/可以快速对多行进行注释

相关推荐

  1. PyTorch学习笔记(一)

    2024-06-14 00:50:02       24 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-14 00:50:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-14 00:50:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-14 00:50:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-14 00:50:02       18 阅读

热门阅读

  1. MySQL CHECK约束

    2024-06-14 00:50:02       7 阅读
  2. Android基础-运行时权限

    2024-06-14 00:50:02       4 阅读
  3. 013-Linux交换分区管理

    2024-06-14 00:50:02       6 阅读
  4. ios CCDelete.m

    2024-06-14 00:50:02       5 阅读
  5. 项目经验:别啥事都跟甲方讲

    2024-06-14 00:50:02       5 阅读
  6. 【设计模式之享元模式 -- C++】

    2024-06-14 00:50:02       6 阅读
  7. 文件已经删除但磁盘空间未释放

    2024-06-14 00:50:02       5 阅读
  8. TikTok限流封号要如何处理

    2024-06-14 00:50:02       7 阅读
  9. 关于自学编程的9点忠告

    2024-06-14 00:50:02       6 阅读
  10. vue中v-bind控制class和style

    2024-06-14 00:50:02       11 阅读
  11. 使用Python多线程批量压缩图片文件

    2024-06-14 00:50:02       6 阅读