神经网络搭建(1)----nn.Sequential

神经网络模型构建

采用CIFAR10中的数据,并对其进行简单的分类。以下图为例

输入:3通道,32×32 ( 经过一个5×5的卷积)
→ 变成32通道,32×32的图像 (经过2×2的最大池化)
→ 变成32通道,16×16的图像 ( 经过一个5×5的卷积)
→ 变成32通道,16×16的图像 (经过2×2的最大池化)
→ 变成32通道,8×8的图像 ( 经过一个5×5的卷积)
→ 变成64通道,8×8的图像(经过2×2的最大池化)
→ 变成64通道,4×4的图像
→ 把图像展平(Flatten)(变成64通道,1×1024 (64×4×4) 的图像)
→先通过一个线性层Linear(in_features=64 * 4 * 4, out_features=64)
→再经过一个线性层Linear(64, 10)→ 得到最终图像

以上就是一个神经网络模型的构建

神经网络中的参数设计及计算

卷积层的参数设计(以第一个卷积层conv2为例)

  • 输入图像为3通道,输出图像为32通道,故:in_channels=3,  out_channels=32
  • 卷积核尺寸为 kernel_size=5  (5*5)

  • 图像经过卷积层conv2前后的尺寸均为32×32,根据公式:

可得:
 

即:

若stride[0]或stride[1]设置为2,那么上面的padding也会随之扩展为-个很大的数,这很不合理。所以这里设置: stride[0] = stride[1] = 1,由此可得: padding[0] = padding[1]= 2

其余卷积层的参数设计及计算方法均同上

最大池化操作的参数设计(以第一个池化操作maxpool为例)

根据计算神经网络 torch.nn---Pooling layers(nn.MaxPool2d)-CSDN博客

可以得到卷积核尺寸为 kernel_size=2

其他参数为默认值

线性层的参数设计

  • 通过三次卷积和最大池化操作后,图像尺寸变为64通道4×4。之后使用Flatten()函数将图像展成一列,此时图像尺寸变为:1×(64×4×4),即1×1024

  • 因此,之后通过第一个线性层,(in_features=64 * 4 * 4, out_features=64)

  • 第二个线性层,(in_features=64 , out_features=10)

程序代码

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear, Sequential, Flatten
from torch.utils.tensorboard import SummaryWriter

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = Sequential(
            Conv2d(3, 32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(32, 32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(32, 64, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),
            Linear(in_features=64 * 4 * 4, out_features=64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

tudui = Tudui()
print(tudui)

input = torch.ones(64, 3, 32, 32)
output = tudui(input)
print(output.shape)

# 可视化神经网络 
writer = SummaryWriter('logs')
writer.add_graph(tudui, input)
writer.close()

这样就可以清晰地看到神经网络的相关参数

相关推荐

  1. Keras库神经网络

    2024-06-07 02:02:02       20 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-07 02:02:02       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-07 02:02:02       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-07 02:02:02       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-07 02:02:02       20 阅读

热门阅读

  1. 中介子方程

    2024-06-07 02:02:02       7 阅读
  2. LE Audio音频广播新功能Auracast介绍

    2024-06-07 02:02:02       10 阅读
  3. Git | SSH 密钥连接到 GitHub

    2024-06-07 02:02:02       11 阅读
  4. lsof 命令

    2024-06-07 02:02:02       8 阅读
  5. Nacos控制台服务安装

    2024-06-07 02:02:02       12 阅读
  6. Meta Llama 3 大型语言模型的超参数

    2024-06-07 02:02:02       10 阅读
  7. 源代码先转字节码,再转机器码的过程

    2024-06-07 02:02:02       10 阅读
  8. 【redis】set和zset常用命令

    2024-06-07 02:02:02       12 阅读