域适应/泛化中的GRL与NP方法

特征可视化:https://github.com/jindongwang/transferlearning/blob/aa6ee9fcbc7e2fa75ce6fb5f2f60d8193e1c6894/code/utils/feature_vis.py


一、GRL

绿色是特征提取,蓝色是分类损失,红色是域损失,因为目标是无法区分特征来自哪个域,所以域损失要最大化,叫梯度反转。
在这里插入图片描述

import torch

class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, eta=1.0):
        ctx.eta = eta
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return (grad_output * -ctx.eta), None


def grad_reverse(x, eta=1.0):
    return GradReverse.apply(x, eta)

这段代码定义了一个自定义的梯度反转层,主要用于深度学习中的对抗性训练或者域适应等任务,其中我们可能希望在反向传播时改变某个层的梯度方向或大小。
下面是对代码的详细解释:

  1. 导入必要的库:
import torch

这里导入了PyTorch库,一个流行的深度学习框架。

2.定义GradReverse类:
这个类继承自torch.autograd.Function,允许我们自定义一个具有特定前向和反向传播行为的层。

class GradReverse(torch.autograd.Function):
* forward方法:
@staticmethod
def forward(ctx, x, eta=1.0):
    ctx.eta = eta
    return x
#在forward方法中,我们仅仅保存了传入的eta值(用于控制梯度反转的程度)到上下文ctx中,并直接返回了输入x。这意味着在前向传播时,这个层不会对数据产生任何影响。
#* backward方法:
@staticmethod
def backward(ctx, grad_output):
    return (grad_output * -ctx.eta), None

在backward方法中,我们定义了该层的反向传播行为。当上游的梯度grad_output传到这里时,我们会用-ctx.eta来乘以这个梯度,从而实现梯度的反转(如果eta为正)和可能的缩放。None表示没有关于eta的梯度,因为我们通常不希望更新这个参数。

3.定义grad_reverse函数:

def grad_reverse(x, eta=1.0):
    return GradReverse.apply(x, eta)

这个函数为上述定义的GradReverse类提供了一个更简洁的接口。用户可以直接调用这个函数,而无需明确地使用GradReverse.apply方法。x是输入数据,eta是控制梯度反转程度的参数(其默认值为1.0,意味着梯度会完全反转)。

总结:这段代码实现了一个自定义的梯度反转层。在前向传播时,该层不会对数据产生任何影响;但在反向传播时,它会将传入的梯度乘以-eta,从而实现梯度的反转和可能的缩放。


二、NP

论文:《Towards Robust Object Detection Invariant to Real-world Domain Shifts》

自动驾驶等安全关键应用需要对真实世界的领域变化保持不变的鲁棒对象检测。这种转变可以被视为不同的领域风格,由于环境的变化,这些风格可能会有很大的变化,但深度模型只知道训练域风格。这种领域风格的差距阻碍了对象检测在不同现实世界领域的泛化。

现有的分类域泛化(DG)方法不能有效地解决鲁棒目标检测问题,因为它们要么依赖于具有大风格方差的多个源域,要么破坏了原始图像的内容结构。在本文中,我们分析并研究了在没有上述缺点的情况下克服域式过拟合的鲁棒目标检测的有效解决方案。我们的方法被称为归一化扰动(NP),它扰动源域低级特征的通道统计,以合成各种潜在风格,使训练后的深度模型能够感知不同的潜在域,并在训练中即使没有观察到目标域数据的情况下也能很好地泛化。该方法的动机是观察到目标域图像的特征通道统计数据偏离源域统计数据。

归一化扰动仅依赖于单个源域,并且出奇地简单有效,通过有效地将分类DG方法应用于鲁棒对象检测,提供了实用的解决方案。在这里插入图片描述
在这里插入图片描述

相关推荐

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-18 06:32:03       66 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-18 06:32:03       70 阅读
  3. 在Django里面运行非项目文件

    2024-07-18 06:32:03       57 阅读
  4. Python语言-面向对象

    2024-07-18 06:32:03       68 阅读

热门阅读

  1. SpringSecurity + JWT 实现登录认证

    2024-07-18 06:32:03       15 阅读
  2. vue路由的钩子函数

    2024-07-18 06:32:03       24 阅读
  3. Socket、WebSocket 和 MQTT 的区别

    2024-07-18 06:32:03       22 阅读
  4. 深入探讨SQL Server端口设置:理论与实践

    2024-07-18 06:32:03       24 阅读
  5. kafka判断生产者是否向kafka集群成功发送消息

    2024-07-18 06:32:03       24 阅读
  6. mysql 安装配置 next 按钮为什么置灰点击不了

    2024-07-18 06:32:03       22 阅读