YOLO11 改进 – 注意力机制 _ CoTAttention (Contextual Transformer Attention) 上下文转换器注意力:动静态上下文融合增强特征表征,优化多尺度目

前言

本文介绍了Contextual Transformer(CoT)块及其在YOLOv11中的结合应用。大多数现有Transformer风格架构设计未充分利用邻近键之间的上下文信息,而CoT块通过3×3卷积对输入键进行上下文编码得到静态表示,将编码后的键与输入查询连接,经两个连续的1×1卷积学习动态多头注意力矩阵,与输入值相乘得到动态表示,最终融合二者作为输出。基于此块构建的CoTNet可替换ResNet架构中的3×3卷积。我们将CoTAttention集成进YOLOv11,大量实验验证了CoTNet作为骨干网络的优越性。

文章目录: YOLOv11改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLOv11改进专栏

文章目录

[TOC]

介绍

摘要

Transformer自注意力机制已在自然语言处理领域引发革命性变革,并近期推动了Transformer风格架构在多种计算机视觉任务中取得竞争性成果。然而,现有方法大多直接在二维特征图上应用自注意力机制,基于各空间位置的孤立查询-键对计算注意力矩阵,未能充分利用邻近键值间的丰富上下文信息。针对此局限性,本研究设计了一种新颖的Transformer风格模块——上下文Transformer(Contextual Transformer, CoT)块,用于视觉识别任务。该设计充分挖掘输入键值间的上下文信息,引导动态注意力矩阵的学习过程,从而显著增强视觉表征能力。在技术实现上,CoT块首先通过3×3卷积对输入键进行上下文编码,生成静态上下文表示;随后将编码后的键与输入查询连接,经由两个连续的1×1卷积学习动态多头注意力矩阵;所得注意力矩阵与输入值相乘后产生动态上下文表示;最终融合静态与动态上下文表示作为模块输出。所提出的CoT模块具备高度实用性,可便捷替换ResNet架构中的3×3卷积层,构建出名为Contextual Transformer Networks(CoTNet)的Transformer风格骨干网络。通过在大规模图像识别、目标检测及实例分割等应用场景中的系统实验,验证了CoTNet作为更强骨干网络的优越性能。相关源代码已公开于https://github.com/JDAI-CV/CoTNet。

文章链接

论文地址: 论文地址

代码地址: 代码地址

基本原理

CoTNet是一种基于Contextual Transformer(CoT)模块的网络结构,其原理如下:

  1. CoTNet原理:

    • CoTNet采用Contextual Transformer(CoT)模块作为构建块,用于替代传统的卷积操作。

    • CoT模块利用3×3卷积来对输入键之间的上下文信息进行编码,生成静态上下文表示。

    • 将编码后的键与输入查询连接,通过两个连续的1×1卷积来学习动态多头注意力矩阵。

    • 学习到的注意力矩阵用于聚合所有输入数值,生成动态上下文表示。

    • 最终将静态和动态上下文表示融合作为输出。

  2. Contextual Transformer Attention在CoTNet中的作用和原理:

    Contextual Transformer Attention是Contextual Transformer(CoT)模块中的关键组成部分,用于引导动态学习注意力矩阵,从而增强视觉表示并提高计算机视觉任务的性能

    • Contextual Transformer Attention是CoT模块中的注意力机制,用于引导动态学习注意力矩阵。

    • 通过Contextual Transformer Attention,模型能够充分利用输入键之间的上下文信息,从而更好地捕捉动态关系。

    • 这种注意力机制有助于增强视觉表示,并提高计算机视觉任务的性能。

    • CoTNet通过整合Contextual Transformer Attention,实现了同时进行上下文挖掘和自注意力学习的优势,从而提升了深度网络的表征能力。

核心代码

import torch
from torch import flatten, nn
from torch.nn import functional as F

class CoTAttention(nn.Module):
    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim  # 输入通道数
        self.kernel_size = kernel_size  # 卷积核大小

        # 关键信息嵌入层,使用分组卷积提取特征
        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),  # 归一化层
            nn.ReLU()  # 激活函数
        )
        # 值信息嵌入层,使用1x1卷积进行特征转换
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)  # 归一化层
        )

        # 注意力机制嵌入层,先降维后升维,最终输出与卷积核大小和通道数相匹配的特征
        factor = 4  # 降维比例
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),  # 归一化层
            nn.ReLU(),  # 激活函数
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)  # 升维匹配卷积核形状
        )

    def forward(self, x):
        bs, c, h, w = x.shape  # 输入特征的尺寸
        k1 = self.key_embed(x)  # 应用关键信息嵌入
        v = self.value_embed(x).view(bs, c, -1)  # 应用值信息嵌入,并展平

        y = torch.cat([k1, x], dim=1)  # 将关键信息和原始输入在通道维度上拼接
        att = self.attention_embed(y)  # 应用注意力机制嵌入层
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # 计算平均后展平

        k2 = F.softmax(att, dim=-1) * v  # 应用softmax进行标准化,并与值信息相乘
        k2 = k2.view(bs, c, h, w)  # 重塑形状与输入相同

        return k1 + k2  # 将两部分信息相加并返回

实验

脚本

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
#     修改为自己的配置文件地址
    model = YOLO('/root/ultralytics-main/ultralytics/cfg/models/11/yolov11-CoTAttention.yaml')
#     修改为自己的数据集地址
    model.train(data='/root/ultralytics-main/ultralytics/cfg/datasets/coco8.yaml',
                cache=False,
                imgsz=640,
                epochs=10,
                single_cls=False,  # 是否是单类别检测
                batch=8,
                close_mosaic=10,
                workers=0,
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='CoTAttention',
                )

结果

THE END