pytorch入门笔记二

torch.nn.Sequential

torch.nn.Sequential是一个容器,利用此容器可以快速简单的搭建一个简单的神经网络。这里以搭建一个三层神经网络为例。
首先该容器的参数分别是上一层到下一层的权重、激活函数,以此循环。
这里torch提供快速生成网络权重的方法:torch.nn.Linear(input,output)参数表示的分别是输入节点数和输出节点数。方法会根据输入输出节点数自动初始化一个权重矩阵。这里以输入节点数为10输出节点数为1为例

import torch
print(torch.nn.Linear(10,1).weight)#weight输出权重

输出结果为:

Parameter containing: tensor([[-0.2077, -0.2672, -0.1795, -0.1366,
0.1868, -0.0780, -0.2176, -0.0796,
0.1267, 0.0697]], requires_grad=True)

因此这里的权重可以直接使用Linear()方法随机生成。
同时还需要有激活函数。torch提供了非常多的激活函数。其中有常用的sigmoid()函数。
这里定义一个输入节点数为1,输出节点数为10,隐藏节点为100的神经网络为例

import torch
model=torch.nn.Sequential(
            torch.nn.Linear(28*28,100),#权重
            torch.nn.Sigmoid(),#激活函数
            torch.nn.Linear(100,10),#权重
            torch.nn.Softmax()#激活函数
        )

做完这些一个基本神经网络雏形出现了。
使用时只需要将输入值作为参数传入即可得到放回值

out=model(inputs)#inputs是输入内容,是tensor张量形式

在神经网络训练的时候为例更新权重,需要进行梯度传递。由于这里的权重封装,直接使用data.add_()方法很复杂,这里可以使用torch提供的优化器

torch.optim

torch.optim为我们提供了大量的优化器。不同的优化器有不同的特点。这里以常用优化器Adam(parameter,lr)为例,参数分别是是权重和初始学习率。
定义时将网络权重作为参数传入

opti=torch.optim.Adam(model.parameters(),lr=0.01)

在进行训练时,需要将梯度清零。目的是防止梯度累计,也就是在每一次训练前进行梯度清零。方法是

opti.zero_grad()

在完成一次训练输出时,将损失计算后进行梯度反传。就可以使用优化器方法step。进行权重更新

loss.backward()   #loss是计算出来的损失
opti.step()       #权重更新

以上就是基本的训练步骤。

model保存和加载

torch.nn.Sequential提供了一个方便快速定义神经网络的方式。
同时torch提供了一个保存和加载的方法。
对于一个训练好的神经网络可以使用方法save(model,‘name’)参数分别是 torch.nn.Sequential定义的容器和保存路径。该方法来保存权重和激活函数等信息。

torch.save(model,'model.data')  #将model保存到文件model.data中

同时也提供了加载函数。使用方法load(path)参数是文件路径,该方法可以直接从指定文件中读取保存的神经网络

model=torch.load('model.data')      #从文件model.data中读取一个神经网络

相关推荐

  1. pytorch入门笔记

    2024-02-18 23:52:03       49 阅读
  2. 入门 PyTorch

    2024-02-18 23:52:03       63 阅读
  3. PyTorch入门

    2024-02-18 23:52:03       48 阅读
  4. PytorchPytorch入门基础

    2024-02-18 23:52:03       36 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-02-18 23:52:03       91 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-02-18 23:52:03       97 阅读
  3. 在Django里面运行非项目文件

    2024-02-18 23:52:03       78 阅读
  4. Python语言-面向对象

    2024-02-18 23:52:03       88 阅读

热门阅读

  1. 关于数据库

    2024-02-18 23:52:03       65 阅读
  2. 面试浏览器框架八股文十问十答第一期

    2024-02-18 23:52:03       76 阅读
  3. 【SpringSecurity】2. 初学SpringSecurity

    2024-02-18 23:52:03       51 阅读
  4. C#系列-C#实现秒杀功能(14)

    2024-02-18 23:52:03       51 阅读
  5. python中函数的运用(1)

    2024-02-18 23:52:03       49 阅读
  6. STM32的三种下载方式

    2024-02-18 23:52:03       54 阅读
  7. 正则表达式速查表

    2024-02-18 23:52:03       45 阅读
  8. 工厂设计模式

    2024-02-18 23:52:03       42 阅读
  9. windows下Oracle 11g的安装和配置教程的详细步骤

    2024-02-18 23:52:03       56 阅读