试想一个简单的问题:一个维度是[1, 6]的tensor,我想其中的6个元素分成3组,每组2个元素,然后对每组中的元素求平均值,得到一个维度是3的输出。应该怎么用Python实现?
最直观的想法就是:将[1, 6]先reshape成[2, 3]或[3, 2],然后在2对应的维度上进行运算,最终得到维度是3的结果。但是,真的[2, 3]或[3, 2]都能行吗?
下面让我们来看看两者的区别,感受一下区别。
a = torch.tensor([[1., 2., 3., 4., 5., 6.]])
b1 = a.reshape(2, 3)
print(b1)
mean_b1 = torch.mean(b1, dim=0, keepdim=False)
print(mean_b1)
b2 = a.reshape(3, 2)
print(b2)
mean_b2 = torch.mean(b2, dim=1, keepdim=False)
print(mean_b2)
输出:
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([2.5000, 3.5000, 4.5000])
tensor([[1., 2.],
[3., 4.],
[5., 6.]])
tensor([1.5000, 3.5000, 5.5000])
可以发现,[2, 3]和[3, 2]两种方式得到的结果是不同的!因为reshape后的元素排列方式不同。
总结一句,如果想将tensor分成n组,然后对每组进行运算(比如求和、求平均),那么,要记得把组数放在前面的维度上(也就是上面例子中的3),把每组的元素数放在后面的维度上(也就是上面例子中的2)。
但如果只是reshape,但是没有分组的运算,那么[2, 3]和[3, 2]都可以,反正都可以等价地reshape回原来的排列。