YOLO11 改进 – 注意力机制 _ DAT (Deformable Attention) 可变形注意力:动态感知关键区域破解固定注意力模式,增强特征捕捉能力
前言
本文介绍了Deformable Attention Transformer(DAT)及其在YOLOv11中的结合应用。DAT是一种用于图像分类和密集预测任务的通用主干模型,其核心是可变形自注意力模块,通过数据依赖的位置选择、灵活的偏移学习、全局键共享和空间自适应机制,解决了传统Transformer在扩大感受野时带来的内存、计算成本高和特征受无关部分影响等问题。该模块通过对键和值进行DCN偏移后再计算注意力,提升了性能。我们将DAttention模块集成进YOLOv11,替代部分原有模块。实验表明,DAT在综合基准测试中取得了持续改进的结果。
文章目录: YOLOv11改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总
专栏链接: YOLOv11改进专栏
文章目录
[TOC]
介绍
摘要
Transformer架构近年来在各类视觉任务中展现出卓越性能表现,其较大的甚至全局感受野赋予模型相较于卷积神经网络(CNN)更强的表征学习能力。然而,单纯扩展感受野规模亦引发若干关键问题:一方面,采用密集注意力机制(如ViT中所用)将导致显著的内存与计算开销,且特征提取过程易受目标区域外无关信息的干扰;另一方面,PVT或Swin Transformer等架构采用的稀疏注意力机制对数据分布不敏感,可能制约其长距离依赖关系建模能力。针对上述挑战,本文提出一种新型可变形自注意力模块,该模块基于数据驱动方式自适应选择自注意力中键值对的位置分布。这种灵活的设计策略使自注意力机制能够动态聚焦于语义相关区域,从而捕获更具信息量的特征表示。基于此创新模块,我们构建了Deformable Attention Transformer(可变形注意力Transformer)架构,该通用骨干网络适用于图像分类及密集预测任务,并集成可变形注意力机制。大量实验验证表明,所提出模型在综合性基准测试中实现了持续性的性能提升,相关代码已在https://github.com/LeapLabTHU/DAT开源发布。
文章链接
论文地址: 论文地址
代码地址: 代码地址
参考代码: 代码地址
基本原理
关键
-
数据依赖的位置选择 :Deformable Attention允许在自注意力机制中以数据依赖的方式选择键和值对的位置,使模型能够根据输入数据动态调整注意力的焦点。
-
灵活的偏移学习 :通过学习偏移量,Deformable Attention可以将关键点和值移动到重要区域,从而提高模型对关键特征的捕获能力。
-
全局键共享 :Deformable Attention学习一组全局键,这些键在不同的视觉标记之间共享,有助于模型捕获长距离的相关性。
-
空间自适应机制 :Deformable Attention可以根据输入数据的特征动态调整注意力模式,从而适应不同的视觉任务和场景。
通过相对于Swin-Transformer和PVT的改进,加入了可变形机制,同时控制网络不增加太多的计算量。作者认为,缩小q对应的k的范围,能够减少无关信息的干扰,增强信息的捕捉,于是引入了DCN机制到注意力模块中,提出了一种新的注意力模块:可变形多头注意力模块。该模块通过对k和v进行DCN偏移后再计算注意力,从而提升了性能。
在可变形多头注意力模块中,输入特征图像 $x \in \mathbb{R}^{H \times W \times C}$ 生成一个参考网格,其中参考点 $p \in \mathbb{R}^{H_G \times W_G \times 2}$ 。该网格是从输入特征图 $x$ 降采样而来,降采样系数为 $r$ , $H_G = H / r, W_G = W / r$ 。参考点的值代表的是坐标值 $(0, 0), \ldots, (H_G - 1, W_G - 1)$ ,再归一化到 $[-1, +1]$ 。
输入特征图像 $x$ 通过线性投影得到 $q = x Wq$ ,再输入到一个轻量级子网络offset network,生成偏移量 $\Delta p = \theta{\text{offset}}(q)$ 。为了稳定训练过程,使用了一些预定义的因子来衡量 $\Delta p$ 的振幅,以防止太大的offset,即 $\Delta p \leftarrow \text{sinh}(\Delta p)$ 。
然后将获得的offset作用在参考点上,获得变形点的位置,进行特征采样(双线性插值)得到 $\hat{x}$ ,再通过投影矩阵生成Key和Value, $\hat{k} = \hat{x} W_k, \hat{v} = \hat{x} W_v$ 。
$qkv$ 进行多头注意力计算,同时加入相对位置偏移嵌入。最后将获得的多头特征拼接起来,通过投影矩阵获得最终的注意力模块输出 $Z$ 。
核心代码
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
from timm.models.layers import to_2tuple, trunc_normal_
# 定义一个LayerNormProxy类,继承自nn.Module
class LayerNormProxy(nn.Module):
def __init__(self, dim):
# 初始化函数,传入参数dim为输入张量的维度
super().__init__()
self.norm = nn.LayerNorm(dim) # 定义LayerNorm层
def forward(self, x):
# 前向传播函数
x = einops.rearrange(x, 'b c h w -> b h w c') # 重排列tensor的维度,将通道维度移到最后
x = self.norm(x) # 进行LayerNorm操作
return einops.rearrange(x, 'b h w c -> b c h w') # 将维度恢复原状
# 定义一个DAttentionBaseline类,继承自nn.Module
class DAttentionBaseline(nn.Module):
def __init__(
self, q_size, kv_size, n_heads, n_head_channels, n_groups,
attn_drop, proj_drop, stride,
offset_range_factor, use_pe, dwc_pe,
no_off, fixed_pe, ksize, log_cpb
):
# 初始化函数,定义了所需的参数
super().__init__()
self.dwc_pe = dwc_pe # 是否使用深度卷积位置编码
self.n_head_channels = n_head_channels # 每个头的通道数
self.scale = self.n_head_channels ** -0.5 # 缩放因子,等于每个头的通道数的负0.5次方
self.n_heads = n_heads # 多头注意力机制中的头数
self.q_h, self.q_w = q_size # query的高和宽
self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride # 计算键值对的高和宽
self.nc = n_head_channels * n_heads # 总的通道数
self.n_groups = n_groups # 分组数
self.n_group_channels = self.nc // self.n_groups # 每组的通道数
self.n_group_heads = self.n_heads // self.n_groups # 每组的头数
self.use_pe = use_pe # 是否使用位置编码
self.fixed_pe = fixed_pe # 是否使用固定的位置编码
self.no_off = no_off # 是否禁用偏移
self.offset_range_factor = offset_range_factor # 偏移范围因子
self.ksize = ksize # 卷积核尺寸
self.log_cpb = log_cpb # 是否使用对数相对位置偏置
self.stride = stride # 步幅
kk = self.ksize
pad_size = kk // 2 if kk != stride else 0 # 计算填充大小
# 定义卷积偏移网络
self.conv_offset = nn.Sequential(
nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
LayerNormProxy(self.n_group_channels), # 使用LayerNormProxy进行归一化
nn.GELU(), # 使用GELU激活函数
nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False) # 输出偏移量
)
if self.no_off:
for m in self.conv_offset.parameters():
m.requires_grad_(False) # 如果不使用偏移,禁用偏移网络的参数更新
# 定义投影层
self.proj_q = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0 # query投影
)
self.proj_k = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0 # key投影
)
self.proj_v = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0 # value投影
)
self.proj_out = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0 # 输出投影
)
self.proj_drop = nn.Dropout(proj_drop, inplace=True) # 投影层的Dropout
self.attn_drop = nn.Dropout(attn_drop, inplace=True) # 注意力层的Dropout
# 相对位置嵌入的定义
if self.use_pe and not self.no_off:
if self.dwc_pe:
self.rpe_table = nn.Conv2d(
self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc) # 深度卷积位置编码
elif self.fixed_pe:
self.rpe_table = nn.Parameter(
torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
)
trunc_normal_(self.rpe_table, std=0.01) # 截断正态分布初始化
elif self.log_cpb:
# 借用自Swin-V2
self.rpe_table = nn.Sequential(
nn.Linear(2, 32, bias=True),
nn.ReLU(inplace=True),
nn.Linear(32, self.n_group_heads, bias=False)
)
else:
self.rpe_table = nn.Parameter(
torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
)
trunc_normal_(self.rpe_table, std=0.01) # 截断正态分布初始化
else:
self.rpe_table = None
@torch.no_grad()
def _get_ref_points(self, H_key, W_key, B, dtype, device):
# 获取参考点
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
indexing='ij' # 保持矩阵索引一致
)
ref = torch.stack((ref_y, ref_x), -1) # 堆叠y和x坐标
ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0) # 归一化到[-1, 1]范围
ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0) # 归一化到[-1, 1]范围
ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # 扩展维度,适应批量和分组
return ref
@torch.no_grad()
def _get_q_grid(self, H, W, B, dtype, device):
# 获取query网格
ref_y, ref_x = torch.meshgrid(
torch.arange(0, H, dtype=dtype, device=device),
torch.arange(0, W, dtype=dtype, device=device),
indexing='ij' # 保持矩阵索引一致
)
ref = torch.stack((ref_y, ref_x), -1) # 堆叠y和x坐标
ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0) # 归一化到[-1, 1]范围
ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0) # 归一化到[-1, 1]范围
ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # 扩展维度,适应批量和分组
return ref
def forward(self, x):
# 前向传播函数
B, C, H, W = x.size() # 获取输入的尺寸
dtype, device = x.dtype, x.device
q = self.proj_q(x) # 对输入x进行query投影
q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels) # 重排列tensor的维度
offset = self.conv_offset(q_off).contiguous() # 计算偏移量
Hk, Wk = offset.size(2), offset.size(3) # 获取偏移量的高和宽
n_sample = Hk * Wk # 计算采样点数量
if self.offset_range_factor >= 0 and not self.no_off:
offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)
offset = einops.rearrange(offset, 'b p h w -> b h w p')
reference = self._get_ref_points(Hk, Wk, B, dtype, device)
if self.no_off:
offset = offset.fill_(0.0)
if self.offset_range_factor >= 0:
pos = offset + reference
else:
pos = (offset + reference).clamp(-1., +1.)
if self.no_off:
x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
else:
x_sampled = F.grid_sample(
input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
grid=pos[..., (1, 0)], # y, x -> x, y
mode='bilinear', align_corners=True) # 进行双线性插值采样
x_sampled = x_sampled.reshape(B, C, 1, n_sample)
q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
attn = torch.einsum('b c m, b c n -> b m n', q, k) # 计算注意力权重
attn = attn.mul(self.scale)
if self.use_pe and (not self.no_off):
if self.dwc_pe:
residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)
elif self.fixed_pe:
rpe_table = self.rpe_table
attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
elif self.log_cpb:
q_grid = self._get_q_grid(H, W, B, dtype, device)
displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(4.0) # 计算位移
displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
attn_bias = self.rpe_table(displacement) # 计算相对位置嵌入偏置
attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
else:
rpe_table = self.rpe_table
rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
q_grid = self._get_q_grid(H, W, B, dtype, device)
displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)
attn_bias = F.grid_sample(
input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),
grid=displacement[..., (1, 0)],
mode='bilinear', align_corners=True) # 双线性插值计算相对位置偏置
attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
attn = attn + attn_bias
attn = F.softmax(attn, dim=2) # 对注意力权重进行softmax
attn = self.attn_drop(attn)
out = torch.einsum('b m n, b c n -> b c m', attn, v) # 计算注意力输出
if self.use_pe and self.dwc_pe:
out = out + residual_lepe
out = out.reshape(B, C, H, W)
y = self.proj_drop(self.proj_out(out)) # 投影输出并进行Dropout
return y, pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)
实验
脚本
import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
if __name__ == '__main__':
# 修改为自己的配置文件地址
model = YOLO('/root/ultralytics-main/ultralytics/cfg/models/11/yolov11-DAttention.yaml')
# 修改为自己的数据集地址
model.train(data='/root/ultralytics-main/ultralytics/cfg/datasets/coco8.yaml',
cache=False,
imgsz=640,
epochs=10,
single_cls=False, # 是否是单类别检测
batch=8,
close_mosaic=10,
workers=0,
optimizer='SGD',
amp=True,
project='runs/train',
name='DAttention',
)