torch.where()中并行方式的实现

torch.where()中一般有三个参数。

第一个参数是一个判断条件。

第二个参数是条件成立时的值。

第三个参数是条件不成立时的值。

        for batch in range(2):
            for i in range(256):
                for j in range(256):
                    output[batch][i][j] = 0 if tensor_count_0[A_arg[batch,i,j]][B_arg[batch,i,j]].item() >= tensor_count_1[A_arg[batch,i,j]][B_arg[batch,i,j]].item() else 1

output,A_arg,B_arg尺寸为[2,256,256]   tensor_count_0和tensor_count_1的尺寸为[15,15],它们都是tensor数据,且都在GPU上。所以可以改为并行方式:

        output = torch.where(tensor_count_0[A_arg, B_arg] >= tensor_count_1[A_arg, B_arg], torch.zeros_like(output),torch.ones_like(output))

相关推荐

  1. torch.where()并行方式实现

    2024-04-22 11:44:03       31 阅读
  2. Go怎么实现map并发安全三种方式

    2024-04-22 11:44:03       10 阅读
  3. Spring Boot 实现跨域几种方式

    2024-04-22 11:44:03       35 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-22 11:44:03       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-22 11:44:03       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-22 11:44:03       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-22 11:44:03       18 阅读

热门阅读

  1. http和https区别与上网过程

    2024-04-22 11:44:03       14 阅读
  2. SQLite去除.db-shm和.db-wal文件【已解决】

    2024-04-22 11:44:03       12 阅读
  3. Spring Boot 中整合 Redisson 实现分布式锁

    2024-04-22 11:44:03       10 阅读
  4. 三年经验!你还不知道KVM虚拟化技术???

    2024-04-22 11:44:03       11 阅读
  5. python内存泄漏解决

    2024-04-22 11:44:03       12 阅读
  6. 工程师每日刷题-7

    2024-04-22 11:44:03       12 阅读
  7. Vue模版语法(初学Vue之v-指令语法)

    2024-04-22 11:44:03       14 阅读
  8. 什么是 ORM(对象关系映射)

    2024-04-22 11:44:03       13 阅读
  9. web开发

    web开发

    2024-04-22 11:44:03      12 阅读
  10. 【数学建模】建筑工地开工问题

    2024-04-22 11:44:03       13 阅读
  11. 速盾:cdn都能防御哪些攻击?

    2024-04-22 11:44:03       12 阅读