torch之从.datasets.CIFAR10解压出训练与测试图片 (附带网盘链接)

前言
从官网上下载的是长这个样子的
在这里插入图片描述
想看图片,咋办咧,看下面代码

import torch
import torchvision
import numpy as np
import os
import cv2
batch_size = 50

transform_predict = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])
#-----#
# train 为True 则是解压出训练图片 为Fasle的时候则解压出测试图片
#------#
image_data = torchvision.datasets.CIFAR10(
    root='/home/netted/img_process_ml/temp', train=True, download=False, transform=transform_predict)
image_loader = torch.utils.data.DataLoader(
    image_data, batch_size, shuffle=True, num_workers=0)

path = '/home/netted/img_process_ml/temp/train'
os.makedirs(path,exist_ok=True)
for i in range(10):
    os.makedirs(f'{path}/{i}',exist_ok=True)


def format(image):
    image = image.clone().detach().cpu().squeeze(0)
    image = np.around(image.mul(255))
    image = np.uint8(image).transpose(1, 2, 0)
    return image


def data(image_loader):
    idx0 = 0
    idx1 = 0
    idx2 = 0
    idx3 = 0
    idx4 = 0
    idx5 = 0
    idx6 = 0
    idx7 = 0
    idx8 = 0
    idx9 = 0

    for i, (data, target) in enumerate(image_loader):

        for idx in range(len(data)):
            label = target[idx].item()
            image = format(data[idx])

            if label == 0:
                cv2.imwrite(f'{path}/{label}/plane_{idx0}.png',image)
                idx0 += 1

            if label == 1:
                cv2.imwrite(f'{path}/{label}/car_{idx1}.png', image)
                idx1 += 1

            if label == 2:
                cv2.imwrite(f'{path}/{label}/bird_{idx2}.png', image)
                idx2 += 1

            if label == 3:
                cv2.imwrite(f'{path}/{label}/cat_{idx3}.png', image)
                idx3 += 1

            if label == 4:
                cv2.imwrite(f'{path}/{label}/deer_{idx4}.png', image)
                idx4 += 1

            if label == 5:
                cv2.imwrite(f'{path}/{label}/dog_{idx5}.png', image)
                idx5 += 1

            if label == 6:
                cv2.imwrite(f'{path}/{label}/frog_{idx6}.png', image)
                idx6 += 1

            if label == 7:
                cv2.imwrite(f'{path}/{label}/horse_{idx7}.png', image)
                idx7 += 1

            if label == 8:
                cv2.imwrite(f'{path}/{label}/ship_{idx8}.png', image)
                idx8 += 1

            if label == 9:
                cv2.imwrite(f'{path}/{label}/truck_{idx9}.png', image)
                idx9 += 1

data(image_loader)

然后就解压出来了
在这里插入图片描述
在这里插入图片描述
当然可以自行调整将它们都合在一个文件夹里面,个人喜好

原包与自己生成好的链接如下:
链接:https://pan.baidu.com/s/1pkAFVjZ2f3ibPvMe4TtjOQ?pwd=noia
提取码:noia

欢迎大家点赞或收藏~
可以鼓励作者加快更新哟~

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-13 08:36:01       66 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-13 08:36:01       70 阅读
  3. 在Django里面运行非项目文件

    2024-07-13 08:36:01       57 阅读
  4. Python语言-面向对象

    2024-07-13 08:36:01       68 阅读

热门阅读

  1. 如何实现一个二叉搜索树

    2024-07-13 08:36:01       26 阅读
  2. 小妙招使用sysctl hw.realmem查看实际物理内存@FreeBSD

    2024-07-13 08:36:01       18 阅读
  3. 网络设备安全

    2024-07-13 08:36:01       23 阅读
  4. sqlalchemy.orm中validates对两个字段进行联合校验

    2024-07-13 08:36:01       26 阅读
  5. Grafana

    Grafana

    2024-07-13 08:36:01      23 阅读
  6. VB 实例:掌握 Visual Basic 编程的精髓

    2024-07-13 08:36:01       20 阅读
  7. Spuer().__init__的意义

    2024-07-13 08:36:01       28 阅读
  8. 匿名函数与函数

    2024-07-13 08:36:01       28 阅读
  9. ios CCRuntime.m

    2024-07-13 08:36:01       23 阅读
  10. js项目生产环境中移除 console

    2024-07-13 08:36:01       24 阅读
  11. uniapp微信小程序授权登录实现

    2024-07-13 08:36:01       24 阅读
  12. 版本发布 | IvorySQL 3.3 发版

    2024-07-13 08:36:01       26 阅读
  13. 【分布式系统】Ceph对象存储系统之RGW接口

    2024-07-13 08:36:01       27 阅读
  14. 浅谈PostCSS

    2024-07-13 08:36:01       26 阅读