其实这里的torch的nonzero() 和numpy的nonzero() 是一样的使用方法。
就是返回所有非0元素的index。一般用于CV中mask操作
首先定义一个5*4的tensor,其中为0的元素个数为3个,不为0的个数为17个
import torch
X=torch.tensor([[1, 2, 0, 4],
[5, 6, 7, 8],
[0, 10,11,12],
[13,14, 0,16],
[17,18,19,20]])
返回非0的所有元素的索引矩阵,下面两种写法都可以
写法1:
# 返回一个 17*2的矩阵,行代表非0的行索引,列代表非0的列索引
print(X.nonzero())
写法2:
print(torch.nonzero(X))
tensor([[0, 0],
[0, 1],
[0, 3],
[1, 0],
[1, 1],
[1, 2],
[1, 3],
[2, 1],
[2, 2],
[2, 3],
[3, 0],
[3, 1],
[3, 3],
[4, 0],
[4, 1],
[4, 2],
[4, 3]])
返回类型定义为tuple
# 返回维度维度与输入一样的N个元组,第一个元组代表行索引,第二个元组代表列索引,依次向后
print(X.nonzero(as_tuple=True))
(tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4]),
tensor([0, 1, 3, 0, 1, 2, 3, 1, 2, 3, 0, 1, 3, 0, 1, 2, 3]))
Numpy.nonzero() 使用方法简明解释_for j in np.nonzero(adj[i])[0]:-CSDN博客