PyTorch-神经网络

神经网络,这也是深度学习的基石,所谓的深度学习,也可以理解为很深层的神经网络。说起这里,有一个小段子,神经网络曾经被打入了冷宫,因为SVM派的崛起,SVM不了解的同学可以去google一下,中文叫支持向量机,因为其有着完备的数学解释,并且之前神经网络运算复杂等问题,导致神经网络停步不前,这个时候任何以神经网络为题目的论文都发不出去,反向传播算法的鼻祖hinton为了解决这个问题,于是就想到了用深度学习为题目。

段子说完,接下来开始我们的简单神经网络。

Neural Network

其实简单的神经网络说起来很简单

通过图片就能很简答的看出来,其实每一层网络所做的就是 y=W×X+b,只不过W的维数由X和输出维书决定,比如X是10维向量,想要输出的维数,也就是中间层的神经元个数为20,那么W的维数就是20×10,b的维数就是20×1,这样输出的y的维数就为20。

中间层的维数可以自己设计,而最后一层输出的维数就是你的分类数目,比如我们等会儿要做的MNIST数据集是10个数字的分类,那么最后输出层的神经元就为10。

Code

有了前面两节的经验,这一节的代码就很简单了,数据的导入和之前一样

定义模型

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

class Neuralnetwork(nn.Module):

    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):

        super(Neuralnetwork, self).__init__()

        self.layer1 = nn.Linear(in_dim, n_hidden_1)

        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)

        self.layer3 = nn.Linear(n_hidden_2, out_dim)

    def forward(self, x):

        x = self.layer1(x)

        x = self.layer2(x)

        x = self.layer3(x)

        return x

model = Neuralnetwork(28*28, 300, 100, 10)

if torch.cuda.is_available():

    model = model.cuda()

criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=learning_rate)

上面定义了三层神经网络,输入是28×28,因为图片大小是28×28,中间两个隐藏层大小分别是300和100,最后是个10分类问题,所以输出层为10.

训练过程与之前完全一样,我就不再重复了,可以直接去github参看完整的代码

这是50次之后的输出结果,可以和上一节logistic回归比较一下

可以发现准确率大大提高,其实logistic回归可以看成简单的一层网络,从这里我们就可以看出为什么多层网络比单层网络的效果要好,这也是为什么深度学习要叫深度的原因。

相关推荐

  1. pytorch基础 神经网络构建

    2024-03-10 18:00:11       43 阅读
  2. pytorch神经网络入门代码

    2024-03-10 18:00:11       53 阅读
  3. 神经网络 | Pytorch神经网络ST-GNN

    2024-03-10 18:00:11       28 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-10 18:00:11       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-10 18:00:11       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-10 18:00:11       82 阅读
  4. Python语言-面向对象

    2024-03-10 18:00:11       91 阅读

热门阅读

  1. C++核心编程

    2024-03-10 18:00:11       40 阅读
  2. 力扣背包问题

    2024-03-10 18:00:11       37 阅读
  3. 【微软技术】介绍

    2024-03-10 18:00:11       44 阅读
  4. 面试题之——SpringBoot的好处?

    2024-03-10 18:00:11       43 阅读
  5. django 的 filter 使用技巧

    2024-03-10 18:00:11       42 阅读
  6. uniapp中使用LocalStorage实现本地存储缓存数据

    2024-03-10 18:00:11       45 阅读
  7. PokéLLMon 源码解析(四)

    2024-03-10 18:00:11       32 阅读