深度学习pytorch——索引与切片

indexing

import torch 
a = torch.rand(4,3,28,28)    # 表示4张28*28的rgb图
print(a[0].shape)            # a[0]获得第一张图片
print(a[0,0].shape)          # a[0,0]获得第一张图片的r图
print(a[0,0,2,4])            # 获得第一张图片第一个通道的一个像素点,因此得到的是一个标量

select first/last N 

# select first/last N
print(a[:2].shape)          # :2 => 0,1
print(a[:2,:1,:,:].shape)   # :1 => 0
print(a[:2,1:,:,:].shape)   # 1: => 1,2
print(a[:2,-1:,:,:].shape)  # -1: => 2

select by steps 

# select by steps
print(a[:,:,0:28:2,0:28:2].shape)   # 0:28:2 => 从0-28,步长为2
print(a[:,:,::2,::2].shape)         # ::2 => 从0-28,步长为2
# 总结
# 1. : => all
# 2. :n => 从最开始到n,不包括n
# 3. n: => 从n到最后
# 4. start:end => 从start到end,不包含end
# 5. start:end:steps => 从start到end,不包含end,步长为2

select by specific index 

# select by specific index
print(a.index_select(0,torch.tensor([0,2])).shape)  # index_select() 第一个参数是维度,第二个参数是具体的索引号,但是索引号必须是tensor,所以要使用torch.tensor()
print(a.index_select(2,torch.arange(28)).shape)     # torch.arange(28) => tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        # 18, 19, 20, 21, 22, 23, 24, 25, 26, 27])

 ...

# ...
print(a[...].shape)         # a[...] => a[:,:,:,:]
print(a[0,...].shape)       # a[0,...] => a[0,:,:,:]
print(a[:,1,...].shape)     # a[:,1,...] => a[:,1,:,:]
print(a[...,:2].shape)      # a[...,:2] => a[:,:,:,:2]

select by mask 

# select by mask
# .masked_select() 会将数据默认打平->之所以打平是因为当满足某条件的位数是根据内容才能确定的
x = torch.randn(3,4)
mask = x.ge(0.5)                            # 将大于等于0.5的数取为ture
print(mask)                                 # 掩码
print(torch.masked_select(x,mask))          # 根据掩码取数据和原shape无关
print(torch.masked_select(x,mask).shape)

select by flatten index 

# select by flatten index   打平
src = torch.tensor([[4,3,5],[6,7,8]])
print(torch.take(src,torch.tensor([0,2,5])))        # 打平,取索引为0,2,5的数

相关推荐

  1. 深度学习pytorch——索引切片

    2024-03-17 06:38:04       44 阅读
  2. pytorch深度学习

    2024-03-17 06:38:04       35 阅读
  3. pytorch深度学习

    2024-03-17 06:38:04       28 阅读
  4. Pytorch深度学习

    2024-03-17 06:38:04       34 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-17 06:38:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-17 06:38:04       101 阅读
  3. 在Django里面运行非项目文件

    2024-03-17 06:38:04       82 阅读
  4. Python语言-面向对象

    2024-03-17 06:38:04       91 阅读

热门阅读

  1. PyTorch究竟是什么?

    2024-03-17 06:38:04       42 阅读
  2. PyTorch学习笔记之基础函数篇(十一)

    2024-03-17 06:38:04       44 阅读
  3. 微信小程序上传图片c# asp.net mvc端接收案例

    2024-03-17 06:38:04       43 阅读
  4. Spring核心方法:Refresh全解(WebMVC如何装配、关联)

    2024-03-17 06:38:04       38 阅读
  5. 阿里提前批(阿里云)一面30min

    2024-03-17 06:38:04       46 阅读
  6. MATLAB中的矩阵和数组,它们之间有什么区别?

    2024-03-17 06:38:04       41 阅读
  7. C语言每日一题—判断是否为魔方矩阵

    2024-03-17 06:38:04       45 阅读
  8. C++ 基础组件(1)定时器

    2024-03-17 06:38:04       41 阅读
  9. 【C++】每日一题 228 汇总区间

    2024-03-17 06:38:04       36 阅读
  10. Spring之底层架构核心概念解析

    2024-03-17 06:38:04       35 阅读
  11. c# 修改数据集

    2024-03-17 06:38:04       37 阅读
  12. 合作测试开发日志1

    2024-03-17 06:38:04       37 阅读
  13. go的fasthttp学习

    2024-03-17 06:38:04       39 阅读