pytorch神经网络训练(AlexNet)

  • 导包
import os

import torch

import torch.nn as nn

import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

from PIL import Image

from torchvision import models, transforms
  • 定义自定义图像数据集
class CustomImageDataset(Dataset): 

定义一个自定义的图像数据集类,继承自Dataset

def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

        self.transform = transform

 图像转换方法,用于对图像进行预处理

        self.files = [] 

存储所有图像文件的路径

        self.labels = [] 

存储所有图像的标签

        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

        for index, label in enumerate(os.listdir(main_dir)):

 遍历主目录中的所有子目录

 

          self.label_to_index[label] = index 

           label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

           if os.path.isdir(label_dir): 

               for file in os.listdir(label_dir): 

                    self.files.append(os.path.join(label_dir, file))

                    self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

def __len__(self):

定义数据集的长度

        return len(self.files) 

返回文件列表的长度

def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

        image = Image.open(self.files[idx]) 

        label = self.labels[idx] 

        if self.transform: 

            image = self.transform(image) 

        return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换
transform = transforms.Compose([

    transforms.Resize((227, 227)),  # AlexNet的输入图像大小

    transforms.RandomHorizontalFlip(),  # 随机水平翻转

    transforms.RandomRotation(10),  # 随机旋转

    transforms.ToTensor(),

    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化

])

  • 创建数据集
dataset = CustomImageDataset(main_dir="D:\\图像处理、深度学习\\flowers", transform=transform)
  • 创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 加载预训练的AlexNet模型
alexnet_model = models.alexnet(pretrained=True)
  • 修改最后几层以适应新的分类任务
num_ftrs = alexnet_model.classifier[6].in_features

alexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))
  • 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)
  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型
if torch.cuda.device_count() > 1:

    alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

alexnet_model.to(device)                                                               

  • 模型评估
def evaluate_model(model, data_loader, device):

    model.eval()  # 将模型设置为评估模式

    correct = 0

    total = 0

    with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度

        for images, labels in data_loader:

            images, labels = images.to(device), labels.to(device)

            outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)

            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total

    return accuracy
  • 训练模型
num_epochs = 10

for epoch in range(num_epochs):

    alexnet_model.train()

    running_loss = 0.0

    for images, labels in data_loader:

        images, labels = images.to(device), labels.to(device)

前向传播

        outputs = alexnet_model(images)

        loss = criterion(outputs, labels)

反向传播和优化

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

在每个epoch结束后评估模型

    train_accuracy = evaluate_model(alexnet_model, data_loader, device)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')

相关推荐

  1. pytorch简单神经网络模型训练

    2024-06-14 12:02:04       33 阅读
  2. Pytorch---实现神经网络模型在GPU上进行训练

    2024-06-14 12:02:04       59 阅读
  3. pytorch写一个神经网络训练示例代码

    2024-06-14 12:02:04       39 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-14 12:02:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-14 12:02:04       100 阅读
  3. 在Django里面运行非项目文件

    2024-06-14 12:02:04       82 阅读
  4. Python语言-面向对象

    2024-06-14 12:02:04       91 阅读

热门阅读

  1. 每天一个数据分析题(三百六十六)- 5WHY分析法

    2024-06-14 12:02:04       28 阅读
  2. docker-compose安装freeradius

    2024-06-14 12:02:04       31 阅读
  3. 【通信协议-RTCM】RTCM信息组

    2024-06-14 12:02:04       33 阅读
  4. 输出数据到excel中

    2024-06-14 12:02:04       24 阅读
  5. MySQL入门学习-聚合和分组.子查询.相关子查询

    2024-06-14 12:02:04       33 阅读
  6. 设计模式之策略模式

    2024-06-14 12:02:04       23 阅读
  7. Ubuntu16-18网卡配置

    2024-06-14 12:02:04       30 阅读
  8. 使用ffmpeg进行音频处理

    2024-06-14 12:02:04       27 阅读
  9. React native新架构组成

    2024-06-14 12:02:04       25 阅读
  10. hive split 特殊用法

    2024-06-14 12:02:04       28 阅读