Python PyTorch 获取 MNIST 数据

1 PyTorch 获取 MNIST 数据

import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_get():
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()

2 PyTorch 保存 MNIST 数据

import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_save(mnist_path):
    print(torch.__version__)
    # 定义数据转换
    transform = transforms.Compose([
        transforms.ToTensor(),  # 将图像转换为张量
        transforms.Normalize((0.5,), (0.5,))  # 归一化图像数据
    ])
    # 获取数据
    train_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    # 训练数据
    train_image = train_data.data.numpy()
    train_label = train_data.targets.numpy()
    # 测试数据
    test_image = test_data.data.numpy()
    test_label = test_data.targets.numpy()
    np.savez(mnist_path, train_data=train_image, train_label=train_label, test_data=test_image, test_label=test_label)

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_save(mnist_path)

3 PyTorch 显示 MNIST 数据

import torch
import numpy as np
import matplotlib.pyplot as plt # type: ignore
from torchvision import datasets, transforms

def mnist_show(mnist_path):
    data = np.load(mnist_path)
    image = data['train_data'][0:100]
    label = data['train_label'].reshape(-1, )
    plt.figure(figsize = (10, 10))
    for i in range(100):
        print('%f, %f' % (i, label[i]))
        plt.subplot(10, 10, i + 1)
        plt.imshow(image[i])
    plt.show()

mnist_path = 'C:\\Users\\Hyacinth\\Desktop\\mnist.npz'
mnist_show(mnist_path)

在这里插入图片描述

相关推荐

  1. MINIST数据集&手写数字识别

    2024-04-25 05:54:08       30 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-25 05:54:08       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-25 05:54:08       100 阅读
  3. 在Django里面运行非项目文件

    2024-04-25 05:54:08       82 阅读
  4. Python语言-面向对象

    2024-04-25 05:54:08       91 阅读

热门阅读

  1. python安装第三方包

    2024-04-25 05:54:08       34 阅读
  2. 搭建最新tensorflow 与pytorch环境

    2024-04-25 05:54:08       37 阅读
  3. 本地wsl的Ubuntu安装docker,不使用docker桌面版

    2024-04-25 05:54:08       37 阅读
  4. spring的扩展接口

    2024-04-25 05:54:08       36 阅读
  5. python实现爬虫例子2

    2024-04-25 05:54:08       26 阅读
  6. 十八、QGIS的作用和下载

    2024-04-25 05:54:08       38 阅读
  7. pandas保存dict字段再读取成DataFrame

    2024-04-25 05:54:08       32 阅读
  8. springboot针对thymeleaf的使用总结

    2024-04-25 05:54:08       35 阅读