浙大 | FcaNet:频域通道注意力机制(keras实现)

FcaNet: Frequency Channel Attention Networks 



这篇论文,将GAP推广到一种更为一般的2维的离散余弦变换(DCT)形式,通过引入更多的frequency analysis重新考虑通道的注意力。




作者认为在通道注意力里的GAP抑制了通道之间的多样性,且GAP本质是离散余弦变换(DCT)的最低分频。若从频域角度分析,可以使用更多有用的其它分量,从而形成一个多谱通道注意力(Multi-Spectral Channel Attention)。重要的是文中设计了一种启发式的两步准则来选择多谱注意力模块的频率分量,其主要思想是先得到每个频率分量的重要性再确定不同数目频率分量的效果。具体而言,先分别计算通道注意力中采用各个频率分量的结果,然后,根据结果少选出topk个性能最好的分量。



def fca(inputs, name, ratio=8):
    w, h, out_dim = [int(x) for x in inputs.shape[1:]]
    temp_dim = max(out_dim // ratio, ratio)
    pool = MultiSpectralAttentionLayer(out_dim, h, w)(inputs)

    excitation = Dense(temp_dim, activation='relu', use_bias=False, name=name + '_Dense_1')(pool)
    excitation = Dense(out_dim, activation='sigmoid', use_bias=False,name=name + '_Dense_2')(excitation)
    excitation = Reshape((1, 1, out_dim), name=name + '_Reshape')(excitation)
    excitation = Multiply(name=name + '_Multiply')([inputs, excitation])
    return excitation
import math
import tensorflow as tf
from keras.layers import Layer
import numpy as np

def get_freq_indices(method):
    assert method in ['top1', 'top2', 'top4', 'top8', 'top16', 'top32',
                      'bot1', 'bot2', 'bot4', 'bot8', 'bot16', 'bot32',
                      'low1', 'low2', 'low4', 'low8', 'low16', 'low32']
    num_freq = int(method[3:])
    if 'top' in method:
        all_top_indices_x = [0, 0, 6, 0, 0, 1, 1, 4, 5, 1, 3, 0, 0, 0, 3, 2, 4, 6, 3, 5, 5, 2, 6, 5, 5, 3, 3, 4, 2, 2,
                             6, 1]
        all_top_indices_y = [0, 1, 0, 5, 2, 0, 2, 0, 0, 6, 0, 4, 6, 3, 5, 2, 6, 3, 3, 3, 5, 1, 1, 2, 4, 2, 1, 1, 3, 0,
                             5, 3]
        mapper_x = all_top_indices_x[:num_freq]
        mapper_y = all_top_indices_y[:num_freq]
    elif 'low' in method:
        all_low_indices_x = [0, 0, 1, 1, 0, 2, 2, 1, 2, 0, 3, 4, 0, 1, 3, 0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5, 6, 1, 2,
                             3, 4]
        all_low_indices_y = [0, 1, 0, 1, 2, 0, 1, 2, 2, 3, 0, 0, 4, 3, 1, 5, 4, 3, 2, 1, 0, 6, 5, 4, 3, 2, 1, 0, 6, 5,
                             4, 3]
        mapper_x = all_low_indices_x[:num_freq]
        mapper_y = all_low_indices_y[:num_freq]
    elif 'bot' in method:
        all_bot_indices_x = [6, 1, 3, 3, 2, 4, 1, 2, 4, 4, 5, 1, 4, 6, 2, 5, 6, 1, 6, 2, 2, 4, 3, 3, 5, 5, 6, 2, 5, 5,
                             3, 6]
        all_bot_indices_y = [6, 4, 4, 6, 6, 3, 1, 4, 4, 5, 6, 5, 2, 2, 5, 1, 4, 3, 5, 0, 3, 1, 1, 2, 4, 2, 1, 1, 5, 3,
                             3, 3]
        mapper_x = all_bot_indices_x[:num_freq]
        mapper_y = all_bot_indices_y[:num_freq]
        raise NotImplementedError
    return mapper_x, mapper_y

class MultiSpectralAttentionLayer(Layer):
    def __init__(self, channel, dct_h, dct_w, reduction=16, freq_sel_method='top16'):
        super(MultiSpectralAttentionLayer, self).__init__()
        self.reduction = reduction
        self.dct_h = dct_h
        self.dct_w = dct_w

        mapper_x, mapper_y = get_freq_indices(freq_sel_method)
        self.num_split = len(mapper_x)
        mapper_x = [temp_x * (dct_h // 7) for temp_x in mapper_x]
        mapper_y = [temp_y * (dct_w // 7) for temp_y in mapper_y]
        # make the frequencies in different sizes are identical to a 7x7 frequency space
        # eg, (2,2) in 14x14 is identical to (1,1) in 7x7

        assert len(mapper_x) == len(mapper_y)
        assert channel % len(mapper_x) == 0

        self.num_freq = len(mapper_x)

        # fixed DCT init
        self.weight = self.get_dct_filter(dct_h, dct_w, mapper_x, mapper_y, channel).transpose([2, 1, 0])

    def call(self, x):
        x = x * self.weight
        result = tf.reduce_sum(x, axis=[1, 2])
        return result

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[-1])

    def build_filter(self, pos, freq, POS):
        result = math.cos(math.pi * freq * (pos + 0.5) / POS) / math.sqrt(POS)
        if freq == 0:
            return result
            return result * math.sqrt(2)

    def get_dct_filter(self, tile_size_x, tile_size_y, mapper_x, mapper_y, channel):
        dct_filter = np.zeros((channel, tile_size_x, tile_size_y))
        c_part = channel // len(mapper_x)
        for i, (u_x, v_y) in enumerate(zip(mapper_x, mapper_y)):
            for t_x in range(tile_size_x):
                for t_y in range(tile_size_y):
                    dct_filter[i * c_part: (i + 1) * c_part, t_x, t_y] = self.build_filter(t_x, u_x,
                                                                                           tile_size_x) * self.build_filter(
                        t_y, v_y, tile_size_y)

        return dct_filter



