基础神经网络模型搭建

nn 包提供通用深度学习网络的模块集合,接收输入张量,计算输出张量,并保存权重。通常使用两种途径搭建 PyTorch 中的模型:nn.Sequential和 nn.Module。

nn.Sequential通过线性层有序组合搭建模型;nn.Module通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

目录

搭建线性层

通过nn.Sequential搭建

通过nn.Module搭建

获取模型摘要


搭建线性层

使用 nn 包搭建线性层。线性层接收 64*1000 维的输入,保存 1000*100 维的权重,并计算 64*100 维的输出。

import torch
input_tensor = torch.randn(64, 1000)
linear_layer = nn.Linear(1000, 100)
output = linear_layer(input_tensor)
print(input_tensor.size())
print(output.size())

通过nn.Sequential搭建

考虑一个两层的神经网络,四个节点作为输入,五个节点在隐藏层,一个节点作为输出

from torch import nn
model = nn.Sequential(
 nn.Linear(4, 5),
 nn.ReLU(),
 nn.Linear(5, 1),
)
print(model)

 

通过nn.Module搭建

在 PyTorch 中搭建模型的另一种方法是对 nn.Module 类进行子类化,通过__init__ 函数指定层,然后通过 forward 函数将层应用于输入,更灵活地构建自定义模型。

考虑两个卷积层和两个完全连接层搭建的模型:

import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
    def forward(self, x):
         pass

定义__init__ 函数和forward 函数

def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(1, 20, 5, 1)
    self.conv2 = nn.Conv2d(20, 50, 5, 1)
    self.fc1 = nn.Linear(4*4*50, 500)
    self.fc2 = nn.Linear(500, 10)
def forward(self, x):
    x = F.relu(self.conv1(x))
    x = F.max_pool2d(x, 2, 2)
    x = F.relu(self.conv2(x))
    x = F.max_pool2d(x, 2, 2) 
    x = x.view(-1, 4*4*50)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

重写两个类函数并打印模型

重写:子类中实现一个与父类的成员函数原型完全相同的函数

Net.__init__ = __init__
Net.forward = forward
model = Net()
print(model)

 

 查看模型位置

print(next(model.parameters()).device)

 

将模型移动至CUDA设备 

device = torch.device("cuda:0")
model.to(device)
print(next(model.parameters()).device)

 

 

获取模型摘要

借助torchsummary包查获取模型摘要

pip install torchsummary
from torchsummary import summary
summary(model, input_size=(1, 28, 28))

 

 

相关推荐

  1. Keras库神经网络

    2024-07-15 17:42:01       41 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-15 17:42:01       67 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-15 17:42:01       72 阅读
  3. 在Django里面运行非项目文件

    2024-07-15 17:42:01       58 阅读
  4. Python语言-面向对象

    2024-07-15 17:42:01       69 阅读

热门阅读

  1. Redis① —— Redis基础

    2024-07-15 17:42:01       20 阅读
  2. LeetCode 445.两数相加||

    2024-07-15 17:42:01       16 阅读
  3. openstack

    2024-07-15 17:42:01       19 阅读
  4. Memcached与Redis:缓存解决方案的较量与选择

    2024-07-15 17:42:01       21 阅读
  5. pandas读取超过16位的excle

    2024-07-15 17:42:01       16 阅读
  6. blinker库

    2024-07-15 17:42:01       19 阅读
  7. 如何使用断点续传方式上传大文件到阿里云 OSS

    2024-07-15 17:42:01       14 阅读
  8. Web打点技术的攻击手段和渗透测试工具

    2024-07-15 17:42:01       20 阅读
  9. 游戏开发面试题2

    2024-07-15 17:42:01       18 阅读
  10. linux系统调用

    2024-07-15 17:42:01       21 阅读
  11. git安装

    git安装

    2024-07-15 17:42:01      19 阅读