Pytorch中nn.Linear使用方法

nn.Linear定义一个神经网络的线性层:

torch.nn.Linear(in_features,             # 输入的神经元个数
                out_features,            # 输出神经元个数
                bias=True                # 是否包含偏置
                )

nn.Linear其实就是对输入x_{n\times i}(n表示样本数量,i表示样本特征数)执行了一个线性变换,即:

Y_{n\times o } = X_{n\times i}W_{i\times o} + b

其中W矩阵是模型要学习的参数,b是1*O的向量偏置(即1行O列),n表示输入向量的个数(也可以理解为行数,比如一次输入100个样本数据,则n=100),i为每个样本的特征数,也可以理解为神经元的个数,O为输出样本的特征数,即输出神经元的个数。

from torch import nn
import torch

model = nn.Linear(3, 1)           # 每个样本输入特征数设置为3,输出特征数设置为1

input = torch.Tensor([2, 4, 6])   # 给一个样本,该样本有3个特征,这3个特征分别是2、4、6
output = model(input)

print("nn.Linear 输出大小:{}".format(output.shape))
print(output)
print("")

print("查看模型参数W和b的值")
# 查看模型参数
for param in model.parameters():
    print(param)

输出:
nn.Linear 输出大小:torch.Size([1])    #输出结果表示只有一个样本输出,且该样本只有一个特征值1
tensor([-0.7842], grad_fn=<AddBackward0>)

查看模型参数W和b的值
Parameter containing:
tensor([[ 0.2353, -0.5686,  0.1759]], requires_grad=True)
Parameter containing:
tensor([-0.0356], requires_grad=True)

可以看到,模型有4个参数,分别为W的三个权重和b的一个偏置。手动计算验证结果:

0.2353*2 + (-0.5686)*4 + 0.1759*6 + (-0.0356) = -0.7839999999999997

假设有5个输入样本A、B、C、D、E(即batch_size为5),每个样本的特征数量为3,定义线性层时,输入特征为3,所以in_feature=3,想让下一层的神经元个数为5,所以out_feature=5,则模型参数为:

model = nn.Linear(in_features=3, out_features=5, bias=True)

此时参数矩阵W大小为3行3列

from torch import nn
import torch

model = nn.Linear(3, 5)           # 每个样本输入特征数设置为3,输出特征数设置为1

input = torch.Tensor([[2, 4, 6],[8,10,12],[14,16,18],[20,22,24],[26,28,30]])   # 给一个样本,该样本有3个特征,这3个特征分别是2、4、6

print(input)

output = model(input)

print("nn.Linear 输出大小:{}".format(output.shape))
print(output)
print("")

print("查看模型参数W和b的值")
# 查看模型参数
for param in model.parameters():
    print(param)

输出:
tensor([[ 2.,  4.,  6.],
        [ 8., 10., 12.],
        [14., 16., 18.],
        [20., 22., 24.],
        [26., 28., 30.]])
nn.Linear 输出大小:torch.Size([5, 5])
tensor([[ -0.9616,  -0.9744,   2.6266,  -0.5605,  -4.2236],
        [ -1.7251,  -4.4417,   5.9969,  -1.3649, -11.0200],
        [ -2.4886,  -7.9090,   9.3673,  -2.1692, -17.8163],
        [ -3.2522, -11.3763,  12.7376,  -2.9736, -24.6127],
        [ -4.0157, -14.8436,  16.1079,  -3.7779, -31.4090]],
       grad_fn=<AddmmBackward>)

查看模型参数W和b的值
Parameter containing:
tensor([[ 0.0714,  0.1456, -0.3443],
        [-0.5098, -0.0893,  0.0211],
        [ 0.3489, -0.2682,  0.4811],
        [ 0.0768, -0.3863,  0.1755],
        [-0.2832, -0.4325, -0.4170]], requires_grad=True)
Parameter containing:
tensor([ 0.3789,  0.2753,  0.1153, -0.2216,  0.5748], requires_grad=True)

第一个样本特征为[2、4、6],输出为[ -0.9616,  -0.9744,   2.6266,  -0.5605,  -4.2236],验证过程如下:

%w是模型参数矩阵
w = [[ 0.0714,  0.1456, -0.3443],
     [-0.5098, -0.0893,  0.0211],
     [ 0.3489, -0.2682,  0.4811],
     [ 0.0768, -0.3863,  0.1755],
     [-0.2832, -0.4325, -0.4170]];
x = [2,4,6];
b = [0.3789,  0.2753,  0.1153, -0.2216,  0.5748];   %偏置向量
x*w'+b

输出:
 -0.9617   -0.9749    2.6269   -0.5602   -4.2236

第2个样本验证:

w = [[ 0.0714,  0.1456, -0.3443],
        [-0.5098, -0.0893,  0.0211],
        [ 0.3489, -0.2682,  0.4811],
        [ 0.0768, -0.3863,  0.1755],
        [-0.2832, -0.4325, -0.4170]];
x = [8,10,12];
b = [0.3789,  0.2753,  0.1153, -0.2216,  0.5748];
x*w'+b

输出:
-1.7255   -4.4429    5.9977   -1.3642  -11.0198

第3、4、5个样本的验证过程类似,从以上验证可以看出,所有样本共享参数矩阵W和偏置b

因为有5个样本,所以相当于依次进行了5次以上操作。

该操作重复了5次,每个样本重复一次:Y_{1\times 5}=X_{1\times 3}W_{3\times 5} + b_{1\times 5}

然后再将5个Y _{1 \times 5}叠加在一起,得到5*5的输出
 

相关推荐

  1. PyTorch2TorchText的基本使用方法

    2024-04-09 12:42:04       40 阅读
  2. Pytorch】在多进程使用 CUDA

    2024-04-09 12:42:04       53 阅读
  3. pytorch nn.ModuleList()使用说明

    2024-04-09 12:42:04       42 阅读
  4. 16、pytorch张量的8种创建方法

    2024-04-09 12:42:04       54 阅读
  5. Pytorch保存模型的两种方法

    2024-04-09 12:42:04       31 阅读
  6. PyTorch】torch.distributed()的含义和使用方法

    2024-04-09 12:42:04       30 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-04-09 12:42:04       99 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-04-09 12:42:04       107 阅读
  3. 在Django里面运行非项目文件

    2024-04-09 12:42:04       90 阅读
  4. Python语言-面向对象

    2024-04-09 12:42:04       98 阅读

热门阅读

  1. 【Linux】手搓shell

    2024-04-09 12:42:04       43 阅读
  2. python实现网络爬虫

    2024-04-09 12:42:04       33 阅读
  3. 从零开始精通RTSP之初识实时流协议

    2024-04-09 12:42:04       40 阅读
  4. 计算机网络---第三天

    2024-04-09 12:42:04       35 阅读
  5. SpringBoot通过token实现用户互踢功能

    2024-04-09 12:42:04       37 阅读
  6. C++:万能进制转换

    2024-04-09 12:42:04       41 阅读
  7. iOS MT19937随机数生成,结合AES-CBC加密算法实现。

    2024-04-09 12:42:04       28 阅读
  8. 头歌:共享单车之数据可视化

    2024-04-09 12:42:04       39 阅读
  9. 计算机网络-ICMP和ARP协议——沐雨先生

    2024-04-09 12:42:04       40 阅读