深度学习pytorch——高阶OP(where & gather)(持续更新)

where

1、我们为什么需要where?

我们经常需要一个数据来自好几个的取值,而这些取值通常是不规律的,这就会导致使用传统的拆分和合并会非常的麻烦。我们也可以使用for循环嵌套来取值,也是可以的,但是使用for循环就意味着是python,那并没有很好的利用pytorch提供的使用gpu加速计算,当数据量非常大的话,会很大的拉低效率,因此我们使用pytorch提供的where。

2、where的使用

语法:torch.where(condition, x, y)  ------>  tensor

返回值:最后的返回值是一个张量,最后每个元素来自数据x,还是数据y依赖于条件。

使用where的条件:x.shape = y.shape = c.shape = condition.shape(c为结果,condition为0 1矩阵)

代码示例:

cond = torch.tensor([[0.6,0.7],[0.8,0.4]])
a = torch.zeros(2,2)
b = torch.ones(2,2)
print(torch.where(cond>0.5,a,b))
# tensor([[0., 0.],
#         [0., 1.]])

gather

1、我们为什么需要gather?

gather:根据index收集数据。

不使用gather的情况:

可以从上图中看出,索引是非常繁琐的,而且不小心就看错了,虽说也不是很难,但是深度学习处理的数据都是非常庞大的,比如一个1024*1024的图片,这时候内心是崩溃的🌹。还有一点,我们可以使用gpu帮助我们加快数据处理的效率。 

2、gather的使用

语法:torch.gather(input, dim, index, out=None) -----> tensor

input:表

dim:在哪个维度查表

index:索引表

代码示例:

prob=torch.randn(4,10)
idx=prob.topk(dim=1,k=3)
idx=idx[1]
# 以上为了得到索引表
label=torch.arange(10)+100
print(torch.gather(label.expand(4,10),dim=1,index=idx))

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-21 21:54:03       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-21 21:54:03       101 阅读
  3. 在Django里面运行非项目文件

    2024-03-21 21:54:03       82 阅读
  4. Python语言-面向对象

    2024-03-21 21:54:03       91 阅读

热门阅读

  1. c++简介

    2024-03-21 21:54:03       44 阅读
  2. web高可用集群(lvs负载均衡+keepalved高可用)

    2024-03-21 21:54:03       37 阅读
  3. 算法刷题day32

    2024-03-21 21:54:03       32 阅读
  4. Linux 安装RabbitMQ及RabbitMQ Web界面管理

    2024-03-21 21:54:03       39 阅读
  5. 注解的原理

    2024-03-21 21:54:03       35 阅读
  6. 浅谈Spring框架

    2024-03-21 21:54:03       44 阅读
  7. C 语言中常量和变量的区别

    2024-03-21 21:54:03       46 阅读
  8. 【生命周期】简述及部分软件知识补充

    2024-03-21 21:54:03       40 阅读
  9. IM服务集群与跨服务器消息路由策略

    2024-03-21 21:54:03       33 阅读
  10. sqllab通关笔记(汇总)

    2024-03-21 21:54:03       39 阅读
  11. Docker 极简入门指南

    2024-03-21 21:54:03       120 阅读
  12. Spring(概念)

    2024-03-21 21:54:03       44 阅读