CSWin Transfomer:超越Swin Transformer的网络来了
点蓝色字关注“机器学习算法工程师”
设为星标,干货直达!
近期,微软亚研院继Swin Transformer之后又推出了CSWin Transformer。和Swin Transformer一样,CSWin Transformer也是一种local self-attention网络,相比Swin的方形window self-attention,CSWin采用的是十字形(cross-shaped)window self-attention,这使得CSWin Transformer的建模能力更强,在分类和检测等任务上也超过Swin Transformer,其中CSWin-L在语义分割数据集ADE20K上达到了SOTA:55.7 mIoU,超过Swin-L的53.5(不过目前微软提出的无监督训练模型BEiT-L已经再次刷新了榜单:57.0 mIoU)。CSWin Transformer和Swin Transformer一样采用金字塔结构,共包括4个stage,各个stage的特征图大小分别是原图的1/4,1/8,1/16和1/32。CSWin Transformer主要有三个重要的改进:Overlapping Patch Embedding,Cross-Shaped Window Self-Attention和Locally-Enhanced Positional Encoding。
Overlapping Patch EmbeddingPVT和Swin Transformer等较早的金字塔模型中patch embedding是没有overlap的,patch size为的patch embedding操作上等价于stride和kernel size均为
#patchembeddingstage1_conv_embed=nn.Sequential(nn.Conv2d(in_chans,embed_dim,7,4,2),Rearrange("bchw->b(hw)c",h=img_size//4,w=img_size//4),nn.LayerNorm(embed_dim))#patchmergingclassMerge_Block(nn.Module):def__init__(self,dim,dim_out,norm_layer=nn.LayerNorm):super().__init__()self.conv=nn.Conv2d(dim,dim_out,3,2,1)self.norm=norm_layer(dim_out)defforward(self,x):B,new_HW,C=x.shapeH=W=int(np.sqrt(new_HW))x=x.transpose(-2,-1).contiguous().view(B,C,H,W)x=self.conv(x)B,C=x.shape[:2]x=x.view(B,C,-1).transpose(-2,-1).contiguous()x=self.norm(x)returnx
注意,这里的卷积都需要包含zero padding来保持和原来一样的输出大小。
Cross-Shaped Window Self-AttentionCSWin Transformer最核心的部分就是cross-shaped window self-attention,如下所示,首先将self-attention的mutil-heads均分成两组,一组做horizontal stripes self-attention,另外一组做vertical stripes self-attention。所谓horizontal stripes self-attention就是沿着H维度将tokens分成水平条状windows,对于输入为HxW的tokens,记每个水平条状window的宽度为
#对于水平attention,H_sp=sw,W_sp=W#对于竖直attention,H_sp=H,W_sp=swdefimg2windows(img,H_sp,W_sp):"""img:BCHW"""B,C,H,W=img.shapeimg_reshape=img.view(B,C,H//H_sp,H_sp,W//W_sp,W_sp)img_perm=img_reshape.permute(0,2,4,3,5,1).contiguous().reshape(-1,H_sp*W_sp,C)returnimg_permdefwindows2img(img_splits_hw,H_sp,W_sp,H,W):"""img_splits_hw:B"HWC"""B=int(img_splits_hw.shape[0]/(H*W/H_sp/W_sp))img=img_splits_hw.view(B,H//H_sp,W//W_sp,H_sp,W_sp,-1)img=img.permute(0,1,3,2,4,5).contiguous().view(B,H,W,-1)returnimg
两组self-attention是并行的,完成后将tokens的特征concat在一起,这样就构成了CSW self-attention,最终效果就是在十字形窗口内做attention,CSW self-attention的感受野要比常规的window attention的感受野更大。用公式表示的话就是:
相关阅读
-
世界热推荐:今晚7:00直播丨下一个突破...
今晚19:00,Cocos视频号直播马上点击【预约】啦↓↓↓在运营了三年... -
NFT周刊|Magic Eden宣布支持Polygon网...
Block-986在NFT这样的市场,每周都会有相当多项目起起伏伏。在过去... -
环球今亮点!头条观察 | DeFi的兴衰与...
在比特币得到机构关注之后,许多财务专家预测世界将因为加密货币的... -
重新审视合作,体育Crypto的可靠关系才能双赢
Block-987即使在体育Crypto领域,人们的目光仍然集中在FTX上。随着... -
简讯:前端单元测试,更进一步
前端测试@2022如果从2014年Jest的第一个版本发布开始计算,前端开发... -
焦点热讯:刘强东这波操作秀
近日,刘强东发布京东全员信,信中提到:自2023年1月1日起,逐步为...