【python深度学习】——torch.einsum|torch.bmm

【python深度学习】——torch.einsum|torch.bmm

1. 基本用法与示例

基本用法:

torch.einsum(equation, *operands)
  • equation: 一个字符串,定义了张量操作的模式。
    使用逗号来分隔输入张量的索引,然后是一个箭头(->),接着是输出张量的索引
  • operands: 要操作的张量。
    示例代码:
import torch
A = torch.randn(2, 3)

B = torch.einsum('ij->ji', A)
# 等价于 B = A.transpose(0, 1)

C = torch.einsum('ik,kj->ij', A, B)
# 等价于 C = torch.matmul(A, B)

a = torch.randn(3)
b = torch.randn(3)
c = torch.einsum('i,i->', a, b)
# 等价于 c = torch.dot(a, b)


A = torch.randn(5, 2, 3)
B = torch.randn(5, 3, 4)
C = torch.einsum('bij,bjk->bik', A, B)
# 等价于 C = torch.bmm(A, B)


a = torch.randn(3)
b = torch.randn(4)
c = torch.einsum('i,j->ij', a, b)
# 结果是一个3x4的矩阵,等价于 c = a.unsqueeze(1) * b.unsqueeze(0)


A = torch.randn(3, 3)
trace = torch.einsum('ii->', A)
# 等价于 trace = torch.trace(A)


2. torch.bmm

全称为: batch matrix-matrix product, 批量矩阵乘法, 适用于三维张量,其中第一维表示批量大小,第二维和第三维表示矩阵的行和列

torch.bmm(input, mat2, *, out=None) -> Tensor
  • input: 一个形状为 (b, n, m) 的三维张量,表示一批矩阵。
  • mat2: 一个形状为 (b, m, p) 的三维张量,表示另一批矩阵。
  • out (可选): 存储输出结果的张量。
    输出是一个形状为 (b, n, p) 的张量,其中每个矩阵是对应批次的矩阵乘法结果。

例如:

import torch

# 定义两个形状为 (b, n, m) 和 (b, m, p) 的三维张量
batch_size = 10
n, m, p = 3, 4, 5

A = torch.randn(batch_size, n, m)
B = torch.randn(batch_size, m, p)

# 进行批量矩阵乘法
C = torch.bmm(A, B)

print(C.shape)  # 输出: torch.Size([10, 3, 5])

再具体的:

A = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
B = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])

# A.shape = (2, 2, 2)
# B.shape = (2, 2, 2)
C = torch.bmm(A, B)

print(C)
# 输出:
# tensor([[[ 31,  34],
#          [ 73,  80]],
#
#         [[155, 166],
#          [211, 226]]])

其数学计算为:
请添加图片描述

相关推荐

  1. Python实现深度学习

    2024-06-08 19:16:05       10 阅读
  2. Python深度学习代码简介

    2024-06-08 19:16:05       11 阅读
  3. 深度学习常用指令(Anaconda、Python

    2024-06-08 19:16:05       42 阅读
  4. python深度学习搭环境技巧

    2024-06-08 19:16:05       29 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-06-08 19:16:05       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-08 19:16:05       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-08 19:16:05       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-08 19:16:05       20 阅读

热门阅读

  1. git 下载openNeuro大文件

    2024-06-08 19:16:05       11 阅读
  2. 哈希表(Hash table)

    2024-06-08 19:16:05       8 阅读
  3. C++协程

    2024-06-08 19:16:05       9 阅读
  4. 【vuejs】vm.$set() 的原理解析和方法以及应用场景

    2024-06-08 19:16:05       8 阅读
  5. 设计模式 —— 装饰器模式

    2024-06-08 19:16:05       8 阅读
  6. 深度学习-10-测试

    2024-06-08 19:16:05       8 阅读
  7. git 怎么让一个文件不提交

    2024-06-08 19:16:05       8 阅读
  8. 算法题 — 可可喜欢吃香蕉(二分查找法)

    2024-06-08 19:16:05       10 阅读
  9. PostgreSQL的内存结构

    2024-06-08 19:16:05       9 阅读