PyTorch中的flatten+transpose函数说明

flatten函数就是对tensor类型进行扁平化处理,就是在不同维度上进行堆叠操作

a.flatten(m),这个意思是将a这个tensor,从第m(m取值从0开始)维度开始堆叠,一直堆叠到最后一个维度

import torch

a=torch.rand(2,3,2,3)

print(a)

x = a.flatten(0)
print(x)
print(x.size())

y = a.flatten(1)
print(y)
print(y.size())

z = a.flatten(2)
print(z)
print(z.size())


#a.flatten()这个括号里面的参数也不只是只有一个,在官方文档里面的说法,这个里面可以是两个参数
#start_dim (int) – the first dim to flatten
#end_dim (int) – the last dim to flatten
#将a的0维度和1维度合并

u = a.flatten(0,1)
print(u)
print(u.size())

输出:
tensor([[[[0.9807, 0.8278, 0.2853],
          [0.2290, 0.3709, 0.6642]],

         [[0.2521, 0.0556, 0.3562],
          [0.3926, 0.9639, 0.3037]],

         [[0.9804, 0.7069, 0.8673],
          [0.0434, 0.5438, 0.7231]]],


        [[[0.7031, 0.2287, 0.0640],
          [0.5223, 0.0660, 0.5081]],

         [[0.2562, 0.4229, 0.8700],
          [0.1164, 0.5058, 0.2986]],

         [[0.6062, 0.2247, 0.4474],
          [0.2376, 0.5606, 0.5911]]]])
tensor([0.9807, 0.8278, 0.2853, 0.2290, 0.3709, 0.6642, 0.2521, 0.0556, 0.3562,
        0.3926, 0.9639, 0.3037, 0.9804, 0.7069, 0.8673, 0.0434, 0.5438, 0.7231,
        0.7031, 0.2287, 0.0640, 0.5223, 0.0660, 0.5081, 0.2562, 0.4229, 0.8700,
        0.1164, 0.5058, 0.2986, 0.6062, 0.2247, 0.4474, 0.2376, 0.5606, 0.5911])
torch.Size([36])

tensor([[0.9807, 0.8278, 0.2853, 0.2290, 0.3709, 0.6642, 0.2521, 0.0556, 0.3562,
         0.3926, 0.9639, 0.3037, 0.9804, 0.7069, 0.8673, 0.0434, 0.5438, 0.7231],
        [0.7031, 0.2287, 0.0640, 0.5223, 0.0660, 0.5081, 0.2562, 0.4229, 0.8700,
         0.1164, 0.5058, 0.2986, 0.6062, 0.2247, 0.4474, 0.2376, 0.5606, 0.5911]])
torch.Size([2, 18])

tensor([[[0.9807, 0.8278, 0.2853, 0.2290, 0.3709, 0.6642],
         [0.2521, 0.0556, 0.3562, 0.3926, 0.9639, 0.3037],
         [0.9804, 0.7069, 0.8673, 0.0434, 0.5438, 0.7231]],

        [[0.7031, 0.2287, 0.0640, 0.5223, 0.0660, 0.5081],
         [0.2562, 0.4229, 0.8700, 0.1164, 0.5058, 0.2986],
         [0.6062, 0.2247, 0.4474, 0.2376, 0.5606, 0.5911]]])
torch.Size([2, 3, 6])

tensor([[[0.9807, 0.8278, 0.2853],
         [0.2290, 0.3709, 0.6642]],

        [[0.2521, 0.0556, 0.3562],
         [0.3926, 0.9639, 0.3037]],

        [[0.9804, 0.7069, 0.8673],
         [0.0434, 0.5438, 0.7231]],

        [[0.7031, 0.2287, 0.0640],
         [0.5223, 0.0660, 0.5081]],

        [[0.2562, 0.4229, 0.8700],
         [0.1164, 0.5058, 0.2986]],

        [[0.6062, 0.2247, 0.4474],
         [0.2376, 0.5606, 0.5911]]])
torch.Size([6, 2, 3])

transpose是Tensor类的一个重要方法,同时它也是torch模块中的一个函数

返回一个张量,它是输入张量的转置版本,其中将给定的维度dim0和dim1交换

import random
import torch

#二维数据情况
arr = torch.rand(2,3)
print(arr)
print(arr.size())
 
a = arr.transpose(1, 0)
print(a)
print(a.size())

#三维数据情况
arr = torch.rand(2,3,4)
print(arr)
print(arr.size())
 
a = arr.transpose(1, 0)
print(a)
print(a.size())

b = arr.transpose(1, 2)
print(b)
print(b.size())

输出:
tensor([[0.3193, 0.1526, 0.0878],
        [0.2070, 0.5021, 0.0383]])
torch.Size([2, 3])

tensor([[0.3193, 0.2070],
        [0.1526, 0.5021],
        [0.0878, 0.0383]])
torch.Size([3, 2])

tensor([[[0.9428, 0.8610, 0.7115, 0.2870],
         [0.0846, 0.5500, 0.8890, 0.6003],
         [0.2907, 0.1275, 0.9961, 0.9360]],

        [[0.3068, 0.2193, 0.6061, 0.3032],
         [0.3735, 0.1232, 0.4352, 0.2763],
         [0.5179, 0.7830, 0.1859, 0.1262]]])
torch.Size([2, 3, 4])

tensor([[[0.9428, 0.8610, 0.7115, 0.2870],
         [0.3068, 0.2193, 0.6061, 0.3032]],

        [[0.0846, 0.5500, 0.8890, 0.6003],
         [0.3735, 0.1232, 0.4352, 0.2763]],

        [[0.2907, 0.1275, 0.9961, 0.9360],
         [0.5179, 0.7830, 0.1859, 0.1262]]])
torch.Size([3, 2, 4])

tensor([[[0.6059, 0.7055, 0.8131],
         [0.3136, 0.1284, 0.1374],
         [0.8604, 0.0243, 0.3363],
         [0.5041, 0.0764, 0.0649]],

        [[0.6565, 0.1308, 0.7233],
         [0.6803, 0.9431, 0.8020],
         [0.2651, 0.7857, 0.4266],
         [0.4035, 0.1960, 0.8238]]])
torch.Size([2, 4, 3])

相关推荐

  1. PyTorchflatten+transpose函数说明

    2024-03-30 07:00:05       21 阅读
  2. PyTorchnn.ReLU函数说明

    2024-03-30 07:00:05       20 阅读
  3. PyTorchview()函数用法说明

    2024-03-30 07:00:05       21 阅读
  4. Pytorch钩子函数Hook函数

    2024-03-30 07:00:05       17 阅读
  5. rpn网络rpnhead回归框解码说明(pytorch)

    2024-03-30 07:00:05       12 阅读
  6. pytorchunsqueeze用法说明

    2024-03-30 07:00:05       12 阅读
  7. python魔法函数pytorchforward()函数

    2024-03-30 07:00:05       40 阅读
  8. Pytorch forward 函数内部原理

    2024-03-30 07:00:05       17 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-03-30 07:00:05       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-03-30 07:00:05       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-03-30 07:00:05       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-03-30 07:00:05       20 阅读

热门阅读

  1. 使用Dom4j解析多层级XML为Map对象

    2024-03-30 07:00:05       17 阅读
  2. 【threejs】计算矩阵、网格等总面积

    2024-03-30 07:00:05       20 阅读
  3. spark DataFrame通过JDBC读写数据库(MySQL示例)

    2024-03-30 07:00:05       15 阅读
  4. npm包发布

    2024-03-30 07:00:05       20 阅读
  5. Node.js常用命令详解

    2024-03-30 07:00:05       15 阅读
  6. 在axios中设置方法防止http重复请求

    2024-03-30 07:00:05       18 阅读
  7. SqlSugar快速入门

    2024-03-30 07:00:05       17 阅读
  8. qt之windows库编译

    2024-03-30 07:00:05       22 阅读
  9. MYSQL分区

    2024-03-30 07:00:05       18 阅读
  10. 关于debian如何使用lb-build构建iso

    2024-03-30 07:00:05       18 阅读