YOLO11 改进 – 注意力机制 _ DAT (Deformable Attention) 可变形注意力:动态感知关键区域破解固定注意力模式,增强特征捕捉能力

前言

本文介绍了Deformable Attention Transformer(DAT)及其在YOLOv11中的结合应用。DAT是一种用于图像分类和密集预测任务的通用主干模型,其核心是可变形自注意力模块,通过数据依赖的位置选择、灵活的偏移学习、全局键共享和空间自适应机制,解决了传统Transformer在扩大感受野时带来的内存、计算成本高和特征受无关部分影响等问题。该模块通过对键和值进行DCN偏移后再计算注意力,提升了性能。我们将DAttention模块集成进YOLOv11,替代部分原有模块。实验表明,DAT在综合基准测试中取得了持续改进的结果。

文章目录: YOLOv11改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLOv11改进专栏

文章目录

[TOC]

介绍

摘要

Transformer架构近年来在各类视觉任务中展现出卓越性能表现,其较大的甚至全局感受野赋予模型相较于卷积神经网络(CNN)更强的表征学习能力。然而,单纯扩展感受野规模亦引发若干关键问题:一方面,采用密集注意力机制(如ViT中所用)将导致显著的内存与计算开销,且特征提取过程易受目标区域外无关信息的干扰;另一方面,PVT或Swin Transformer等架构采用的稀疏注意力机制对数据分布不敏感,可能制约其长距离依赖关系建模能力。针对上述挑战,本文提出一种新型可变形自注意力模块,该模块基于数据驱动方式自适应选择自注意力中键值对的位置分布。这种灵活的设计策略使自注意力机制能够动态聚焦于语义相关区域,从而捕获更具信息量的特征表示。基于此创新模块,我们构建了Deformable Attention Transformer(可变形注意力Transformer)架构,该通用骨干网络适用于图像分类及密集预测任务,并集成可变形注意力机制。大量实验验证表明,所提出模型在综合性基准测试中实现了持续性的性能提升,相关代码已在https://github.com/LeapLabTHU/DAT开源发布。

文章链接

论文地址: 论文地址

代码地址: 代码地址

参考代码: 代码地址

基本原理

关键

  1. 数据依赖的位置选择 :Deformable Attention允许在自注意力机制中以数据依赖的方式选择键和值对的位置,使模型能够根据输入数据动态调整注意力的焦点。

  2. 灵活的偏移学习 :通过学习偏移量,Deformable Attention可以将关键点和值移动到重要区域,从而提高模型对关键特征的捕获能力。

  3. 全局键共享 :Deformable Attention学习一组全局键,这些键在不同的视觉标记之间共享,有助于模型捕获长距离的相关性。

  4. 空间自适应机制 :Deformable Attention可以根据输入数据的特征动态调整注意力模式,从而适应不同的视觉任务和场景。

通过相对于Swin-Transformer和PVT的改进,加入了可变形机制,同时控制网络不增加太多的计算量。作者认为,缩小q对应的k的范围,能够减少无关信息的干扰,增强信息的捕捉,于是引入了DCN机制到注意力模块中,提出了一种新的注意力模块:可变形多头注意力模块。该模块通过对k和v进行DCN偏移后再计算注意力,从而提升了性能。

在可变形多头注意力模块中,输入特征图像 $x \in \mathbb{R}^{H \times W \times C}$ 生成一个参考网格,其中参考点 $p \in \mathbb{R}^{H_G \times W_G \times 2}$ 。该网格是从输入特征图 $x$ 降采样而来,降采样系数为 $r$ , $H_G = H / r, W_G = W / r$ 。参考点的值代表的是坐标值 $(0, 0), \ldots, (H_G - 1, W_G - 1)$ ,再归一化到 $[-1, +1]$ 。

输入特征图像 $x$ 通过线性投影得到 $q = x Wq$ ,再输入到一个轻量级子网络offset network,生成偏移量 $\Delta p = \theta{\text{offset}}(q)$ 。为了稳定训练过程,使用了一些预定义的因子来衡量 $\Delta p$ 的振幅,以防止太大的offset,即 $\Delta p \leftarrow \text{sinh}(\Delta p)$ 。

然后将获得的offset作用在参考点上,获得变形点的位置,进行特征采样(双线性插值)得到 $\hat{x}$ ,再通过投影矩阵生成Key和Value, $\hat{k} = \hat{x} W_k, \hat{v} = \hat{x} W_v$ 。

$qkv$ 进行多头注意力计算,同时加入相对位置偏移嵌入。最后将获得的多头特征拼接起来,通过投影矩阵获得最终的注意力模块输出 $Z$ 。

核心代码

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from timm.models.layers import to_2tuple, trunc_normal_

# 定义一个LayerNormProxy类,继承自nn.Module
class LayerNormProxy(nn.Module):

    def __init__(self, dim):
        # 初始化函数,传入参数dim为输入张量的维度
        super().__init__()
        self.norm = nn.LayerNorm(dim)  # 定义LayerNorm层

    def forward(self, x):
        # 前向传播函数
        x = einops.rearrange(x, 'b c h w -> b h w c')  # 重排列tensor的维度,将通道维度移到最后
        x = self.norm(x)  # 进行LayerNorm操作
        return einops.rearrange(x, 'b h w c -> b c h w')  # 将维度恢复原状

# 定义一个DAttentionBaseline类,继承自nn.Module
class DAttentionBaseline(nn.Module):

    def __init__(
        self, q_size, kv_size, n_heads, n_head_channels, n_groups,
        attn_drop, proj_drop, stride, 
        offset_range_factor, use_pe, dwc_pe,
        no_off, fixed_pe, ksize, log_cpb
    ):
        # 初始化函数,定义了所需的参数
        super().__init__()
        self.dwc_pe = dwc_pe  # 是否使用深度卷积位置编码
        self.n_head_channels = n_head_channels  # 每个头的通道数
        self.scale = self.n_head_channels ** -0.5  # 缩放因子,等于每个头的通道数的负0.5次方
        self.n_heads = n_heads  # 多头注意力机制中的头数
        self.q_h, self.q_w = q_size  # query的高和宽
        self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride  # 计算键值对的高和宽
        self.nc = n_head_channels * n_heads  # 总的通道数
        self.n_groups = n_groups  # 分组数
        self.n_group_channels = self.nc // self.n_groups  # 每组的通道数
        self.n_group_heads = self.n_heads // self.n_groups  # 每组的头数
        self.use_pe = use_pe  # 是否使用位置编码
        self.fixed_pe = fixed_pe  # 是否使用固定的位置编码
        self.no_off = no_off  # 是否禁用偏移
        self.offset_range_factor = offset_range_factor  # 偏移范围因子
        self.ksize = ksize  # 卷积核尺寸
        self.log_cpb = log_cpb  # 是否使用对数相对位置偏置
        self.stride = stride  # 步幅
        kk = self.ksize
        pad_size = kk // 2 if kk != stride else 0  # 计算填充大小

        # 定义卷积偏移网络
        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),  # 使用LayerNormProxy进行归一化
            nn.GELU(),  # 使用GELU激活函数
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)  # 输出偏移量
        )
        if self.no_off:
            for m in self.conv_offset.parameters():
                m.requires_grad_(False)  # 如果不使用偏移,禁用偏移网络的参数更新

        # 定义投影层
        self.proj_q = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0  # query投影
        )

        self.proj_k = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0  # key投影
        )

        self.proj_v = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0  # value投影
        )

        self.proj_out = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0  # 输出投影
        )

        self.proj_drop = nn.Dropout(proj_drop, inplace=True)  # 投影层的Dropout
        self.attn_drop = nn.Dropout(attn_drop, inplace=True)  # 注意力层的Dropout

        # 相对位置嵌入的定义
        if self.use_pe and not self.no_off:
            if self.dwc_pe:
                self.rpe_table = nn.Conv2d(
                    self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)  # 深度卷积位置编码
            elif self.fixed_pe:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
                )
                trunc_normal_(self.rpe_table, std=0.01)  # 截断正态分布初始化
            elif self.log_cpb:
                # 借用自Swin-V2
                self.rpe_table = nn.Sequential(
                    nn.Linear(2, 32, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Linear(32, self.n_group_heads, bias=False)
                )
            else:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
                )
                trunc_normal_(self.rpe_table, std=0.01)  # 截断正态分布初始化
        else:
            self.rpe_table = None

    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):
        # 获取参考点
        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
            indexing='ij'  # 保持矩阵索引一致
        )
        ref = torch.stack((ref_y, ref_x), -1)  # 堆叠y和x坐标
        ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
        ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # 扩展维度,适应批量和分组

        return ref

    @torch.no_grad()
    def _get_q_grid(self, H, W, B, dtype, device):
        # 获取query网格
        ref_y, ref_x = torch.meshgrid(
            torch.arange(0, H, dtype=dtype, device=device),
            torch.arange(0, W, dtype=dtype, device=device),
            indexing='ij'  # 保持矩阵索引一致
        )
        ref = torch.stack((ref_y, ref_x), -1)  # 堆叠y和x坐标
        ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
        ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # 扩展维度,适应批量和分组

        return ref

    def forward(self, x):
        # 前向传播函数
        B, C, H, W = x.size()  # 获取输入的尺寸
        dtype, device = x.dtype, x.device

        q = self.proj_q(x)  # 对输入x进行query投影
        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)  # 重排列tensor的维度
        offset = self.conv_offset(q_off).contiguous()  # 计算偏移量
        Hk, Wk = offset.size(2), offset.size(3)  # 获取偏移量的高和宽
        n_sample = Hk * Wk  # 计算采样点数量

        if self.offset_range_factor >= 0 and not self.no_off:
            offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, 'b p h w -> b h w p')
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)

        if self.no_off:
            offset = offset.fill_(0.0)

        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).clamp(-1., +1.)

        if self.no_off:
            x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
            assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
        else:
            x_sampled = F.grid_sample(
                input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), 
                grid=pos[..., (1, 0)],  # y, x -> x, y
                mode='bilinear', align_corners=True)  # 进行双线性插值采样

        x_sampled = x_sampled.reshape(B, C, 1, n_sample)

        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

        attn = torch.einsum('b c m, b c n -> b m n', q, k)  # 计算注意力权重
        attn = attn.mul(self.scale)

        if self.use_pe and (not self.no_off):
            if self.dwc_pe:
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)
            elif self.fixed_pe:
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
            elif self.log_cpb:
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(4.0)  # 计算位移
                displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
                attn_bias = self.rpe_table(displacement)  # 计算相对位置嵌入偏置
                attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
            else:
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)
                attn_bias = F.grid_sample(
                    input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True)  # 双线性插值计算相对位置偏置

                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)  # 对注意力权重进行softmax
        attn = self.attn_drop(attn)

        out = torch.einsum('b m n, b c n -> b c m', attn, v)  # 计算注意力输出

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)

        y = self.proj_drop(self.proj_out(out))  # 投影输出并进行Dropout

        return y, pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

实验

脚本

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
#     修改为自己的配置文件地址
    model = YOLO('/root/ultralytics-main/ultralytics/cfg/models/11/yolov11-DAttention.yaml')
#     修改为自己的数据集地址
    model.train(data='/root/ultralytics-main/ultralytics/cfg/datasets/coco8.yaml',
                cache=False,
                imgsz=640,
                epochs=10,
                single_cls=False,  # 是否是单类别检测
                batch=8,
                close_mosaic=10,
                workers=0,
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='DAttention',
                )
THE END