pytorch-nn.Module

1. nn.Module

是所有nn.类的父类,其中包括nn.Linear nn.BatchNorm2d nn.Conv2d nn.ReLU nn.Sigmoid等等

2. nn.Sequential容器

如下图,定义一个net网络,将所有继承自nn.Module的子类定义的网络层加入到了nn.Sequential容器中,与一层一层的单独调用模块组成序列相比,nn.Sequential() 可以允许将整个容器视为单个模块(即相当于把多个模块封装成一个模块),forward()方法接收输入之后,nn.Sequential()按照内部模块的顺序自动依次计算并输出结果。因此可以利用nn.Sequential()搭建模型架构

在这里插入图片描述

3. 网络参数parameters

如下图,通过net.parameters()可以获取到net的参数,转换成list后,通过index访问第几个参数,比如:图中的list(net.named_parameters())[0]就可以获取到网络的第一个参数,也就是网络第一层的w参数。
通过list(net.named_parameters()).items()获取到所有网络层,从获取结果可以看到,每一层都被pytorch命名了,比如:‘0.weight’,‘0.bias’,即第一层网络的weight和bias.
在这里插入图片描述

4. Modules内部管理

与根节点相连的直系亲属叫children,其他再与children连接的节点都叫modules
如下图,nn.Sequential是Net的children,其他的是modules,包括nn.ReLU、nn.Linear、BasicNet
在这里插入图片描述
从下面这张截图可以看出,Net本身和Children也都是modules
在这里插入图片描述

5. checkpoint

为了防止train过程意外停止,需从头train的问题,train过程需要定期保持checkpoint,而一旦出现train意外停止,就可以从最后一次checkpoint接着训练。
torch.save保存checkpoint
torch.load_state_dict(torch.load(‘chpt.md’))用于load checkpoint
在这里插入图片描述

6. train/test状态切换

所有nn.类都继承自nn.Module,因此在切换train和test状态时,只需要调用一次net.train()或net.eval即可,而不需要那些train和test(dropout)行为不一致的类每个单独去切换.
在这里插入图片描述

6. 实现自己的网络层

6.1 实现打平操作

全连接层层需要打平输入,打平操作通过.view方法实现,由于Flatten继承自nn.Module,因此可以直接放到nn.Sequential中。
在这里插入图片描述

6.2 实现自己的线性层

通过net.parameters()可以将网络参数加到优化器中。
在这里插入图片描述
troch.tensor是不会自动加到nn.parameters中,因此需要使用nn.Parameter将tensor加到nn.parameters,从而才能加到SGD等优化器中。

在这里插入图片描述

7. 代码

import  torch
from    torch import nn
from    torch import optim

class MyLinear(nn.Module):

    def __init__(self, inp, outp):
        super(MyLinear, self).__init__()

        # requires_grad = True
        self.w = nn.Parameter(torch.randn(outp, inp))
        self.b = nn.Parameter(torch.randn(outp))

    def forward(self, x):
        x = x @ self.w.t() + self.b
        return x


class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, input):
        return input.view(input.size(0), -1)

class TestNet(nn.Module):

    def __init__(self):
        super(TestNet, self).__init__()

        self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
                                 nn.MaxPool2d(2, 2),
                                 Flatten(),
                                 nn.Linear(1*14*14, 10))

    def forward(self, x):
        return self.net(x)


class BasicNet(nn.Module):

    def __init__(self):
        super(BasicNet, self).__init__()

        self.net = nn.Linear(4, 3)

    def forward(self, x):
        return self.net(x)

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        self.net = nn.Sequential(BasicNet(),
                                 nn.ReLU(),
                                 nn.Linear(3, 2))

    def forward(self, x):
        return self.net(x)





def main():
    device = torch.device('cuda')
    net = Net()
    net.to(device)

    net.train()

    net.eval()

    # net.load_state_dict(torch.load('ckpt.mdl'))
    #
    #
    # torch.save(net.state_dict(), 'ckpt.mdl')

    for name, t in net.named_parameters():
        print('parameters:', name, t.shape)

    for name, m in net.named_children():
        print('children:', name, m)


    for name, m in net.named_modules():
        print('modules:', name, m)

if __name__ == '__main__':
    main()

相关推荐

  1. <span style='color:red;'>Pytorch</span>

    Pytorch

    2024-06-08 11:20:06      51 阅读
  2. PyTorch

    2024-06-08 11:20:06       52 阅读
  3. PytorchPytorch入门基础

    2024-06-08 11:20:06       37 阅读
  4. 入门 PyTorch

    2024-06-08 11:20:06       63 阅读
  5. PyTorch】概述

    2024-06-08 11:20:06       53 阅读
  6. pytorch RNN

    2024-06-08 11:20:06       43 阅读
  7. Python:PyTorch

    2024-06-08 11:20:06       51 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-08 11:20:06       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-08 11:20:06       100 阅读
  3. 在Django里面运行非项目文件

    2024-06-08 11:20:06       82 阅读
  4. Python语言-面向对象

    2024-06-08 11:20:06       91 阅读

热门阅读

  1. mac前端com+f与com+shift+f查找文章内容

    2024-06-08 11:20:06       25 阅读
  2. 图论方法学习

    2024-06-08 11:20:06       30 阅读
  3. Tomcat 启动闪退问题解决方法

    2024-06-08 11:20:06       25 阅读
  4. tomcat 启动闪退问题解决方法

    2024-06-08 11:20:06       23 阅读
  5. Mysql 快速入门指南

    2024-06-08 11:20:06       23 阅读
  6. Linux关闭SSH延迟连接和超时自动注销

    2024-06-08 11:20:06       28 阅读
  7. 力扣76.最小覆盖子串

    2024-06-08 11:20:06       29 阅读
  8. 物联网的应用——医疗健康

    2024-06-08 11:20:06       27 阅读
  9. 【Redis】Redis集群脑裂的原因及解决方案

    2024-06-08 11:20:06       34 阅读