对vit(Vision Transformer)的注意力可视化。使用grad_cam方法

一、环境准备

注意安装包是pip install grad_cam而不是pytorch_grad_cam。一个是包名一个是导入名。之前发现怎么都装不上。

pip install "grad-cam==1.4.0"

导入时调用pytorch_grad_cam

```python
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus,AblationCAM, \
                            XGradCAM, EigenCAM, EigenGradCAM,LayerCAM,FullGrad
from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
import cv2
import numpy as np
import torch

二、加载预训练的vit模型(离线加载或在线加载)

因为网络问题,使用离线定义网络与离线加载模型方法,也可以在线torch.hub.load加载

#离线模型,模型定义具体省略
my_model =  models.init_models(myargs)#省略
model_pkl = "******/dino_finetune.pkl"#加载自己训练好的模型
my_model.load_state_dict(torch.load(model_pkl))
my_model.eval()
##在线模型加载
#my_model = torch.hub.load('facebookresearch/deit:main','deit_tiny_patch16_224', #pretrained=True)
#my_model.eval()
# 判断是否使用 GPU 加速
use_cuda = torch.cuda.is_available()
if use_cuda:
    my_model = my_model.cuda() #如果是gpu的话加速

三、选择目标层来计算grad_cam,由于 ViT 的最后一层只有类别标记对预测类别有影响,所以我们不能选择最后一层。我们可以选择倒数第二层中的任意一个 Transformer 编码器作为目标层。

#首先定义函数对vit输出的3维张量转换为传统卷积处理时的二维张量,gradcam需要。
#(B,H*W,feat_dim)转换到(B,C,H,W),其中H*W是分pathc数。具体参数根据自己模型情况
#我的输入为224*224,pathsize为(16*16),那么我的(H,W)就是(224/16,224/16),即14*14
def reshape_transform(tensor, height=14, width=14):
    # 去掉cls token
    result = tensor[:, 1:, :].reshape(tensor.size(0),
    height, width, tensor.size(2))
    # 将通道维度放到第一个位置
    result = result.transpose(2, 3).transpose(1, 2)
    return result
    
# 创建 GradCAM 对象
cam = GradCAM(model=model,
            target_layers=[model.blocks[-1].norm1],
            # 这里的target_layer要看模型情况,调试时自己打印下model吧
            # 比如还有可能是:target_layers = [model.blocks[-1].ffn.norm]
            # 或者target_layers = [model.blocks[-1].ffn.norm]
            use_cuda=use_cuda,
            reshape_transform=reshape_transform)

四、读入图片,预处理后送入网络。调用 cam 对象的 forward 方法,传入输入张量和预测类别(如果不指定,则默认为最高概率的类别),得到 Grad-CAM 的输出

# 读取输入图像
image_path = "xxx.jpg"
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
rgb_img = cv2.resize(rgb_img, (224, 224))

# 预处理图像
input_tensor = preprocess_image(rgb_img,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

# 看情况将图像转换为批量形式
# input_tensor = input_tensor.unsqueeze(0)
if use_cuda:
    input_tensor = input_tensor.cuda()

# 计算 grad-cam
target_category = None # 可以指定一个类别,或者使用 None 表示最高概率的类别
grayscale_cam = cam(input_tensor=input_tensor, targets=target_category)
grayscale_cam = grayscale_cam[0, :]

# 将 grad-cam 的输出叠加到原始图像上
#visualization = show_cam_on_image(rgb_img, grayscale_cam),借鉴的代码rgb格式不对,换成下面
visualization = show_cam_on_image(rgb_img.astype(dtype=np.float32)/255,grayscale_cam)

# 保存可视化结果
cv2.cvtColor(visualization, cv2.COLOR_RGB2BGR, visualization)
# cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB, visualization)#我个人代码用的bgr

cv2.imwrite('cam.jpg', visualization)

注意

报错一:如果grad-cam版本过高,会报错Grad-cam报错AttributeError: ‘GradCAM‘ object has no attribute ‘activations_and_grads‘:所以装1.4版本。

报错二、gradcam报错2 AttributeError: ‘list’ object has no attribute ‘cpu’,是因为grad_cam在通过分类层结果确认梯度贡献,而我的代码是做识别方向的,return的feat特征,而不是cls分类。改变网络最后return为分类层输出即可。同时如果用了circleproduct还要注意测试时是否传了label进入网络

报错三、RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn。是因为我为了省事,在工程文件的测试代码中加grad_cam可视化,然后with torch.no_grad()忘记注释了,导致grad_cam计算时没有梯度而报错。注释即可

注意是否红色和蓝色区域互换了,红色应该是注意力地方,如果反过来了,那就是图片rgb和bgr格式问题了

 cv2.cvtColor(visualization, cv2.COLOR_BGR2RGB, visualization)

#代码借鉴于:https://zhuanlan.zhihu.com/p/640450435

相关推荐

  1. 学习:WebGL基础使用

    2024-04-24 16:08:03       32 阅读
  2. 学习:WebGL基础使用

    2024-04-24 16:08:03       27 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-04-24 16:08:03       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-04-24 16:08:03       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-04-24 16:08:03       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-04-24 16:08:03       18 阅读

热门阅读

  1. 数字人成了大佬标配?再不上车就晚了

    2024-04-24 16:08:03       12 阅读
  2. 【动态规划】Leetcode 70. 爬楼梯【简单】

    2024-04-24 16:08:03       10 阅读
  3. Qt5中的常用模块

    2024-04-24 16:08:03       12 阅读
  4. 变电站综合监控系统系统组成分析

    2024-04-24 16:08:03       16 阅读
  5. 富格林:掌握鉴别阻挠虚假套路

    2024-04-24 16:08:03       12 阅读
  6. 5分钟快速搭建k8s集群1.29.x

    2024-04-24 16:08:03       13 阅读
  7. MySQL中的关键字深入比较:UNION vs UNION ALL

    2024-04-24 16:08:03       12 阅读
  8. 分组排序取第一条数据 SQL写法

    2024-04-24 16:08:03       11 阅读
  9. Redis 大KEY/慢查询问题的排查和解决

    2024-04-24 16:08:03       15 阅读
  10. flutter组件 InheritedWidget

    2024-04-24 16:08:03       15 阅读
  11. leetcode922-Sort Array By Parity II

    2024-04-24 16:08:03       11 阅读