PyTorch张量形状

🍃作者介绍:双非本科大三网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发、数据结构和算法,初步涉猎人工智能和前端开发。
🦅个人主页:@逐梦苍穹
📕所属专栏:人工智能
🌻gitee地址:xzl的人工智能代码仓库
✈ 您的一键三连,是我创作的最大动力🌹

1、简介

1.1、形状(Shape)

形状描述张量在每个维度上的大小,用一个元组表示。例如:

  • 一个标量的形状是 ()
  • 一个长度为 5 的向量的形状是 (5,)
  • 一个 3x4 的矩阵的形状是 (3, 4)
  • 一个 2x3x4 的三维张量的形状是 (2, 3, 4)

1.2、形状的意义

形状决定了张量的结构和存储方式:

  • 存储空间:形状可以帮助确定存储张量所需的内存大小。例如,形状为 (2, 3, 4) 的张量需要存储 2 * 3 * 4 = 24 个元素。
  • 操作兼容性:许多张量操作需要输入张量具有兼容的形状。例如,矩阵乘法要求两个矩阵的内维度大小相同。
  • 数据访问:形状影响数据访问的方式。例如,形状为 (3, 4) 的张量可以通过两个索引来访问特定元素,如 tensor[1, 2] 访问第二行第三列的元素。

1.3、形状操作

在实际应用中,经常需要对张量的形状进行各种操作,包括:

  • 重塑(Reshape):改变张量的形状而不改变其数据。例如,将形状为 (6,) 的向量重塑为形状为 (2, 3) 的矩阵。
  • 扩展(Expand):在张量的特定位置添加新的维度。例如,将形状为 (3, 4) 的矩阵扩展为形状为 (1, 3, 4) 的三维张量。
  • 压缩(Squeeze):移除张量中大小为 1 的维度。例如,将形状为 (1, 3, 1, 4) 的张量压缩为形状为 (3, 4) 的张量。
  • 连接(Concat):将多个张量沿指定维度连接起来。例如,将两个形状为 (3, 4) 的张量沿行方向连接成一个形状为 (6, 4) 的张量。

理解张量的形状及其操作对于深度学习和数据处理中的张量运算至关重要。这些概念帮助我们有效地管理和操作多维数据。

1.4、本文摘要

掌握reshape, transpose, permute, view, contigous, squeeze, unsqueeze等函数使用

在后面搭建网络模型时,数据都是基于张量形式的表示,网络层与层之间很多都是以不同的 shape 的方式进行表现和运算,我们需要掌握对张量形状的操作,以便能够更好处理网络各层之间的数据连接。

2、reshape 函数

reshape 函数可以在保证张量数据不变的前提下改变数据的维度,将其转换成指定的形状,在后面的神经网络学习时,会经常使用该函数来调节数据的形状,以适配不同网络层之间的数据传递。

def test01():
    data = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print(data)
    # 1. 使用 shape 属性或者 size 方法都可以获得张量的形状
    print(data.shape, data.shape[0], data.shape[1])
    print(data.size(), data.size(0), data.size(1))
    # 2. 使用 reshape 函数修改张量形状
    new_data = data.reshape(1, 6)
    print(new_data)
    print(new_data.shape)

程序运行结果:

E:\anaconda3\python.exe D:\Python\AI\PyTorch\13-张量形状操作.py 
tensor([[10, 20, 30],
        [40, 50, 60]])
torch.Size([2, 3]) 2 3
torch.Size([2, 3]) 2 3
tensor([[10, 20, 30, 40, 50, 60]])
torch.Size([1, 6])

Process finished with exit code 0

3、transpose 和 permute 函数

transpose 函数可以实现交换张量形状的指定维度,。
例如: 一个张量的形状为 (2, 3, 4) 可以通过 transpose 函数把 3 和 4 进行交换, 将张量的形状变为 (2, 4, 3)
permute 函数可以一次交换更多的维度
torch.transpose(input, dim0, dim1)

  • input:需要进行维度交换的输入张量。
  • dim0:要交换的第一个维度。
  • dim1:要交换的第二个维度

torch.permute(input, dims)

  • input:需要重新排列维度的输入张量。
  • dims:一个包含新维度顺序的列表或元组。

torch.permute(data, [1, 2, 0]),1维和2维交换,交换之后的2维和0维交换

# transpose and permute
def test02():
    data = torch.tensor(np.random.randint(0, 10, [3, 4, 5]))
    print('data: ', data)
    print('data shape:', data.size())
    # 1. 交换1和2维度
    new_data = torch.transpose(data, 1, 2)
    print('data shape:', new_data.size())
    # 2. 将 data 的形状修改为 (4, 3, 5),然后改为(4, 5 ,3)
    new_data = torch.transpose(data, 0, 1)
    new_data = torch.transpose(new_data, 1, 2)
    print('new_data shape:', new_data.size())
    # 3. 使用 permute 函数将形状修改为 (4, 5, 3)
    new_data = torch.permute(data, [1, 2, 0])
    print('new_data shape:', new_data.size())

程序运行结果:

E:\anaconda3\python.exe D:\Python\AI\PyTorch\13-张量形状操作.py 
data:  tensor([[[3, 4, 6, 4, 5],
         [9, 8, 5, 6, 1],
         [9, 6, 7, 0, 6],
         [3, 6, 1, 2, 8]],

        [[7, 3, 0, 5, 0],
         [6, 6, 6, 3, 3],
         [3, 8, 9, 8, 5],
         [5, 1, 5, 2, 5]],

        [[2, 4, 0, 2, 8],
         [8, 3, 0, 0, 7],
         [6, 5, 3, 0, 8],
         [6, 3, 5, 8, 3]]], dtype=torch.int32)
data shape: torch.Size([3, 4, 5])
data shape: torch.Size([3, 5, 4])
new_data shape: torch.Size([4, 5, 3])
new_data shape: torch.Size([4, 5, 3])

Process finished with exit code 0

4、view 和 contigous 函数

  1. view 函数也可以用于修改张量的形状,但是其用法比较局限,只能用于存储在整块内存中的张量。
  2. contiguous 方法用于返回一个内存连续的张量副本。如果张量已经是连续的,则返回自身;如果不是连续的,则返回一个新的、连续的张量副本。

在 PyTorch 中,有些张量是由不同的数据块组成的,它们并没有存储在整块的内存中,view 函数无法对这样的张量进行变形处理,例如: 一个张量经过了** transpose 或者 permute **函数的处理之后,就无法使用 view 函数进行形状操作,那么就需要使用contiguous进行处理。

4.1、代码

# view and contiguous
def test03():
    data = torch.tensor([[10, 20, 30], [40, 50, 60]])
    print('data shape:', data.size())
    # 1. 使用 view 函数修改形状
    new_data = data.view(3, 2)
    print('new_data shape:', new_data.shape)
    # 2. 判断张量是否使用整块内存
    print('data:', data.is_contiguous())  # True
    # 3. 使用 transpose 函数修改形状
    new_data = torch.transpose(data, 0, 1)
    print('new_data shape:', new_data.shape)
    print('new_data:', new_data.is_contiguous())  # False
    # new_data = new_data.view(2, 3)  # RuntimeError

    # 需要先使用 contiguous 函数转换为整块内存的张量,再使用 view 函数
    print(new_data.contiguous().is_contiguous())
    new_data = new_data.contiguous().view(2, 3)
    print('new_data shape:', new_data.shape)

程序运行结果:

E:\anaconda3\python.exe D:\Python\AI\PyTorch\13-张量形状操作.py 
data shape: torch.Size([2, 3])
new_data shape: torch.Size([3, 2])
data: True
new_data shape: torch.Size([3, 2])
new_data: False
True
new_data shape: torch.Size([2, 3])

Process finished with exit code 0

4.2、内存连续与不连续⭐

张量内存不连续:Non-contiguous Tensor

在PyTorch中,张量在内存中的存储方式可能是连续的,也可能是非连续的。
内存不连续的张量意味着它们在内存中的数据存储顺序不按照张量元素的索引顺序排列。
这通常发生在对张量进行一些变换操作之后,比如转置、切片等。

什么是内存连续

一个张量的内存是连续的,意味着它的元素在内存中是按**行优先(row-major)列优先(column-major)**顺序存储的。
对于一个二维张量,按行优先顺序意味着第一行的所有元素存储在一起,接着是第二行,依此类推。

什么是内存不连续

不连续是指张量的数据在内存中的排列方式不再按照顺序连续存储,而是分散的。
当进行某些操作后,比如**转置(transpose)**操作,张量的数据在内存中的存储顺序可能会改变,导致内存不连续。

例如,对于一个二维张量,原本连续存储的数据可能在转置后变得不再连续存储。
转置操作会改变张量的维度顺序,但并不改变实际数据的存储方式,所以导致数据访问顺序不再连续。
这会影响某些操作的性能,并且一些需要连续内存的操作(如view)将无法直接应用。

5、squeeze 和 unsqueeze 函数

squeeze 函数用删除 shape 为 1 的维度,unsqueeze 在每个维度添加 1, 以增加数据的形状。

具体如下:

  1. torch.squeeze:用于移除尺寸为1的维度。

如果不指定 dim 参数,则移除所有尺寸为1的维度。
如果指定 dim 参数,则仅移除指定维度的尺寸为1的维度。

  1. torch.unsqueeze:用于在指定位置添加尺寸为1的维度。

dim 参数指定了在哪个位置添加尺寸为1的维度。

# squeeze and unsqueeze
def test04():
    data = torch.tensor(np.random.randint(0, 10, [1, 3, 1, 5]))
    print('data:', data)
    print('data shape:', data.size())
    # 1. 去掉值为1的维度
    new_data = data.squeeze()
    print('new_data:', new_data)
    print('new_data shape:', new_data.size())  # torch.Size([3, 5])
    # 2. 去掉指定位置为1的维度,注意: 如果指定位置不是1则不删除
    new_data = data.squeeze(2)
    print('new_data shape:', new_data.size())  # torch.Size([1, 3, 5])
    # 3. 在2维度增加一个维度
    print(data.shape)
    new_data = data.unsqueeze(-1)
    print('new_data shape:', new_data.size())  # torch.Size([1, 3, 1, 5, 1])

程序运行结果:

E:\anaconda3\python.exe D:\Python\AI\PyTorch\13-张量形状操作.py 
data: tensor([[[[0, 0, 6, 1, 5]],

         [[0, 3, 2, 4, 6]],

         [[7, 4, 9, 8, 7]]]], dtype=torch.int32)
data shape: torch.Size([1, 3, 1, 5])
new_data: tensor([[0, 0, 6, 1, 5],
        [0, 3, 2, 4, 6],
        [7, 4, 9, 8, 7]], dtype=torch.int32)
new_data shape: torch.Size([3, 5])
new_data shape: torch.Size([1, 3, 5])
torch.Size([1, 3, 1, 5])
new_data shape: torch.Size([1, 3, 1, 5, 1])

Process finished with exit code 0

相关推荐

  1. PyTorch形状

    2024-07-22 10:20:02       18 阅读
  2. PyTorch

    2024-07-22 10:20:02       39 阅读
  3. Pytorch广播

    2024-07-22 10:20:02       28 阅读
  4. Pytorch笔记】

    2024-07-22 10:20:02       19 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-22 10:20:02       52 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-22 10:20:02       54 阅读
  3. 在Django里面运行非项目文件

    2024-07-22 10:20:02       45 阅读
  4. Python语言-面向对象

    2024-07-22 10:20:02       55 阅读

热门阅读

  1. 深度学习落地实战:人脸面部表情识别

    2024-07-22 10:20:02       16 阅读
  2. Python中Selenium 和 keyboard 库的使用

    2024-07-22 10:20:02       12 阅读
  3. 【mybatis 一级缓存】

    2024-07-22 10:20:02       17 阅读
  4. QT表格显示MYSQL数据库源码分析(七)

    2024-07-22 10:20:02       16 阅读
  5. Github 2024-07-22开源项目日报Top10

    2024-07-22 10:20:02       13 阅读
  6. 十六、多任务

    2024-07-22 10:20:02       14 阅读
  7. 目标检测的隐形威胁:对抗攻击的深度解析

    2024-07-22 10:20:02       18 阅读
  8. ASP.NET Core Web深度探讨

    2024-07-22 10:20:02       15 阅读
  9. opencv—常用函数学习_“干货“_13

    2024-07-22 10:20:02       18 阅读
  10. 高精度-大整数计算模板

    2024-07-22 10:20:02       18 阅读
  11. Anonymous Informant

    2024-07-22 10:20:02       15 阅读
  12. Oracle(16)什么是视图(View)?

    2024-07-22 10:20:02       20 阅读