Pytorch中nn.Sequential()函数创建网络的几种方法

1. 创作灵感

在创建大型网络的时候,如果使用nn.Sequential()将几个有紧密联系的运算组成一个序列,可以使网络的结构更加清晰。

2.应用举例

为了记录nn.Sequential()的用法,搭建以下测试网络:

2.1 方法一

把网络分成3个Sequential序列,分别实现:第一个是卷积序列,第二个是铺平成一维的操作,第3个包含了两个线性层。


class TestNet(nn.Module):
    def __init__(self, in_channels, mid_channels):
        super(TestNet,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,in_channels*2,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(in_channels*2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels*2,in_channels*4,kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(in_channels*4),
            nn.ReLU(inplace=True)
        )
        self.flat = nn.Flatten()
        self.linaer = nn.Sequential(
            nn.Linear(49*4, 64),
            nn.Linear(64, 10),
            nn.Linear(10, 1)
        )

    def forward(self,x):
        x = self.conv(x)
        print(x.shape)
        x = self.flat(x)
        print(x.shape)
        x = self.linaer(x)
        print(x.shape)
        return x

运行结果:

从维度上判断,网络符合预期。

2.2  方法二

第二种方法,在网络结构比较复杂且重复的单元比较多,为了自动化生成网络,通常会先定义一个列表,在列表中添加网络,再使用nn.Sequential()。

使用方法二所需要的代码如下:

class TestNet2(nn.Module):
    def __init__(self, in_channels):
        super(TestNet2,self).__init__()
        layer1 = []
        layer2 = []
        layer1.append(nn.Conv2d(in_channels,in_channels*2,kernel_size=3,stride=2,padding=1))
        layer1.append(nn.BatchNorm2d(in_channels*2))
        layer1.append(nn.ReLU(inplace=True))
        layer1.append(nn.Conv2d(in_channels*2,in_channels*4,kernel_size=3,stride=2,padding=1))
        layer1.append(nn.BatchNorm2d(in_channels*4))
        layer1.append(nn.ReLU(inplace=True))
        self.conv = nn.Sequential(*layer1)
        self.flat = nn.Flatten()
        layer2.append(nn.Linear(49*4, 64))
        layer2.append(nn.Linear(64, 10))
        layer2.append(nn.Linear(10, 1))
        self.linaer = nn.Sequential(*layer2)

    def forward(self,x):
        x = self.conv(x)
        print(x.shape)
        x = self.flat(x)
        print(x.shape)
        x = self.linaer(x)
        print(x.shape)
        return x

运行结果如下:

与第一种方法的结果相同。

参考文献:

nn.Sequential、nn.ModuleList、nn.ModuleDict区别及使用技巧-CSDN博客

相关推荐

  1. 16、pytorch张量8创建方法

    2024-07-13 17:58:04       50 阅读
  2. QT 创建线程方法

    2024-07-13 17:58:04       50 阅读
  3. 03 创建图像窗口方式

    2024-07-13 17:58:04       49 阅读
  4. CSS常用清除浮动方法

    2024-07-13 17:58:04       22 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-13 17:58:04       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-13 17:58:04       71 阅读
  3. 在Django里面运行非项目文件

    2024-07-13 17:58:04       58 阅读
  4. Python语言-面向对象

    2024-07-13 17:58:04       69 阅读

热门阅读

  1. 算法提高第二章 线段树基础

    2024-07-13 17:58:04       18 阅读
  2. django orm中value和value_list以及转成list

    2024-07-13 17:58:04       21 阅读
  3. C# .Net Core Zip压缩包中文名乱码的解决方法

    2024-07-13 17:58:04       22 阅读
  4. live555关于RTSP协议交互流程

    2024-07-13 17:58:04       15 阅读
  5. EXPORT_SYMBOL

    2024-07-13 17:58:04       24 阅读
  6. 【车载开发系列】汽车开发常见概念理解

    2024-07-13 17:58:04       19 阅读
  7. 深入理解Spring Boot中的定时任务调度

    2024-07-13 17:58:04       17 阅读
  8. 大数据平台建设概要

    2024-07-13 17:58:04       21 阅读
  9. python文件

    2024-07-13 17:58:04       22 阅读