【PyTorch】register_hook的使用

首先看一下正常的梯度计算例子:

#正常求导情况
v = torch.randn((1, 3), dtype=torch.float32, requires_grad=True)
z = v.sum()
z.backward()
print(v.grad)

输出:

tensor([[1., 1., 1.]])

 上面的代码中,当执行到z.backward()这一句代码的时候,就是计算变量z的偏导数,因为v=torch.randn(1,3)也就是1行3列,所以可以假设v=(v1, v2, v3),那么z=v1+v2+v3。所以z对v的偏微分就是:

\frac{\partial z}{\partial v}=\frac{\partial (v1+v2+v3)}{\partial v}=\frac{\partial (v1+v2+v3)}{\partial (v1,v2,v3)}=(\frac{\partial (v1+v2+v3)}{\partial v1},\frac{\partial (v1+v2+v3)}{\partial v2},\frac{\partial (v1+v2+v3)}{\partial v3})

 其中:

\frac{\partial (v1+v2+v3)}{\partial v1}=\frac{\partial v1}{\partial v2}+\frac{\partial v2}{\partial v2}+\frac{\partial v3}{\partial v2}=0+1+0=1

所以可以得出上面的偏微分结果为 :tensor([[1., 1., 1.]])。

如果我们需要对导数进行2倍的操作:

v = torch.randn((1, 3), dtype=torch.float32, requires_grad=True)
z = v.sum()
# lambda grad: grad*2是一个函数,即:
# def lambda(grad):
#    return grad*2
v.register_hook(lambda grad: grad*2)  
z.backward()
print(v.grad)

输出为:

tensor([[2., 2., 2.]])

可以看出v.register_hook()的作用是将反向传播过程中关于v的梯度给取出来,同时进行一些操作,上面代码所进行的操作是对关于v的梯度乘以2,当然,这里的梯度只是暂时取出来了,如果需要“长久的”保存梯度信息方便后续的计算的话,则可以如下代码所示:

grad_store = []
def function(grad):
    grad_store.append(grad)

v = torch.randn((1, 3), dtype=torch.float32, requires_grad=True)
z = v.sum()
v.register_hook(function)  
z.backward()

上面的代码即可将梯度保存到变量grad_store中,方便后面计算Grad-CAM等等。

相关推荐

  1. ThreadLocal使用以及使用场景

    2024-03-13 05:26:04       25 阅读
  2. git使用

    2024-03-13 05:26:04       73 阅读
  3. websoket 使用

    2024-03-13 05:26:04       56 阅读
  4. Logstash使用方法

    2024-03-13 05:26:04       67 阅读
  5. Auth使用、缓存

    2024-03-13 05:26:04       56 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-13 05:26:04       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-13 05:26:04       106 阅读
  3. 在Django里面运行非项目文件

    2024-03-13 05:26:04       87 阅读
  4. Python语言-面向对象

    2024-03-13 05:26:04       96 阅读

热门阅读

  1. docker直接下载太慢,更换国内靠谱镜像源

    2024-03-13 05:26:04       38 阅读
  2. vue双向绑定/小程序双向绑定?

    2024-03-13 05:26:04       44 阅读
  3. 从SQL质量管理体系来看SQL审核(1)

    2024-03-13 05:26:04       38 阅读
  4. 面试经典-4-LRU 缓存

    2024-03-13 05:26:04       40 阅读
  5. 使用SpringBoot实现定时任务

    2024-03-13 05:26:04       44 阅读
  6. 子查询的特殊用途

    2024-03-13 05:26:04       41 阅读
  7. 双指针算法———C++

    2024-03-13 05:26:04       45 阅读
  8. docker搭建upload-labs

    2024-03-13 05:26:04       44 阅读
  9. 大数据开发(HBase面试真题-卷二)

    2024-03-13 05:26:04       43 阅读