Pytorch--Hooks For Module


1.register_module_forward_pre_hook

在 PyTorch 中,register_module_forward_pre_hook 是一个方法,用于向模型的模块注册前向传播预钩子(forward pre-hook)。预钩子是在模块的前向传播之前被调用的函数,允许在模块接收输入之前对输入进行修改或记录

import torch
import torch.nn as nn

# 定义一个前向传播预钩子函数
def forward_pre_hook(module, input):
    print("Forward pre-hook called for module:", module)
    print("Input shape:", input[0].shape)

# 创建一个模型类
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = MyModel()

# 注册前向传播预钩子
model.register_module_forward_pre_hook(forward_pre_hook)

# 输入数据
input_data = torch.randn(1, 10)

# 前向传播
output = model(input_data)
Forward pre-hook called for module: Linear(in_features=10, out_features=10, bias=True)
Input shape: torch.Size([1, 10])

2.register_module_forward_hook

在 PyTorch 中,register_module_forward_hook 是一个方法,用于向模型的模块注册前向传播钩子(forward hook)。钩子是在模块的前向传播过程中被调用的函数,可以用于获取中间特征、对特征进行修改或记录等操作。

import torch
import torch.nn as nn

# 定义一个前向传播钩子函数
def forward_hook(module, input, output):
    print("Forward hook called for module:", module)
    print("Input shape:", input[0].shape)
    print("Output shape:", output.shape)

# 创建一个模型类
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = MyModel()

# 注册前向传播钩子
model.register_forward_hook(forward_hook)

# 输入数据
input_data = torch.randn(1, 10)

# 前向传播
output = model(input_data)
Forward hook called for module: Linear(in_features=10, out_features=10, bias=True)
Input shape: torch.Size([1, 10])
Output shape: torch.Size([1, 10])

3.register_module_backward_hook

在 PyTorch 中,register_module_backward_hook 是一个方法,用于向模型的模块注册反向传播钩子(backward hook)。钩子是在模块的反向传播过程中被调用的函数,可以用于获取梯度、对梯度进行修改或记录等操作。

import torch
import torch.nn as nn

# 定义一个反向传播钩子函数
def backward_hook(module, grad_input, grad_output):
    print("Backward hook called for module:", module)
    print("Grad input shape:", grad_input[0].shape)
    print("Grad output shape:", grad_output[0].shape)

# 创建一个模型类
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = MyModel()

# 注册反向传播钩子
model.register_backward_hook(backward_hook)

# 输入数据
input_data = torch.randn(1, 10)
target = torch.randn(1, 10)

# 前向传播和反向传播
output = model(input_data)
loss = nn.MSELoss()(output, target)
loss.backward()
Backward hook called for module: Linear(in_features=10, out_features=10, bias=True)
Grad input shape: torch.Size([1, 10])
Grad output shape: torch.Size([1, 10])

相关推荐

  1. <span style='color:red;'>Pytorch</span>

    Pytorch

    2024-06-14 21:46:02      51 阅读
  2. PyTorch

    2024-06-14 21:46:02       54 阅读
  3. PytorchPytorch入门基础

    2024-06-14 21:46:02       38 阅读
  4. 入门 PyTorch

    2024-06-14 21:46:02       64 阅读
  5. PyTorch】概述

    2024-06-14 21:46:02       55 阅读
  6. pytorch RNN

    2024-06-14 21:46:02       43 阅读
  7. Python:PyTorch

    2024-06-14 21:46:02       52 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-14 21:46:02       98 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-14 21:46:02       106 阅读
  3. 在Django里面运行非项目文件

    2024-06-14 21:46:02       87 阅读
  4. Python语言-面向对象

    2024-06-14 21:46:02       96 阅读

热门阅读

  1. SQL AND & OR 运算符的使用与区别

    2024-06-14 21:46:02       32 阅读
  2. 数据结构-单链表

    2024-06-14 21:46:02       38 阅读
  3. 人工智能在问题答疑领域的应用

    2024-06-14 21:46:02       31 阅读
  4. 从输入URL到页面加载完中间发生了什么?

    2024-06-14 21:46:02       26 阅读
  5. 优化SQL查询的策略和技巧 - AI提供

    2024-06-14 21:46:02       31 阅读
  6. 从 GPT2 到 ChatGPT

    2024-06-14 21:46:02       30 阅读
  7. sqlcoder:7b sqlcoder:15b sqlcoder:70b 有什么区别呢?

    2024-06-14 21:46:02       37 阅读
  8. Android RecyclerView使用

    2024-06-14 21:46:02       26 阅读
  9. C#面:抽象类和接口有什么异同

    2024-06-14 21:46:02       27 阅读