python 学习: 矩阵运算

摘要: 本贴通过例子描述 python 的矩阵运算.

1. 一般乘法 (mm 与 matmul)

代码:

    input_mat1 = torch.tensor([[1, 2, 3, 4],
            [1, 2, 2, 3]])

    input_mat2 = torch.tensor([[1, 2, 3, 3],
            [2, 1, 2, 3],
            [3, 1, 2, 2],
            [2, 3, 2, 3]])
    print("input_mat1: ", input_mat1)
    print("input_mat2: ", input_mat2)

    output_mat1 = torch.mm(input_mat1, input_mat2)
    print("torch.mm() test, output_mat1 = ", output_mat1)

    output_mat2 = torch.matmul(input_mat1, input_mat2)
    print("torch.matmul() test, output_mat2 = ", output_mat2)

结果:

input_mat1:  tensor([[1, 2, 3, 4],
        [1, 2, 2, 3]])
input_mat2:  tensor([[1, 2, 3, 3],
        [2, 1, 2, 3],
        [3, 1, 2, 2],
        [2, 3, 2, 3]])
torch.mm() test, output_mat1 =  tensor([[22, 19, 21, 27],
        [17, 15, 17, 22]])
torch.matmul() test, output_mat2 =  tensor([[22, 19, 21, 27],
        [17, 15, 17, 22]])

分析:

  • 利用 torch.tensor 来定义张量 (含矩阵);
  • torch.mm() 和 torch.matmul() 在这个例子里面的作用相同, 都是将 m × n m \times n m×n n × k n \times k n×k 的矩阵进行乘法, 获得 m × k m \times k m×k 的矩阵.

2. 逐点乘法 (乘法符号)

2.1 一维数组

代码:

    print("---torch.tensor star product test---")
    input_array1 = torch.tensor([1, 2, 3, 4])
    input_array2 = torch.tensor([4, 3, 2, 1])
    star_product = input_array1 * input_array2
    print("star_product: ", star_product)

结果:

---torch.tensor star product test---
star_product:  tensor([4, 6, 6, 4])

分析:

  • 不改变向量尺寸.

2.2 二维矩阵

代码:

    print("---element_wise_product  test---")
    input_matrix = np.array([[1, 2], [3, 4]])
    element_wise_product = input_matrix * input_matrix
    print("element_wise_product : ", element_wise_product)

结果:

---element_wise_product  test---
element_wise_product :  [[ 1  4]
 [ 9 16]]

分析:

  • 不改变矩阵尺寸.

3. 点乘 (dot)

3.1 一维数组

代码:

    print("---torch.tensor dot_product test---")
    input_array1 = torch.tensor([1, 2, 3, 4])
    input_array2 = torch.tensor([4, 3, 2, 1])
    dot_product = torch.dot(input_array1, input_array2)
    print("array dot_product: ", dot_product)

    print("---np.array dot_product test---")
    input_array1 = np.array([1, 2, 3, 4])
    input_array2 = np.array([4, 3, 2, 1])
    dot_product = np.dot(input_array1, input_array2)
    print("array dot_product: ", dot_product)

结果:

---torch.tensor dot_product test---
array dot_product:  tensor(20)
---np.array dot_product test---
array dot_product:  20

分析:

  • 相当于内积;
  • torch.tensor 和 np.array 都支持 dot;
  • torch 返回结果是一个 1 × 1 1 \times 1 1×1 tensor, np 返回的是一个标量.

3.2 矩阵

与 torch.matmul 相同.

4. 拼接 (cat)

4.1 向量

代码:

    print("---array test---")
    input_mat1 = torch.tensor([1, 2, 3, 4])
    print("input: ", input_mat1)

    horizontal_stack = torch.cat((input_mat1, input_mat1), 0)
    #vertical_stack = torch.cat((input_mat1, input_mat1), 1)

    print("horizontal_cat = ", horizontal_stack)
    #print("vertical_cat = ", vertical_stack)

结果:

---array test---
input:  tensor([1, 2, 3, 4])
horizontal_cat =  tensor([1, 2, 3, 4, 1, 2, 3, 4])

分析:

  • cat 的第 2 个参数指定方向, 0 表示水平, 1 表示垂直;
  • 向量支持水平叠加, 不支持垂直叠加, 否则向量变成二维矩阵, 不合适.

4.2 矩阵

代码:

    print("---matrix test---")
    input_mat1 = torch.tensor([[1, 2, 3, 4],
            [1, 2, 2, 3]])
    print("input: ", input_mat1)

    horizontal_cat = torch.cat((input_mat1, input_mat1), 0)
    vertical_cat = torch.cat((input_mat1, input_mat1), 1)

    print("horizontal_cat = ", horizontal_cat)
    print("vertical_cat = ", vertical_cat)
    print("shape: ", np.shape(input_mat1), np.shape(horizontal_cat), np.shape(vertical_cat))

结果:

---matrix test---
input:  tensor([[1, 2, 3, 4],
        [1, 2, 2, 3]])
horizontal_cat =  tensor([[1, 2, 3, 4],
        [1, 2, 2, 3],
        [1, 2, 3, 4],
        [1, 2, 2, 3]])
vertical_cat =  tensor([[1, 2, 3, 4, 1, 2, 3, 4],
        [1, 2, 2, 3, 1, 2, 2, 3]])
shape:  torch.Size([2, 4]) torch.Size([4, 4]) torch.Size([2, 8])       

分析:

  • 水平叠加两个 m × n m \times n m×n 矩阵, 将获得一个 2 m × n 2m \times n 2m×n 矩阵; 垂直叠加两个 m × n m \times n m×n 矩阵, 将获得一个 m × 2 n m \times 2n m×2n 矩阵.

4.3 张量

    print("---tensor test---")
    input_tensor1 = torch.tensor([[[1, 2, 3, 4], [1, 2, 2, 3]],
                                  [[5, 6, 7, 8], [8, 7, 6, 5]]])
    print("input: ", input_tensor1)

    horizontal_cat = torch.cat((input_tensor1, input_tensor1), 0)
    vertical_cat = torch.cat((input_tensor1, input_tensor1), 1)

    print("horizontal_cat = ", horizontal_cat)
    print("vertical_cat = ", vertical_cat)
    print("shape: ", np.shape(input_tensor1), np.shape(horizontal_cat), np.shape(vertical_cat))   

结果:

---tensor test---
input:  tensor([[[1, 2, 3, 4],
         [1, 2, 2, 3]],

        [[5, 6, 7, 8],
         [8, 7, 6, 5]]])
horizontal_cat =  tensor([[[1, 2, 3, 4],
         [1, 2, 2, 3]],

        [[5, 6, 7, 8],
         [8, 7, 6, 5]],

        [[1, 2, 3, 4],
         [1, 2, 2, 3]],

        [[5, 6, 7, 8],
         [8, 7, 6, 5]]])
vertical_cat =  tensor([[[1, 2, 3, 4],
         [1, 2, 2, 3],
         [1, 2, 3, 4],
         [1, 2, 2, 3]],

        [[5, 6, 7, 8],
         [8, 7, 6, 5],
         [5, 6, 7, 8],
         [8, 7, 6, 5]]])
shape:  torch.Size([2, 2, 4]) torch.Size([4, 2, 4]) torch.Size([2, 4, 4])

分析:

  • 水平叠加两个 m × n × k m \times n \times k m×n×k 张量, 将获得一个 2 m × n × k 2m \times n \times k 2m×n×k 张量; 垂直叠加两个 m × n m \times n m×n 矩阵, 将获得一个 m × 2 n × k m \times 2n \times k m×2n×k 矩阵;

5. 堆叠 (stack)

相关推荐

  1. python 学习: 矩阵运算

    2024-05-02 13:26:04       33 阅读
  2. 矩阵运算

    2024-05-02 13:26:04       55 阅读
  3. Jones矩阵符号运算

    2024-05-02 13:26:04       39 阅读
  4. 机器人--矩阵运算

    2024-05-02 13:26:04       30 阅读
  5. 37、matlab矩阵运算

    2024-05-02 13:26:04       24 阅读
  6. Python学习系列之三目运算

    2024-05-02 13:26:04       27 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-05-02 13:26:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-05-02 13:26:04       100 阅读
  3. 在Django里面运行非项目文件

    2024-05-02 13:26:04       82 阅读
  4. Python语言-面向对象

    2024-05-02 13:26:04       91 阅读

热门阅读

  1. Android 修改Camera的最大变焦倍数

    2024-05-02 13:26:04       31 阅读
  2. 三生随记——午夜医院的诡异回声

    2024-05-02 13:26:04       26 阅读
  3. 美国国防部数据网格参考架构概述(下)

    2024-05-02 13:26:04       31 阅读
  4. 文件上传知识

    2024-05-02 13:26:04       30 阅读
  5. k8s面试29连问

    2024-05-02 13:26:04       25 阅读
  6. solidity(16)

    2024-05-02 13:26:04       34 阅读
  7. 【刷爆力扣之二叉树】107. 二叉树的层序遍历 II

    2024-05-02 13:26:04       35 阅读
  8. LeetCode //C - 44. Wildcard Matching

    2024-05-02 13:26:04       35 阅读