flatten函数就是对tensor类型进行扁平化处理,就是在不同维度上进行堆叠操作
a.flatten(m),这个意思是将a这个tensor,从第m(m取值从0开始)维度开始堆叠,一直堆叠到最后一个维度
import torch
a=torch.rand(2,3,2,3)
print(a)
x = a.flatten(0)
print(x)
print(x.size())
y = a.flatten(1)
print(y)
print(y.size())
z = a.flatten(2)
print(z)
print(z.size())
#a.flatten()这个括号里面的参数也不只是只有一个,在官方文档里面的说法,这个里面可以是两个参数
#start_dim (int) – the first dim to flatten
#end_dim (int) – the last dim to flatten
#将a的0维度和1维度合并
u = a.flatten(0,1)
print(u)
print(u.size())
输出:
tensor([[[[0.9807, 0.8278, 0.2853],
[0.2290, 0.3709, 0.6642]],
[[0.2521, 0.0556, 0.3562],
[0.3926, 0.9639, 0.3037]],
[[0.9804, 0.7069, 0.8673],
[0.0434, 0.5438, 0.7231]]],
[[[0.7031, 0.2287, 0.0640],
[0.5223, 0.0660, 0.5081]],
[[0.2562, 0.4229, 0.8700],
[0.1164, 0.5058, 0.2986]],
[[0.6062, 0.2247, 0.4474],
[0.2376, 0.5606, 0.5911]]]])
tensor([0.9807, 0.8278, 0.2853, 0.2290, 0.3709, 0.6642, 0.2521, 0.0556, 0.3562,
0.3926, 0.9639, 0.3037, 0.9804, 0.7069, 0.8673, 0.0434, 0.5438, 0.7231,
0.7031, 0.2287, 0.0640, 0.5223, 0.0660, 0.5081, 0.2562, 0.4229, 0.8700,
0.1164, 0.5058, 0.2986, 0.6062, 0.2247, 0.4474, 0.2376, 0.5606, 0.5911])
torch.Size([36])
tensor([[0.9807, 0.8278, 0.2853, 0.2290, 0.3709, 0.6642, 0.2521, 0.0556, 0.3562,
0.3926, 0.9639, 0.3037, 0.9804, 0.7069, 0.8673, 0.0434, 0.5438, 0.7231],
[0.7031, 0.2287, 0.0640, 0.5223, 0.0660, 0.5081, 0.2562, 0.4229, 0.8700,
0.1164, 0.5058, 0.2986, 0.6062, 0.2247, 0.4474, 0.2376, 0.5606, 0.5911]])
torch.Size([2, 18])
tensor([[[0.9807, 0.8278, 0.2853, 0.2290, 0.3709, 0.6642],
[0.2521, 0.0556, 0.3562, 0.3926, 0.9639, 0.3037],
[0.9804, 0.7069, 0.8673, 0.0434, 0.5438, 0.7231]],
[[0.7031, 0.2287, 0.0640, 0.5223, 0.0660, 0.5081],
[0.2562, 0.4229, 0.8700, 0.1164, 0.5058, 0.2986],
[0.6062, 0.2247, 0.4474, 0.2376, 0.5606, 0.5911]]])
torch.Size([2, 3, 6])
tensor([[[0.9807, 0.8278, 0.2853],
[0.2290, 0.3709, 0.6642]],
[[0.2521, 0.0556, 0.3562],
[0.3926, 0.9639, 0.3037]],
[[0.9804, 0.7069, 0.8673],
[0.0434, 0.5438, 0.7231]],
[[0.7031, 0.2287, 0.0640],
[0.5223, 0.0660, 0.5081]],
[[0.2562, 0.4229, 0.8700],
[0.1164, 0.5058, 0.2986]],
[[0.6062, 0.2247, 0.4474],
[0.2376, 0.5606, 0.5911]]])
torch.Size([6, 2, 3])
transpose是Tensor类的一个重要方法,同时它也是torch模块中的一个函数
返回一个张量,它是输入张量的转置版本,其中将给定的维度dim0和dim1交换
import random
import torch
#二维数据情况
arr = torch.rand(2,3)
print(arr)
print(arr.size())
a = arr.transpose(1, 0)
print(a)
print(a.size())
#三维数据情况
arr = torch.rand(2,3,4)
print(arr)
print(arr.size())
a = arr.transpose(1, 0)
print(a)
print(a.size())
b = arr.transpose(1, 2)
print(b)
print(b.size())
输出:
tensor([[0.3193, 0.1526, 0.0878],
[0.2070, 0.5021, 0.0383]])
torch.Size([2, 3])
tensor([[0.3193, 0.2070],
[0.1526, 0.5021],
[0.0878, 0.0383]])
torch.Size([3, 2])
tensor([[[0.9428, 0.8610, 0.7115, 0.2870],
[0.0846, 0.5500, 0.8890, 0.6003],
[0.2907, 0.1275, 0.9961, 0.9360]],
[[0.3068, 0.2193, 0.6061, 0.3032],
[0.3735, 0.1232, 0.4352, 0.2763],
[0.5179, 0.7830, 0.1859, 0.1262]]])
torch.Size([2, 3, 4])
tensor([[[0.9428, 0.8610, 0.7115, 0.2870],
[0.3068, 0.2193, 0.6061, 0.3032]],
[[0.0846, 0.5500, 0.8890, 0.6003],
[0.3735, 0.1232, 0.4352, 0.2763]],
[[0.2907, 0.1275, 0.9961, 0.9360],
[0.5179, 0.7830, 0.1859, 0.1262]]])
torch.Size([3, 2, 4])
tensor([[[0.6059, 0.7055, 0.8131],
[0.3136, 0.1284, 0.1374],
[0.8604, 0.0243, 0.3363],
[0.5041, 0.0764, 0.0649]],
[[0.6565, 0.1308, 0.7233],
[0.6803, 0.9431, 0.8020],
[0.2651, 0.7857, 0.4266],
[0.4035, 0.1960, 0.8238]]])
torch.Size([2, 4, 3])