优化器(一)torch.optim.SGD-随机梯度下降法

torch.optim.SGD-随机梯度下降法

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True,
                                       transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model1(x)
        return x


tudui = Tudui()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        running_loss += result_loss
    print(running_loss)


在这里插入图片描述

相关推荐

  1. 随机梯度下降算法

    2024-01-08 08:58:03       43 阅读
  2. 随机梯度下降(SGD)

    2024-01-08 08:58:03       32 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-08 08:58:03       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-08 08:58:03       106 阅读
  3. 在Django里面运行非项目文件

    2024-01-08 08:58:03       87 阅读
  4. Python语言-面向对象

    2024-01-08 08:58:03       96 阅读

热门阅读

  1. python笔记-自用

    2024-01-08 08:58:03       57 阅读
  2. React07-路由管理器react-router

    2024-01-08 08:58:03       49 阅读
  3. MySQL第一讲:MySQL知识体系详解(P6精通)

    2024-01-08 08:58:03       54 阅读
  4. 企业云安全能力建设的要点

    2024-01-08 08:58:03       54 阅读
  5. es6中箭头函数 原型

    2024-01-08 08:58:03       56 阅读
  6. shtml与html的区别

    2024-01-08 08:58:03       53 阅读
  7. oracle xml_data 包的使用

    2024-01-08 08:58:03       60 阅读
  8. 【面试】Redis基础知识

    2024-01-08 08:58:03       57 阅读