【pytorch】nn.linear 中为什么是y=xA^T+b

我记得读教材的时候是y=Wx+b, 左乘矩阵W,这样才能表示线性变化。
但是pytorch中的nn.linear中,计算方式是y=xA^T+b,其中A是权重矩阵。
为什么右乘也能表示线性变化操作呢?因为pytorch中,照顾到输入是多个样本一起算的(第一个维度是多个样本数,所以输入默认是行向量),所以用y=xA^T+b,输出的y也是行向量。

在这里插入图片描述

我们的教材中默认输入是列向量的,而pytorch为了用户方便,输入当作列向量,维度为(batch, dim),每行是特征

m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)

print(output.size())
>>> torch.Size([128, 30])
print(m.weight.shape)
>>>torch.Size([30, 20])  # 注意这里的权重维度

相关推荐

  1. 什么封装?什么要封装?

    2024-01-31 10:22:01       28 阅读
  2. React使用usePrevious的意义什么啥要用它

    2024-01-31 10:22:01       22 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-01-31 10:22:01       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-01-31 10:22:01       106 阅读
  3. 在Django里面运行非项目文件

    2024-01-31 10:22:01       87 阅读
  4. Python语言-面向对象

    2024-01-31 10:22:01       96 阅读

热门阅读

  1. SparkSQL之函数解析

    2024-01-31 10:22:01       49 阅读
  2. compose LazyColumn + items没有自动刷新问题

    2024-01-31 10:22:01       62 阅读
  3. CF97B Superset 题解 分治

    2024-01-31 10:22:01       58 阅读
  4. 【kafka-01数据保留时间设置】

    2024-01-31 10:22:01       64 阅读
  5. 华为HI模式与华为智选模式的左右互博

    2024-01-31 10:22:01       52 阅读