(超详细)7-YOLOV5改进-添加 CoTAttention注意力机制

1、在yolov5/models下面新建一个CoTAttention.py文件,在里面放入下面的代码
在这里插入图片描述

代码如下:

import numpy as np
import torch
from torch import flatten, nn
from torch.nn import init
from torch.nn.modules.activation import ReLU
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import functional as F


class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.ReLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2

2、找到yolo.py文件,进行更改内容
在29行加一个from models.CoTAttention import CoTAttention, 保存即可
在这里插入图片描述

3、找到自己想要更改的yaml文件,我选择的yolov5s.yaml文件(你可以根据自己需求进行选择),将刚刚写好的模块CoTAttention加入到yolov5s.yaml里面,并更改一些内容。更改如下
在这里插入图片描述

4、在yolo.py里面加入两行代码(335-337)
保存即可!
在这里插入图片描述
运行
在这里插入图片描述

最近更新

  1. TCP协议是安全的吗?

    2024-01-21 15:42:02       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-01-21 15:42:02       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-01-21 15:42:02       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-01-21 15:42:02       18 阅读

热门阅读

  1. Mybatis 44_调用传出参数是游标引用的存储过程

    2024-01-21 15:42:02       33 阅读
  2. Gin之gin介绍和安装

    2024-01-21 15:42:02       49 阅读
  3. Ubuntu-MarkText安装使用

    2024-01-21 15:42:02       39 阅读
  4. [go] 迭代器模式

    2024-01-21 15:42:02       32 阅读
  5. MVC的设计理念

    2024-01-21 15:42:02       35 阅读
  6. 野指针(C语言)

    2024-01-21 15:42:02       30 阅读
  7. rust嵌入式之用类函数宏简写状态机定义

    2024-01-21 15:42:02       31 阅读
  8. 小程序定制开发流程

    2024-01-21 15:42:02       35 阅读
  9. HTTP 第二章 发展历史

    2024-01-21 15:42:02       31 阅读
  10. Could not load library libcudnn_cnn_infer.so.8

    2024-01-21 15:42:02       35 阅读