【深度学习】pytorch,MNIST手写数字分类

efficientnet_b0的迁移学习


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import models
import matplotlib.pyplot as plt

# 定义超参数
batch_size = 240
learning_rate = 0.001
num_epochs = 10

# 数据预处理,包括调整图像大小并将单通道图像复制到三个通道
transform = transforms.Compose([
    transforms.Resize(224),  # 调整图像大小以适应EfficientNetB0
    transforms.Grayscale(num_output_channels=3),  # 将单通道图像复制到三个通道
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 使用ImageNet的均值和标准差
])

# 加载数据集
train_dataset = MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='./data', train=False, transform=transform, download=True)

# 创建数据加载器
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=32)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载预训练的EfficientNetB0模型并调整最后的分类层
model = models.efficientnet_b0(pretrained=True)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)  # MNIST共10个类别
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 用于绘图的数据
train_losses = []
test_accuracies = []

# 训练模型
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        print(f"\rEpoch {epoch + 1}/{num_epochs}, Batch {batch_idx + 1}/{len(train_loader)}, Loss: {loss.item():.4f}")

    # 计算平均损失
    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)

    # 测试准确率
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)  # Move test data to the correct device
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()

    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

# save
torch.save(model.state_dict(), 'mnist_efficientnetb0.pth')

# 绘制损失函数和准确率图
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.show()

训练10轮,测试准确率很猛:

Epoch 10/10, Loss: 0.0087, Test Accuracy: 99.60%

在这里插入图片描述

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-25 05:20:10       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-25 05:20:10       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-25 05:20:10       82 阅读
  4. Python语言-面向对象

    2024-03-25 05:20:10       91 阅读

热门阅读

  1. [AIGC] OkHttp:轻松实现网络请求

    2024-03-25 05:20:10       43 阅读
  2. 智能写作利器ChatGPT:提升论文写作效率

    2024-03-25 05:20:10       49 阅读
  3. 数据分析-Pandas分类数据的比较如何避坑

    2024-03-25 05:20:10       44 阅读
  4. 在Flink SQL中使用watermark进阶功能

    2024-03-25 05:20:10       43 阅读
  5. 使用docker搭建dockge

    2024-03-25 05:20:10       40 阅读
  6. 自学python指导教程

    2024-03-25 05:20:10       35 阅读
  7. Nodejs版本管理工具nvm

    2024-03-25 05:20:10       42 阅读
  8. Chinese-LLaMA-Alpaca-2模型量化部署&测试

    2024-03-25 05:20:10       34 阅读
  9. 【Python】复习12:标准库与第三方库

    2024-03-25 05:20:10       40 阅读
  10. Postgresql中常见的执行计划解释

    2024-03-25 05:20:10       37 阅读
  11. vue3模板引用介绍

    2024-03-25 05:20:10       49 阅读
  12. 数据结构面试常见问题

    2024-03-25 05:20:10       43 阅读