YOLO26改进 – C3k2 C3k2融合FDConv频率动态卷积:空间-频域协同调制增强细节捕获,提升小目标与边界模糊目标检出 CVPR 2025

# 前言

本文提出新型频率动态卷积(FDConv),旨在解决传统动态卷积权重频率响应相似、参数开销大且适应性有限的问题。FDConv通过在傅里叶域学习固定参数预算,将其划分为基于频率的分组,构建频率多样化的权重。同时设计了核空间调制(KSM)和频段调制(FBM),分别在空间和频率域提升适应性。大量实验表明,FDConv应用于ResNet - 50时,仅增360万参数就能实现更优性能。我们将FDConv集成进YOLO26,替换部分模块,在目标检测、分割和分类任务中验证了其有效性,为现代视觉任务提供灵活高效的解决方案。

文章目录: YOLO26改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLO26改进专栏

介绍

image-20251126225825205

摘要

尽管动态卷积(DY-Conv)通过“多组并行权重+注意力机制”实现自适应权重选择,展现出良好性能,但这些权重的频率响应往往高度相似,导致参数开销大而适应性有限。在本文中,我们提出一种新型频率动态卷积(FDConv),通过在傅里叶域学习固定参数预算来缓解这些局限性。FDConv将该预算划分为具有不相交傅里叶索引的基于频率的分组,能够在不增加参数开销的前提下构建频率多样化的权重。为进一步提升适应性,我们设计了核空间调制(KSM)和频段调制(FBM):KSM在空间层面动态调整每个滤波器的频率响应,而FBM则在频率域将权重分解为不同频段,并基于局部内容进行动态调制。我们在目标检测、分割和分类任务上开展了大量实验,验证了FDConv的有效性。结果表明,将FDConv应用于ResNet-50时,仅需小幅增加360万参数即可实现更优性能,超越了此前需要大幅增加参数预算的方法(如CondConv增加9000万参数、KW增加7650万参数)。此外,FDConv可无缝集成到多种架构中,包括ConvNeXt、Swin-Transformer,为现代视觉任务提供了灵活高效的解决方案。相关代码已公开在https://github.com/LinweiChen/FDConv。

文章链接

论文地址:论文地址

代码地址:代码地址

基本原理

FDConv是针对密集图像预测任务(目标检测、分割、分类等)设计的新型动态卷积模块,核心目标是解决传统动态卷积(DY-Conv)并行权重频率响应相似、参数冗余、适应性有限的问题。其创新点在于在傅里叶域学习固定参数预算,通过“频率分组-空间调制-频段适配”的三层设计,在不显著增加计算成本的前提下,大幅提升权重的频率多样性与自适应能力,可无缝集成到ConvNets和视觉Transformer架构中。

一、核心设计定位

1. 解决的核心痛点

传统动态卷积(如CondConv、ODConv)通过并行权重+注意力融合实现自适应,但存在两大缺陷:

  • 频率冗余:并行权重的频率响应高度相似(如图1(a)),t-SNE聚类紧密(图1(c)),难以捕捉多频段特征;
  • 参数低效:为提升多样性需成倍增加权重数量,导致参数暴增(如CondConv+90M、ODConv+65.1M参数),部署成本高。

2. 设计目标

  • 提升频率多样性:让并行权重覆盖不同频段(低频降噪、高频捕细节),增强特征表达;
  • 参数高效:维持固定参数预算,仅小幅增加参数(+3.6M),远低于传统动态卷积;
  • 空间-频率协同:突破传统卷积的空间不变性,实现“频段-空间”的动态适配。

二、三大核心模块详解

FDConv的核心由三个协同模块构成,从傅里叶域到空间域、从全局到局部层层优化频率适应性:

1. 傅里叶不相交权重(FDW:Fourier Disjoint Weight)

  • 核心作用:在不增加参数的前提下,生成频率响应多样化的并行权重,解决传统动态卷积权重频率同质化问题。
  • 实现原理与步骤
    1. 傅里叶域参数分组:将固定参数预算(与标准卷积一致)视为傅里叶域的谱系数,按傅里叶索引的L₂范数(从低到高排序)均匀划分为n个不相交分组(n可>10,远超传统动态卷积的n<10),每组对应独立频段;
    2. 傅里叶-空间转换:对每个分组应用逆离散傅里叶变换(iDFT),将傅里叶域的谱系数转换为空间域的权重雏形,未分配到该组的傅里叶索引设为0;
    3. 权重重组:将转换后的空间域特征裁剪为k×k(卷积核大小)的补丁,重组为标准卷积权重形状(k×k×C_in×C_out)。
  • 关键优势:每组权重仅包含特定频段信息,不同分组权重的频率响应完全独立(图1(b)(d)),实现“参数不增、多样性倍增”。

2. 核空间调制(KSM:Kernel Spatial Modulation)

  • 核心作用:对FDW生成的权重进行“逐元素微调”,解决FDW权重级融合过粗、无法适配单个滤波器频率响应的问题。
  • 实现原理与结构
    • 双分支设计:
      1. 局部通道分支:用轻量1D卷积捕捉局部通道信息,预测稠密调制矩阵(k×k×C_in×C_out),实现每个权重元素的细粒度调整,参数开销极低;
      2. 全局通道分支:用全局平均池化+全连接层捕捉全局信息,预测3个维度调制值(输入通道、输出通道、核空间维度),补充局部分支的全局视野;
    • 融合调制:将双分支输出相乘,得到最终调制矩阵,与FDW生成的权重进行哈达玛积(逐元素相乘),完成频率响应微调。
  • 关键优势:兼顾局部细节与全局上下文,让每个滤波器的频率响应能自适应输入特征,提升表达能力。

3. 频段调制(FBM:Frequency Band Modulation)

  • 核心作用:突破传统卷积的“空间不变性”,实现“空间位置-频段”的动态适配,解决不同区域对频段需求不同的问题(如边界需高频、背景需低噪)。
  • 实现原理与步骤
    1. 核频段分解:将卷积核填充至特征图大小,用二进制掩码(M_b)在傅里叶域分离出B个频段(默认4个,按倍频划分:[0,1/16)、[1/16,1/8)、[1/8,1/4)、[1/4,1/2]),每个频段对应独立核权重(W_b);
    2. 傅里叶域卷积:利用卷积定理(空间域卷积=频率域点乘),在傅里叶域高效计算每个频段的卷积输出(Y_b),避免空间域直接分离频段的无限支持问题;
    3. 空间变体调制:用卷积+Sigmoid生成每个频段的空间调制图(Ab∈R^(h×w)),动态调整每个空间位置上对应频段的权重,最终融合所有频段输出:(Y=\sum{b=0}^{B-1}(A_b \odot Y_b))。
  • 关键优势:可选择性增强目标边界的高频成分、抑制背景的高频噪声(图6),让频率响应适配空间变化,提升密集预测精度。

核心代码

class FDConv(nn.Conv2d):
    def __init__(self, 
                 *args, 
                 reduction=0.0625, 
                 kernel_num=4,
                 use_fdconv_if_c_gt=16, #if channel greater or equal to 16, e.g., 64, 128, 256, 512
                 use_fdconv_if_k_in=[1, 3], #if kernel_size in the list
                 use_fdconv_if_stride_in=[1], #if stride in the list
                 use_fbm_if_k_in=[3], #if kernel_size in the list
                 use_fbm_for_stride=False,
                 kernel_temp=1.0,
                 temp=None,
                 att_multi=2.0,
                 param_ratio=1,
                 param_reduction=1.0,
                 ksm_only_kernel_att=False,
                 att_grid=1,
                 use_ksm_local=True,
                 ksm_local_act='sigmoid',
                 ksm_global_act='sigmoid',
                 spatial_freq_decompose=False,
                 convert_param=True,
                 linear_mode=False,
                 fbm_cfg={
                    'k_list':[2, 4, 8],
                    'lowfreq_att':False,
                    'fs_feat':'feat',
                    'act':'sigmoid',
                    'spatial':'conv',
                    'spatial_group':1,
                    'spatial_kernel':3,
                    'init':'zero',
                    'global_selection':False,
                 },
                 **kwargs,
                 ):
        super().__init__(*args, **kwargs)
        self.use_fdconv_if_c_gt = use_fdconv_if_c_gt
        self.use_fdconv_if_k_in = use_fdconv_if_k_in
        self.use_fdconv_if_stride_in = use_fdconv_if_stride_in
        self.kernel_num = kernel_num
        self.param_ratio = param_ratio
        self.param_reduction = param_reduction
        self.use_ksm_local = use_ksm_local
        self.att_multi = att_multi
        self.spatial_freq_decompose = spatial_freq_decompose
        self.use_fbm_if_k_in = use_fbm_if_k_in

        self.ksm_local_act = ksm_local_act
        self.ksm_global_act = ksm_global_act
        assert self.ksm_local_act in ['sigmoid', 'tanh']
        assert self.ksm_global_act in ['softmax', 'sigmoid', 'tanh']

        ### Kernel num & Kernel temp setting
        if self.kernel_num is None:
            self.kernel_num = self.out_channels // 2
            kernel_temp = math.sqrt(self.kernel_num * self.param_ratio)
        if temp is None:
            temp = kernel_temp

        if min(self.in_channels, self.out_channels) <= self.use_fdconv_if_c_gt \
            or self.kernel_size[0] not in self.use_fdconv_if_k_in:
                return
        print('*** kernel_num:', self.kernel_num)
        self.alpha = min(self.out_channels, self.in_channels) // 2 * self.kernel_num * self.param_ratio / param_reduction
        self.KSM_Global = KernelSpatialModulation_Global(self.in_channels, self.out_channels, self.kernel_size[0], groups=self.groups, 
                                                        temp=temp,
                                                        kernel_temp=kernel_temp,
                                                        reduction=reduction, kernel_num=self.kernel_num * self.param_ratio, 
                                                        kernel_att_init=None, att_multi=att_multi, ksm_only_kernel_att=ksm_only_kernel_att, 
                                                        act_type=self.ksm_global_act,
                                                        att_grid=att_grid, stride=self.stride, spatial_freq_decompose=spatial_freq_decompose)

        # print(use_fbm_for_stride, self.stride[0] > 1)
        if self.kernel_size[0] in use_fbm_if_k_in or (use_fbm_for_stride and self.stride[0] > 1):
            self.FBM = FrequencyBandModulation(self.in_channels, **fbm_cfg)
            # self.channel_comp = ChannelPool(reduction=16)

        if self.use_ksm_local:
            self.KSM_Local = KernelSpatialModulation_Local(channel=self.in_channels, kernel_num=1, out_n=int(self.out_channels * self.kernel_size[0] * self.kernel_size[1]) )

        self.linear_mode = linear_mode
        self.convert2dftweight(convert_param)

    def convert2dftweight(self, convert_param):
        d1, d2, k1, k2 = self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]
        freq_indices, _ = get_fft2freq(d1 * k1, d2 * k2, use_rfft=True) # 2, d1 * k1 * (d2 * k2 // 2 + 1)
        # freq_indices = freq_indices.reshape(2, self.kernel_num, -1)
        weight = self.weight.permute(0, 2, 1, 3).reshape(d1 * k1, d2 * k2)
        weight_rfft = torch.fft.rfft2(weight, dim=(0, 1)) # d1 * k1, d2 * k2 // 2 + 1
        if self.param_reduction < 1:
            # freq_indices = freq_indices[:, torch.randperm(freq_indices.size(1), generator=torch.Generator().manual_seed(freq_indices.size(1)))] # 2, indices
            # freq_indices = freq_indices[:, :int(freq_indices.size(1) * self.param_reduction)] # 2, indices
            num_to_keep = int(freq_indices.size(1) * self.param_reduction)
            freq_indices = freq_indices[:, :num_to_keep] # 保留前 k 个最低频的索引
            weight_rfft = torch.stack([weight_rfft.real, weight_rfft.imag], dim=-1)
            weight_rfft = weight_rfft[freq_indices[0, :], freq_indices[1, :]]
            weight_rfft = weight_rfft.reshape(-1, 2)[None, ].repeat(self.param_ratio, 1, 1) / (min(self.out_channels, self.in_channels) // 2)
        else:
            weight_rfft = torch.stack([weight_rfft.real, weight_rfft.imag], dim=-1)[None, ].repeat(self.param_ratio, 1, 1, 1) / (min(self.out_channels, self.in_channels) // 2) #param_ratio, d1, d2, k*k, 2

        if convert_param:
            self.dft_weight = nn.Parameter(weight_rfft, requires_grad=True)
            del self.weight
        else:
            if self.linear_mode:
                assert self.kernel_size[0] == 1 and self.kernel_size[1] == 1
                self.weight = torch.nn.Parameter(self.weight.squeeze(), requires_grad=True)
        indices = []
        for i in range(self.param_ratio):
            indices.append(freq_indices.reshape(2, self.kernel_num, -1)) # paramratio, 2, kernel_num, d1 * k1 * (d2 * k2 // 2 + 1) // kernel_num
        self.register_buffer('indices', torch.stack(indices, dim=0), persistent=False)

    def get_FDW(self, ):
        d1, d2, k1, k2 = self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1]
        weight = self.weight.reshape(d1, d2, k1, k2).permute(0, 2, 1, 3).reshape(d1 * k1, d2 * k2)
        weight_rfft = torch.fft.rfft2(weight, dim=(0, 1)).contiguous() # d1 * k1, d2 * k2 // 2 + 1
        weight_rfft = torch.stack([weight_rfft.real, weight_rfft.imag], dim=-1)[None, ].repeat(self.param_ratio, 1, 1, 1) / (min(self.out_channels, self.in_channels) // 2) #param_ratio, d1, d2, k*k, 2
        return weight_rfft

    def forward(self, x):
        if min(self.in_channels, self.out_channels) <= self.use_fdconv_if_c_gt or self.kernel_size[0] not in self.use_fdconv_if_k_in:
            return super().forward(x)
        global_x = F.adaptive_avg_pool2d(x, 1)
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.KSM_Global(global_x)
        if self.use_ksm_local:
            # global_x_std = torch.std(x, dim=(-1, -2), keepdim=True)
            hr_att_logit = self.KSM_Local(global_x) # b, kn, cin, cout * ratio, k1*k2, 
            hr_att_logit = hr_att_logit.reshape(x.size(0), 1, self.in_channels, self.out_channels, self.kernel_size[0], self.kernel_size[1])
            # hr_att_logit = hr_att_logit + self.hr_cin_bias[None, None, :, None, None, None] + self.hr_cout_bias[None, None, None, :, None, None] + self.hr_spatial_bias[None, None, None, None, :, :]
            hr_att_logit = hr_att_logit.permute(0, 1, 3, 2, 4, 5)
            if self.ksm_local_act == 'sigmoid':
                hr_att = hr_att_logit.sigmoid() * self.att_multi
            elif self.ksm_local_act == 'tanh':
                hr_att = 1 + hr_att_logit.tanh()
            else:
                raise NotImplementedError
        else:
            hr_att = 1
        b = x.size(0)
        batch_size, in_planes, height, width = x.size()
        DFT_map = torch.zeros((b, self.out_channels * self.kernel_size[0], self.in_channels * self.kernel_size[1] // 2 + 1, 2), device=x.device)
        kernel_attention = kernel_attention.reshape(b, self.param_ratio, self.kernel_num, -1)
        if hasattr(self, 'dft_weight'):
            dft_weight = self.dft_weight
        else:
            dft_weight = self.get_FDW()
            # print('get_FDW')

        # _t0 = time.perf_counter()
        for i in range(self.param_ratio):
            # print(i)
            # print(DFT_map.device)
            indices = self.indices[i]
            if self.param_reduction < 1:
                w = dft_weight[i].reshape(self.kernel_num, -1, 2)[None]
                DFT_map[:, indices[0, :, :], indices[1, :, :]] += torch.stack([w[..., 0] * kernel_attention[:, i], w[..., 1] * kernel_attention[:, i]], dim=-1)
            else:
                w = dft_weight[i][indices[0, :, :], indices[1, :, :]][None] * self.alpha # 1, kernel_num, -1, 2
                # print(w.shape)
                DFT_map[:, indices[0, :, :], indices[1, :, :]] += torch.stack([w[..., 0] * kernel_attention[:, i], w[..., 1] * kernel_attention[:, i]], dim=-1)
                pass
        # print(time.perf_counter() - _t0)
        adaptive_weights = torch.fft.irfft2(torch.view_as_complex(DFT_map), dim=(1, 2)).reshape(batch_size, 1, self.out_channels, self.kernel_size[0], self.in_channels, self.kernel_size[1])
        adaptive_weights = adaptive_weights.permute(0, 1, 2, 4, 3, 5)
        # print(spatial_attention, channel_attention, filter_attention)
        if hasattr(self, 'FBM'):
            x = self.FBM(x)
            # x = self.FBM(x, self.channel_comp(x))

        if self.out_channels * self.in_channels * self.kernel_size[0] * self.kernel_size[1] < (in_planes + self.out_channels) * height * width:
            # print(channel_attention.shape, filter_attention.shape, hr_att.shape)
            aggregate_weight = spatial_attention * channel_attention * filter_attention * adaptive_weights * hr_att
            # aggregate_weight = spatial_attention * channel_attention * adaptive_weights * hr_att
            aggregate_weight = torch.sum(aggregate_weight, dim=1)
            # print(aggregate_weight.abs().max())
            aggregate_weight = aggregate_weight.view(
                [-1, self.in_channels // self.groups, self.kernel_size[0], self.kernel_size[1]])
            x = x.reshape(1, -1, height, width)
            output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                            dilation=self.dilation, groups=self.groups * batch_size)
            if isinstance(filter_attention, float): 
                output = output.view(batch_size, self.out_channels, output.size(-2), output.size(-1))
            else:
                output = output.view(batch_size, self.out_channels, output.size(-2), output.size(-1)) # * filter_attention.reshape(b, -1, 1, 1)
        else:
            aggregate_weight = spatial_attention * adaptive_weights * hr_att
            aggregate_weight = torch.sum(aggregate_weight, dim=1)
            if not isinstance(channel_attention, float): 
                x = x * channel_attention.view(b, -1, 1, 1)
            aggregate_weight = aggregate_weight.view(
                [-1, self.in_channels // self.groups, self.kernel_size[0], self.kernel_size[1]])
            x = x.reshape(1, -1, height, width)
            output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                            dilation=self.dilation, groups=self.groups * batch_size)
            # if isinstance(filter_attention, torch.FloatTensor): 
            if isinstance(filter_attention, float): 
                output = output.view(batch_size, self.out_channels, output.size(-2), output.size(-1))
            else:
                output = output.view(batch_size, self.out_channels, output.size(-2), output.size(-1)) * filter_attention.view(b, -1, 1, 1)
        if self.bias is not None:
            output = output + self.bias.view(1, -1, 1, 1)
        return output

    def profile_module(
                self, input: Tensor, *args, **kwargs
            ):
            # TODO: to edit it
            b_sz, c, h, w = input.shape
            seq_len = h * w

            # FFT iFFT
            p_ff, m_ff = 0, 5 * b_sz * seq_len * int(math.log(seq_len)) * c
            # others
            # params = macs = sum([p.numel() for p in self.parameters()])
            params = macs = self.hidden_size * self.hidden_size_factor * self.hidden_size * 2 * 2 // self.num_blocks
            # // 2 min n become half after fft
            macs = macs * b_sz * seq_len

            # return input, params, macs
            return input, params, macs + m_ff

实验

脚本

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
#     修改为自己的配置文件地址
    model = YOLO('./ultralytics/cfg/models/26/yolo26-C3k2_FDConv.yaml')
#     修改为自己的数据集地址
    model.train(data='./ultralytics/cfg/datasets/coco8.yaml',
                cache=False,
                imgsz=640,
                epochs=10,
                single_cls=False,  # 是否是单类别检测
                batch=8,
                close_mosaic=10,
                workers=0,
                # optimizer='MuSGD',  
                optimizer='SGD',
                amp=False,
                project='runs/train',
                name='yolo26-C3k2_FDConv',
                )

结果

image-20260124210142917

THE END