飞道的博客

【零基础讲论文源码】CVT:Introducing Convolutions to Vision Transformers

541人阅读  评论(0)

目前这个系列会开两个方向, cv transformer 和OCR方向。

Transformer方向

OCR方向

  • DBnet解读【链接】(正在制作中。。。)
  • PP_OCR【链接】(待续)
  • SRN【链接】
  • read like human【链接】

整体介绍:

CvT: Introducing Convolutions to Vision Transformers,刚发不久的一篇文章,最近Transformer很多,之所有现在选这一篇是因为方法简洁高效,性能在现在大神云集的Transformer算法里非常有竞争力。另感觉swin-trans源码非常碎,操作繁琐,所以更喜欢这一篇。

Cvt论文原文【链接】
Cvt 解读代码【链接】(论文里的代码链接无效,找了个star很多的)

整体流程图:


整体改进非常简单,

  • 通过卷积7*7获得conv embedding。
  • 通过深度卷积进行conv proj,即将特征转化成query ,value,key向量。这种转化方式可以见下

CVT配置


可以对比途中的CVT13配置信息来看给CVT初始化的各个参数。整体结构应该分为三个阶段。刚看的可以跳过这一些配置。

num_classes,
            s1_emb_dim = 64,
            s1_emb_kernel = 7,
            s1_emb_stride = 4,
            s1_proj_kernel = 3,
            s1_kv_proj_stride = 2,
            s1_heads = 1,
            s1_depth = 1,
            s1_mlp_mult = 4,
            s2_emb_dim = 192,
            s2_emb_kernel = 3,
            s2_emb_stride = 2,
            s2_proj_kernel = 3,
            s2_kv_proj_stride = 2,
            s2_heads = 3,
            s2_depth = 2,
            s2_mlp_mult = 4,
            s3_emb_dim = 384,
            s3_emb_kernel = 3,
            s3_emb_stride = 2,
            s3_proj_kernel = 3,
            s3_kv_proj_stride = 2,
            s3_heads = 6,
            s3_depth = 10,
            s3_mlp_mult = 4,
            dropout = 0.

CVT代码主函数

class CvT(nn.Module):
    def __init__(
            #见上代码
    ):
        super().__init__()
        kwargs = dict(locals())

        for prefix in ('s1', 's2', 's3'):
            config, kwargs = group_by_key_prefix_and_remove_prefix(f'{prefix}_', kwargs)
            #主要用于处理命令参数和配置参数

            layers.append(nn.Sequential(
                nn.Conv2d(dim, config['emb_dim'], kernel_size = config['emb_kernel'], padding = (config['emb_kernel'] // 2), stride = config['emb_stride']),
                Transformer()
            ))

        self.layers = nn.Sequential(
            *layers,
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...'),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        #b c h w
        return self.layers(x)

可以看到整体的结构为
stage1(conv2d-> transformer)-> stage2(conv2d-> transformer)->
stage3(conv2d-> transformer)->AdaptiveAvgPool2d -> Linear,

  • conv2d并不是深度卷积,可以理解为一个patch embedding操作,在大部分cv transformer中均会进行。这里是每个阶段会使用一次Patch embedding,来增加通道和降低分辨率。
  • transformer函数是整体的Transformer结构
  • 池化和全连接用于进行最后的预测。

Transformer函数

class Transformer(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, depth, heads, dim_head = 64, mlp_mult = 4, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, proj_kernel = proj_kernel, kv_proj_stride = kv_proj_stride, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout))
            ]))
    def forward(self, x):
    	# x: b c h w
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

Transformer函数是一个典型的transformer结构:
x->attention(x)+残差-> MLP(全连接)+残差

Attention

这是CVT方法的核心代码部分:

class Attention(nn.Module):
    def __init__(self, dim, proj_kernel, kv_proj_stride, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        padding = proj_kernel // 2
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)

        self.to_q = DepthWiseConv2d(dim, inner_dim, proj_kernel, padding = padding, stride = 1, bias = False)
        self.to_kv = DepthWiseConv2d(dim, inner_dim * 2, proj_kernel, padding = padding, stride = kv_proj_stride, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv2d(inner_dim, dim, 1),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        shape = x.shape
        
        # b=b, n=c, _=h, y=w 
        b, n, _, y, h = *shape, self.heads
        
        q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
        #chunk可以把元素切开,to_kv(x) -> b in_dim*2 h w -> b in_dim h w
        
        # b in_dim h w -> (b head) (w*h) in_dim/head 
        q, k, v = map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))
        
        # self-attention op
        dots = einsum('b i d, b j d -> b i j', q, k) * self.scale
        attn = self.attend(dots)
        out = einsum('b i j, b j d -> b i d', attn, v)
        
        # reshape b c h w
        out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, y = y)
        # to_out downsample
        return self.to_out(out)

这里用空洞卷积 DepthWiseConv2d作为Projection for Attention给q,k,v。kv均为步长2,q为步长1。
其中map(lambda t: rearrange(t, 'b (h d) x y -> (b h) (x y) d', h = h), (q, k, v))

  • map:第一个参数 function 以参数序列中的每一个元素调用 function 函数,返回包含每次 function 函数返回值的新列表
  • lambda :匿名函数
    所以代码的意思是将q,k,v 分别代入函数 lambda中作为t,执行'b (h d) x y -> (b h) (x y) d'

整个attention操作为:原始的x为4维特征,通过空洞卷积进行映射,变成3维特征等,执行self-attention操作,之后reshape回4维特征,通过卷积(普通卷积to_out)将维度升回去。


卷积位置: 看上面这张图,从特征中进行卷积(b)位置(conv2d),然后接空洞卷积3个(图中表示为3个),最后出去的时候仍会接卷积(to_out).

前置知识:深度卷积

在这个算法中,提到了一个卷积方式, DepthWiseConv2d:深度卷积(懂的各位可以直接跳过)
优势:减少参数数量

正常卷积核是对3个通道同时做卷积。也就是说,3个通道,在一次卷积后,输出一个数。深度可分离卷积分为两步:

  • 第一步用三个卷积对三个通道分别做卷积,这样在一次卷积后,输出3个数。
  • 这输出的三个数,再通过一个1x1xout_dim的卷积核(pointwise核),得到一个数。

对每个通道做一次卷积并且不相加,最后在用一个卷积处理成想要的维度

CVT中的实现

class DepthWiseConv2d(nn.Module):
    def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True):
        super().__init__()
        self.net = nn.Sequential(
            # 对每个通道做分组卷积,groups=dim_in
            nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
            nn.BatchNorm2d(dim_in),
            # 实现点乘操作,即所有通道融合并升维
            nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
        )
    def forward(self, x):
        return self.net(x)

深度卷积详细介绍可看【链接】


转载:https://blog.csdn.net/qq_35307005/article/details/116204958
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场