【论文阅读】CrossFormer: A Versatile Vision Transformer Based on Cross-scale Attention

发布时间 2023-07-06 13:12:31作者: ninisong

来自CVPR 2021

论文地址:https://link.zhihu.com/?target=https%3A//arxiv.org/pdf/2108.00154.pdf

代码地址:https://link.zhihu.com/?target=https%3A//github.com/cheerss/CrossFormer

一、Motivation

 主要还是ViT的历史遗留问题

ViT在处理输入时,将图片划分为了相等大小的图像块(Patch),然后通过linear操作生成token序列,这种操作导致ViT各层的输入嵌入是等尺度的,没有跨尺度特征,缺少不同尺度的交互能力,而这种能力对于视觉感知非常重要。一幅图像通常包含许多不同尺度的对象,建立它们之间的关系需要跨尺度的注意机制。此外,一些任务,如实例分割,需要大规模(粗粒度)特征和小规模(细粒度)特征之间的交互。

现有的vision transformers无法处理这些情况的原因有两个:

(1)嵌入序列是由大小相等的块生成的,因此同一层中的嵌入只具有单一尺度的特征。虽然理论上只要这些patch的大小足够大,可以有机会提取任何尺度的特征,但是patch大小变大后,输出的特征图的分辨率就会变低,难以学习到高分辨率的表征,这对密集预测类的任务很重要;

(2)在Self-Attention模块内部,部分方法放弃了K和V的部分表达,相邻嵌入的键/值经常被合并,以降低成本。因此,即使嵌入同时具有小尺度和大尺度特征,合并操作也会丢失每个单个嵌入的小尺度(细粒度)特征,从而使跨尺度注意力失效。例如,Swin-Transformer将self-attention操作的范围限制在每个window内,这一定程度上放弃了全局尺度的长距离关系。

二、Contribution

这篇文章主要是解决以往架构在建立跨尺度注意力方面的问题,从两个角度提出了跨尺度嵌入层(CEL)和长短距离注意力机制(LDA和SDA)来弥补这方面的空白。

最后附赠一个动态位置偏差(DPB),这与跨尺度无关,通过将偏差添加到注意力机制中表示嵌入的相对位置,动态得到位置偏差可以使相对位置偏差更加灵活。

三、CrossFormer

和PVT、Swin-Transformer一样,CrossFormer也采用了金字塔式的层级的结构,这样的好处是可以迁移到dense任务上去,做检测,分割等。模型分为四个阶段。每个阶段由一个跨尺度嵌入层(CEL)和几个CrossFormer block组成。CEL接收上一阶段的输出(或图像)作为输入,并生成跨尺度嵌入。在这个过程中,CEL(第一阶段除外)将金字塔结构的嵌入次数减少到四分之一,而将其维数增加了一倍。然后,在CEL之后放置几个CrossFormer块(包含LSDA和DPB)。在特定任务的最后阶段之后,紧随其后的是专门的head函数做分类。

1. Cross-scale Embedding Layer (CEL)

CrossFormer是层级的结构,既然是层级的结构那就一定会包含着一定尺度的下采样,这一点PVT和Swin中都有提到,CEL采用不同大小的卷积核(4×4,8×8)对图片做卷积,得到卷积后的结果,直接concat作为Patch Embedding。通过这种方式,强迫一些维度(例如4×4的卷积得到的部分)关注更细粒度、更小尺度的信息,而其他的维度(例如8×8的卷积得到的部分)有机会学习到更大尺度的信息。通过不同大小的卷积核获得不同尺度的信息,对变化尺度的物体是比较友好的。

从图中可以看出,CEL接受一幅影像作为输入,然后使用四个不同大小的卷积核进行采样,四个卷积核的步长大小相同,保证他们生成的token数目相同,四个对应的Patch的中心是相同的,但是尺度不同,将得到的四个尺度的Patch进行concat,就完成了patch embedding。

对于跨尺度的嵌入,有一个问题要注意:怎么设置每个尺度的嵌入维度?

先在这里推理一下卷积的计算量(FLOPs )和参数量(Parametaers):

FLOPs:

卷积层计算量 = 卷积矩阵操作 + 融合操作 + bias操作(注意:其中矩阵操作包括:先乘法,再加法)

假设我们输入一张7*7*3的图像,使用大小为5*5*3的84个卷积核,进行stride=1,padding=0的卷积操作,输出3*3*64的feature map

输出图像大小=【(7-5+2*0)/1】+1=3

对于单个像素点来说:需要进行(5*5*3)次卷积矩阵乘法操作,需要进行3*(5*5-1)次卷积矩阵加法操作,(3-1)次通道融合操作,归纳偏置项为1,FLOPs=所有项相加=150

对于一整张feature map有3*3*64个这样的像素点:FLOPs=150*3*3*64=86400

Parameters:
卷积层的参数量只与卷积核有关,Parameters=卷积核计算量+bias

单个卷积核的参数量为5*5*3=75,bias=1,单个卷积核的参数量=75+1=76,一次卷积的Parametars=76*64=4800

从上方推导可以看出,卷积层的计算预算与卷积核大小核输入输出维度有关,假设将每个尺度都设置成相同的维度,那么大的卷积核的计算成本将远大于小的卷积核,为了控制计算成本,所以将较小卷积核设置为较大的维度。

核心代码如下:

  • 初始化几个不同kernel,不同padding,相同stride的conv
  • 对输入进行卷积操作后得到的feature,做concat

 

 

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=[4], in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        ...

        self.projs = nn.ModuleList()
        for i, ps in enumerate(patch_size):
            if i == len(patch_size) - 1:
                dim = embed_dim // 2 ** i
            else:
                dim = embed_dim // 2 ** (i + 1)
            stride = patch_size[0]
            padding = (ps - patch_size[0]) // 2
            self.projs.append(nn.Conv2d(in_chans, dim, kernel_size=ps, stride=stride, padding=padding))
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        xs = []
        for i in range(len(self.projs)):
            tx = self.projs[i](x).flatten(2).transpose(1, 2)
            xs.append(tx)  # B Ph*Pw C
        x = torch.cat(xs, dim=2)
        if self.norm is not None:
            x = self.norm(x)
        return x

 

2.CrossFormer Block

每个CrossFormerBlock都由短距离注意(SDA)或长距离注意(LDA)模块和多层感知器(MLP)组成。特别是,SDA和LDA交替出现在不同的块中,并将 DPB 模块嵌入其中,以获得位置信息。在块中使用残差连接。

2.1 Long Short Distance Attention (LSDA)

将自注意力的计算分为两各部分,短距离注意力机制(SDA)和长距离注意力机制(LDA)。

SDA:将相邻的(G*G)个相邻的嵌入作为一个计算单元-组,这一步考虑的是局部内的信息;

LDA:将输入(S*S)进行间隔为I的采样,采样到位置的嵌入作为一个计算单元。每个组的大小为S/I。这一步则是考虑的是全局的信息。

对嵌入进行分组后,在每个组内进行self-attention的计算。

可以看出SDA和LDA都是基于window的注意力,那么是怎么保证全局信息获取的呢?

LDA中对feature map做了步长为I的采样,获得多个全局的嵌入,最大程度的利用feature map的全局性。类似于空洞卷积,将相隔I的像素作为一组,进行attention 操作,这样就引入了长距离的信息。

 

 核心代码如下:

  • 实现很简单,只需要两个reshape
class LSDA():
    def forward(x, type):
        # group the embeddings
        if type == "SDA":
            x = x.reshaspe(H // G, G, W // G, G, D).permute(0, 2, 1, 3, 4)
        elif type == "LDA":
            x = x.reshaspe(G, H // G, G, W // G, D).permute(1, 3, 0, 2, 4)
        x = x.reshape(H * W // (G ** 2), G ** 2, D)
        # the vanilla self-attention module
        x = Attention(x)
        # un-group the embeddings
        x = x.reshaspe(H // G, W // G, G, G, D)
        if type == "SDA":
            x = x.permute(0, 2, 1, 3, 4).reshaspe(H, W, D)
        elif type == "LDA":
            x = x.permute(2, 0, 3, 1, 4).reshaspe(H, W, D)
        return x

2.2 Dynamic Position Bias   

随着位置编码技术的不断发展,相对位置编码偏差逐渐的应用到了transformers中,很多的vision transformers均采用RPB来替换原始的APE,好处是可以直接插入到我们的attention中,不需要很繁琐的公式计算,并且可学习性高,鲁棒性强,公式如下:

 

以Swin-Transformer为例,位置偏差矩阵B是一个固定大小的矩阵,使用第i、j两个嵌入之间的坐标距离来表达相对位置,这里有一个问题,就是我么输入的图像的大小不能超过一定的范围,否则位置编码就失去作用。举个例子,加入我们设置的window大小为7*7,相对位置范围是【-6,6】,但是如果我们把window的大小扩大,扩大为9*9,相对位置的范围不变,但是失去了对外层数据的访问,这样位置编码就失去了作用。

有一些常用的方法,例如插值,但是插值产生的信息在没有微调的情况下,会降低性能。因为插值是通过原始的位置信息来模拟出来信息,实际上还是原始的信息,没有信息收益。

所以作者在这里提出了一种可学习的动态位置偏差,DBP的结构是基于MLP的,由3个Linear层+LayerNorm+ReLU组成的block堆叠而成,最后接一个输出为1的线性层做bias的表征。DPB的思想则是,我们不希望通过用实际的相对位置来做embeeidng,而是希望通过隐空间先对位置偏差进行学习。输入是(N,2),由于self-attention是由多个head组成的,所以输出为(N,1×head),核心代码如下:

1.首先,构建一个相对位置矩阵,假设窗口大小为7,那么输入就是大小就是【(2*7-1)*(2*7-1),2】=(169,2)

self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
 
# generate mother-set
position_bias_h = torch.arange(1 - self.group_size[0], self.group_size[0])
position_bias_w = torch.arange(1 - self.group_size[1], self.group_size[1])
biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))  # 2, 2Wh-1, 2Wh-1
biases = biases.flatten(1).transpose(0, 1).float()
self.register_buffer("biases", biases)

2. 构建索引矩阵, 得到了一个49×49的一个索引,从右上角为0开始,向左和向下递增。

coords_h = torch.arange(self.group_size[0])
coords_w = torch.arange(self.group_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.group_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.group_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.group_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

DPB是一个与整个模型一起优化的可培训模块。它可以处理任何图像/组大小。

三、变体

  变体很多,还有Cross Former++,加入了渐进式的组大小和冷却层。

四、Rethinking

这种局部和全局信息的串联结构并不新颖了,和swin的想法有点类似,还待深入。主要解决的是跨尺度的问题,非常值得借鉴。