PyTorch中matmul函数的矩阵相乘原则和注意事项

PyTorch中matmul函数的矩阵相乘原则和注意事项

一、高维张量乘法规则

1. 选择乘法的维度: 选择最后两个维度进行乘
2. 维度匹配规则: 最后两个维度按照普通矩阵乘法计算
3. 广播机制:torch.matmul 函数支持广播机制,即在满足乘法维度匹配规则的前提下,可以通过扩展(广播)其他维度来实现矩阵相乘。这使得可以对不同形状的张量进行相乘。

4. 结果张量的形形状:

最后2维为矩阵乘法正常计算完成应该有的维度,而高维则以参与计算的两个矩阵中,维度更大的那个矩阵的维度为准。

为什么是这样,因为其实高维矩阵的乘法就是分别从高维中选取对应位置的一对矩阵(普通矩阵)相乘把高维都遍历完了,整个高维矩阵乘法也就完成了。

那么原参与计算的矩阵形状,高维有多大,计算结果就应该有多大(因为高维只遍历)

⭐⭐⭐一言以蔽之:除最后两维外,每一维的分量数必须对应相等(每个分量对应相乘) 或 有一方为1(broadcast-广播机制)

二、二维矩阵相乘

线性代数基本知识,就不多讲了

import torch

# 创建两个二维矩阵
A = torch.tensor([[1, 2],
                  [3, 4]])
B = torch.tensor([[5, 6],
                  [7, 8]])

# 使用 matmul 进行二维矩阵相乘
C = torch.matmul(A, B)
print("二维矩阵相乘结果:")
print(C)

输出结果:

二维矩阵相乘结果:
tensor([[19, 22],
        [43, 50]])

三、三维张量相乘

对于两个三维张量 A 和 B,我们可以选择其中的最后两个维度进行相乘。

import torch

# 创建两个三维张量
A = torch.randn(2, 3, 4)
B = torch.randn(2, 4, 5)

# 使用 matmul 进行三维张量相乘
C = torch.matmul(A, B)
print("三维张量相乘结果的形状:")
print(C.shape)

输出结果:

三维张量相乘结果的形状:
torch.Size([2, 3, 5])

在这个示例中,张量 A 的形状是 [2, 3, 4],张量 B 的形状是 [2, 4, 5],我们对最后两个维度进行了矩阵相乘,得到的结果张量 C 的形状是 [2, 3, 5]

四、三维张量广播相乘示例

import torch

# 创建两个可以广播的张量
a = torch.randn(2, 3, 4)
b = torch.randn(4, 5)

# 使用 torch.matmul 进行广播机制的矩阵乘法
result = torch.matmul(a, b)
print("广播机制下的矩阵乘法结果的形状:")
print(result.sha

输出

广播机制下的矩阵乘法结果的形状:
torch.Size([2, 3, 5])

五、高维张量相乘

最后,我们考虑更高维度的情况,例如四维张量。对于四维张量 A 和 B,我们选择最后两个维度进行相乘。

# 创建两个四维张量
A = torch.randn(2, 3, 4, 5)
B = torch.randn(2, 3, 5, 6)

# 使用 matmul 进行四维张量相乘
C = torch.matmul(A, B)
print("四维张量相乘结果的形状:")
print(C.shape)

输出结果:

四维张量相乘结果的形状:
torch.Size([2, 3, 4, 6])

在这个示例中,张量 A 的形状是 [2, 3, 4, 5],张量 B 的形状是 [2, 3, 5, 6],我们对最后两个维度进行了矩阵相乘,得到的结果张量 C 的形状是 [2, 3, 4, 6]

相关推荐

  1. PyTorchmatmul函数矩阵相乘原则注意事项

    2024-07-09 19:54:05       31 阅读
  2. PyTorch学习(12):PyTorch张量相乘(torch.matmul

    2024-07-09 19:54:05       25 阅读
  3. Pytorch之DatasetDataLoader注意事项

    2024-07-09 19:54:05       37 阅读
  4. free函数用法注意事项

    2024-07-09 19:54:05       25 阅读
  5. 【GoLang基础】函数注意事项细节讨论

    2024-07-09 19:54:05       28 阅读
  6. [hive] sqldistinct用法注意事项

    2024-07-09 19:54:05       59 阅读
  7. CUDA | 核函数编写注意事项

    2024-07-09 19:54:05       33 阅读
  8. Pytorch forward 函数内部原理

    2024-07-09 19:54:05       36 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-09 19:54:05       50 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-09 19:54:05       54 阅读
  3. 在Django里面运行非项目文件

    2024-07-09 19:54:05       43 阅读
  4. Python语言-面向对象

    2024-07-09 19:54:05       54 阅读

热门阅读

  1. 使用 Conda 管理 Python 环境的详细指南

    2024-07-09 19:54:05       22 阅读
  2. 从零开始!Jupyter Notebook的安装教程

    2024-07-09 19:54:05       23 阅读
  3. UI 自动化分布式测试 -- Docker Selenium Grid

    2024-07-09 19:54:05       18 阅读
  4. Spring Cloud Gateway报sun.misc.Unsafe.park(Native Method)

    2024-07-09 19:54:05       30 阅读
  5. Spring Cloud Gateway如何匹配某路径并进行路由转发

    2024-07-09 19:54:05       24 阅读
  6. 裸金属服务器与物理服务器之间的区别

    2024-07-09 19:54:05       18 阅读
  7. 精准注入:掌握Conda包依赖注入的艺术

    2024-07-09 19:54:05       26 阅读