YOLO26改进 – C3k2 C3k2融合HMHA分层多头注意力机制:优化模型在复杂场景下的目标感知能力 CVPR 2025
# 前言
本文介绍了分层多头注意力驱动的Transformer模型HINT中的核心模块HMHA,并将其集成到YOLO26中。传统多头注意力机制(MHA)存在冗余问题,HMHA通过“通道重排序+分层子空间划分”,使注意力头在不同子空间学习,避免冗余,提取多样化上下文特征。其流程包括通道重排序、分层子空间划分与注意力计算、特征聚合三步。我们将HMHA集成到YOLO26,构建C3k2_HMHA模块,经注册和配置yaml文件后进行实验,以验证其在目标检测任务中的有效性。
文章目录: YOLO26改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总
专栏链接: YOLO26改进专栏
介绍

摘要
Transformer-based 方法在图像修复任务中受到了广泛关注,其中核心组件——多头注意力机制(Multi-Head Attention, MHA)在提取多样特征和恢复高质量图像方面发挥了关键作用。在 MHA 中,各个注意力头在统一划分的子空间上独立计算,这种方式会引发冗余问题,从而影响模型的最终表现。
为了解决这一问题,本文提出了一种改进的多头注意力机制,通过引入多样化的学习器和头之间的多种交互方式,构建了一个分层多头注意力驱动的 Transformer 模型,命名为 HINT,用于图像修复任务。HINT 包含两个核心模块:分层多头注意力模块(HMHA) 和 查询-键缓存更新模块(QKCU),旨在从根本上缓解传统 MHA 中的冗余问题。
具体而言,HMHA 通过使注意力头在不同规模、包含不同信息的子空间中进行学习,提取更加多样的上下文特征。而 QKCU 模块则包括层内与层间的更新机制,进一步通过增强注意力头之间的交互,有效降低冗余。
我们在图像低光增强、去雾、去雪、去噪和去雨这五类图像修复任务中的 12 个基准数据集上进行了大量实验,结果充分证明了 HINT 的优越性。
源代码将开放获取,地址为:https://github.com/joshyZhou/HINT。
创新
HMHA 是 HINT 模型的核心模块,核心目标是解决传统 MHA(多头注意力)的冗余问题——传统 MHA 中所有注意力头从尺寸统一、信息相似的子空间学习,导致多头聚焦同一区域、遗漏退化区域修复,最终影响图像恢复质量。其核心设计思路是通过“通道重排序+分层子空间划分”,让每个注意力头学习差异化、独立的上下文特征,同时保持高效计算。
一、核心设计背景与目标
1. 传统 MHA 的核心缺陷
传统 MHA 对输入特征的通道维度进行均匀划分(例如将总通道数 C 平均分为 h 份,每份尺寸为 C/h),所有注意力头在尺寸相同、信息重叠的子空间中独立计算注意力。这种设计导致:
- 注意力头倾向于关注图像中相同的“易修复区域”(如明亮、清晰区域),造成计算冗余;
- 忽略退化严重的区域(如暗角、雾气覆盖区域),导致修复不完整、细节丢失。
2. HMHA 的核心目标
- 让每个注意力头学习独立且多样化的上下文特征,避免多头冗余;
- 通过分层子空间设计,覆盖不同尺度、不同语义的特征信息(如细节纹理、全局结构);
- 保持与传统 MHA 相当的计算效率,不显著增加模型复杂度。
二、关键技术步骤与实现细节
HMHA 的工作流程可分为“通道重排序”和“分层子空间划分与注意力计算”两步,具体如下:
1. 步骤1:通道重排序(Reranking)—— 确保子空间独立性
传统 MHA 直接均匀划分通道,导致子空间信息重叠。HMHA 先对通道进行“相似度重排序”,核心目的是让相似信息的通道聚集,不同子空间的信息差异最大化:
- 排序依据:基于皮尔逊相关系数(Pearson Correlation)计算通道间的相似度;
- 操作逻辑:将相似度高的通道归为同一组,相似度低的通道分配到不同组,确保后续划分的每个子空间包含“独特语义信息”(如部分通道聚焦纹理、部分聚焦亮度);
- 核心作用:从源头减少子空间信息冗余,为每个注意力头分配“专属任务”。
2. 步骤2:分层子空间划分—— 差异化头分配
重排序后,HMHA 对通道进行非均匀的分层划分,而非传统 MHA 的均匀划分:
- 划分规则:将总通道数 C 划分为 h 个不同尺寸的子空间,即 ( C = [C_1, C_2, ..., C_h] ),且满足 ( C_1 ≤ C_2 ≤ ... ≤ C_h )(例如 4 个注意力头的子空间尺寸比例为 [1,2,2,3]);
- 头与子空间映射:每个注意力头独占一个子空间,在自身子空间内独立执行“缩放点积注意力”(Scaled Dot-Product Attention);
- 数学表达:
给定归一化输入特征 ( X \in \mathbb{R}^{H×W×C} )(H=高度、W=宽度、C=通道数),第 i 个注意力头的计算为:
$$
H_i = Attention(X W_Q^i, X W_K^i, X W_V^i)
$$
其中:
- $W_Q^i \in \mathbb{R}^{C×C_i}$、$W_K^i \in \mathbb{R}^{C×C_i}$、$W_V^i \in \mathbb{R}^{C×C_i}$ 是第 i 个头的查询(Q)、键(K)、值(V)投影矩阵,维度随子空间尺寸 $C_i$ 自适应调整;
- $Attention(\cdot)$为标准缩放点积注意力,即 $Attention(Q,K,V) = Softmax(\frac{QK^T}{\sqrt{C_i}})V$(分母为子空间尺寸的平方根,确保梯度稳定)。
3. 步骤3:特征聚合—— 整合多样化特征
所有注意力头计算完成后,通过“拼接(Concat)+ 线性投影(Projection)”整合所有头的输出,得到最终的 HMHA 特征: $$ HMHA(X) = Concat(H_1, H_2, ..., H_h) W_p $$ 其中 $W_p \in \mathbb{R}^{(C_1+C_2+...+C_h)×C}$是输出投影矩阵,将拼接后的特征映射回原始通道维度 C,确保与后续网络层兼容。
文章链接
论文地址:论文地址
代码地址:代码地址
基本原理
核心代码
class HMHA(nn.Module):
def __init__(self, dim, num_heads=8, bias=False):
super(HMHA, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(4, 1, 1))
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.group = [1, 2, 2, 3]
self.intra_modulator = Intra_CacheModulation(embed_dim=dim)
self.inter_modulator1 = Inter_CacheModulation(in_c=1 * dim // 8)
self.inter_modulator2 = Inter_CacheModulation(in_c=2 * dim // 8)
self.inter_modulator3 = Inter_CacheModulation(in_c=2 * dim // 8)
self.inter_modulator4 = Inter_CacheModulation(in_c=3 * dim // 8)
self.inter_modulators = [self.inter_modulator1, self.inter_modulator2, self.inter_modulator3,
self.inter_modulator4]
self.regroup = ReGroup(self.group)
self.dim = dim
def forward(self, x ):
b, c, h, w = x.shape
qv_cache = None
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b c h w -> b c (h w)')
k = rearrange(k, 'b c h w -> b c (h w)')
v = rearrange(v, 'b c h w -> b c (h w)')
qu, ke, va = self.regroup(q, k, v)
attScore = []
tmp_cache = []
for index in range(len(self.group)):
query_head = qu[index]
key_head = ke[index]
query_head = torch.nn.functional.normalize(query_head, dim=-1)
key_head = torch.nn.functional.normalize(key_head, dim=-1)
attn = (query_head @ key_head.transpose(-2, -1)) * self.temperature[index, :, :]
attn = attn.softmax(dim=-1)
attScore.append(attn) # CxC
t_cache = query_head.clone().detach() + key_head.clone().detach()
tmp_cache.append(t_cache)
tmp_caches = torch.cat(tmp_cache, 1)
out = []
if qv_cache is not None:
if qv_cache.shape[-1] != c:
qv_cache = F.adaptive_avg_pool2d(qv_cache, c)
for i in range(4):
if qv_cache is not None:
inter_modulator = self.inter_modulators[i]
attScore[i] = inter_modulator(attScore[i], qv_cache) + attScore[i]
out.append(attScore[i] @ va[i])
else:
out.append(attScore[i] @ va[i])
update_factor = 0.9
if qv_cache is not None:
update_elements = CalculateCurrentLayerCache(attScore, c, self.group)
qv_cache = qv_cache * update_factor + update_elements * (1 - update_factor)
else:
qv_cache = CalculateCurrentLayerCache(attScore, c, self.group)
qv_cache = qv_cache * update_factor
out_all = torch.concat(out, 1)
# Intra Modulation
out_all = self.intra_modulator(out_all, tmp_caches) + out_all
out_all = rearrange(out_all, 'b c (h w) -> b c h w', h=h, w=w)
out_all = self.project_out(out_all)
return out_all
实验
脚本
import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
if __name__ == '__main__':
# 修改为自己的配置文件地址
model = YOLO('./ultralytics/cfg/models/26/yolo26-C3k2_HMHA.yaml')
# 修改为自己的数据集地址
model.train(data='./ultralytics/cfg/datasets/coco8.yaml',
cache=False,
imgsz=640,
epochs=10,
single_cls=False, # 是否是单类别检测
batch=8,
close_mosaic=10,
workers=0,
# optimizer='MuSGD',
optimizer='SGD',
amp=False,
project='runs/train',
name='yolo26-C3k2_HMHA',
)
结果
