您的位置:首页 >聚焦 >

何恺明一作MAE收录CVPR2022 oral!(附源码实现)

2022-04-26 19:49:52    来源:程序员客栈

点蓝色字关注“机器学习算法工程师”

设为星标,干货直达!

近日,FAIR的最新论文Masked Autoencoders Are Scalable Vision Learners(何恺明一作)提出了一种更简单有效的用于ViT无监督训练的方法MAE,并在ImageNet-1K数据集上的top-1 acc达到新的SOTA:87.8%(无额外训练数据)。自从ViT火了之后,一些研究者就开始尝试研究ViT的无监督学习,比如Mocov3用对比学习的方法无监督训练ViT,此外也有一些研究开始借鉴BERT中的MLM(masked language modeling)方法,比如BEiT提出了用于图像的无监督学习方法:MIM(masked image modeling)。无疑,MAE方法也落在MIM的范畴,但整个论文会给人更震撼之感,因为MAE方法更简单有效。

NLP领域的BERT提出的预训练方法本质上也是一种masked autoencoding:去除数据的一部分然后学习恢复。这种masked autoencoding方法也很早就在图像领域应用,比如Stacked Denoising Autoencoders。但是NLP领域已经在BERT之后采用这种方法在无监督学习上取得非常大的进展,比如目前已经可以训练超过1000亿参数的大模型,但是图像领域却远远落后,而且目前主流的无监督训练还是对比学习。那么究竟是什么造成了masked autoencoding方法在NLP和CV上的差异呢?MAE论文从三个方面做了分析,这也是MAE方法的立意:

图像的主流模型是CNN,而NLP的主流模型是transformer,CNN和transformer的架构不同导致NLP的BERT很难直接迁移到CV。但是vision transformer的出现已经解决这个问题;图像和文本的信息密度不同,文本是高语义的人工创造的符号,而图像是一种自然信号,两者采用masked autoencoding建模任务难度就不一样,从句子中预测丢失的词本身就是一种复杂的语言理解任务,但是图像存在很大的信息冗余,一个丢失的图像块很容易利用周边的图像区域进行恢复;用于重建的decoder在图像和文本任务发挥的角色有区别,从句子中预测单词属于高语义任务,encoder和decoder的gap小,所以BERT的decoder部分微不足道(只需要一个MLP),而对图像重建像素属于低语义任务(相比图像分类),encoder需要发挥更大作用:将高语义的中间表征恢复成低语义的像素值。

基于这三个的分析,论文提出了一种用于图像领域(ViT模型)的更简单有效的无监督训练方法:MAE(masked autoencoder),随机mask掉部分patchs然后进行重建,其整体架构如下所示。MAE采用encoder-decoder结构(分析3,需要单独的decoder),但属于非对称结构,一方面decoder采用比encoder更轻量级设计,另外一方面encoder只处理一部分patchs(visible patchs,除了masked patchs之外的patchs),而decoder处理所有的patchs。一个很重要的点,MAE采用很高的masking ratio(比如75%甚至更高),这契合分析2,这样构建的学习任务大大降低了信息冗余,也使得encoder能学习到更高级的特征。由于encoder只处理visible patchs,所以很高的masking ratio可以大大降低计算量。

MAE采用的masking策略是简单的随机mask:基于均匀分布从图像的patchs随机抽样一部分patchs进行mask。每个被mask的patch采用mask token来替代,mask token是一个共享且可学习的向量。MAE的encoder采用ViT模型,只处理visible patchs,visible patchs通过linear projection得到patch embedding输入到ViT的transformer blocks进行处理;而decoder是一个轻量级模块,主体包含几个transformer blocks,而最后一层是一个linear层(输出是和一个patch像素数一致),用来直接预测masked patch的像素值。decoder的输入是所有的tokens:encoded visible patchs和mask tokens,它们要加上对应的positional embeddings。训练的loss采用简单的MSE:计算预测像素值和原始像素值的均方误差,不过loss只计算masked patchs。MAE的实现非常简单:首先对输入的patch进行linear projection得到patch embeddings,并加上positional embeddings(采用sine-cosine版本);然后对tokens列表进行random shuffle,根据masking ratio去掉列表中后面的一部分tokens,然后送入encoder中,这里注意ViT中需要一个class token来做图像分类,所以这里的输入也要增加一个dummy token(如果最后分类采用global avg pooling就不需要这个);encoder处理后,在tokens列表后面补足mask tokens,然后通过unshuffle来恢复tokens列表中tokens的原始位置,然后再加上positional embeddings(mask tokens本身并无位置信息,所以还要此操作)送入decoder中进行处理。具体的代码实现如下:

classMaskedAutoencoderViT(nn.Module):"""MaskedAutoencoderwithVisionTransformerbackbone"""def__init__(self,img_size=224,patch_size=16,in_chans=3,embed_dim=1024,depth=24,num_heads=16,decoder_embed_dim=512,decoder_depth=8,decoder_num_heads=16,mlp_ratio=4.,norm_layer=nn.LayerNorm,norm_pix_loss=False):super().__init__()#--------------------------------------------------------------------------#MAEencoderspecificsself.patch_embed=PatchEmbed(img_size,patch_size,in_chans,embed_dim)num_patches=self.patch_embed.num_patchesself.cls_token=nn.Parameter(torch.zeros(1,1,embed_dim))self.pos_embed=nn.Parameter(torch.zeros(1,num_patches+1,embed_dim),requires_grad=False)#fixedsin-cosembeddingself.blocks=nn.ModuleList([Block(embed_dim,num_heads,mlp_ratio,qkv_bias=True,qk_scale=None,norm_layer=norm_layer)foriinrange(depth)])self.norm=norm_layer(embed_dim)#--------------------------------------------------------------------------#--------------------------------------------------------------------------#MAEdecoderspecificsself.decoder_embed=nn.Linear(embed_dim,decoder_embed_dim,bias=True)self.mask_token=nn.Parameter(torch.zeros(1,1,decoder_embed_dim))self.decoder_pos_embed=nn.Parameter(torch.zeros(1,num_patches+1,decoder_embed_dim),requires_grad=False)#fixedsin-cosembeddingself.decoder_blocks=nn.ModuleList([Block(decoder_embed_dim,decoder_num_heads,mlp_ratio,qkv_bias=True,qk_scale=None,norm_layer=norm_layer)foriinrange(decoder_depth)])self.decoder_norm=norm_layer(decoder_embed_dim)self.decoder_pred=nn.Linear(decoder_embed_dim,patch_size**2*in_chans,bias=True)#encodertodecoder#--------------------------------------------------------------------------self.norm_pix_loss=norm_pix_lossself.initialize_weights()defpatchify(self,imgs):"""imgs:(N,3,H,W)x:(N,L,patch_size**2*3)"""p=self.patch_embed.patch_size[0]assertimgs.shape[2]==imgs.shape[3]andimgs.shape[2]%p==0h=w=imgs.shape[2]//px=imgs.reshape(shape=(imgs.shape[0],3,h,p,w,p))x=torch.einsum("nchpwq->nhwpqc",x)x=x.reshape(shape=(imgs.shape[0],h*w,p**2*3))returnxdefunpatchify(self,x):"""x:(N,L,patch_size**2*3)imgs:(N,3,H,W)"""p=self.patch_embed.patch_size[0]h=w=int(x.shape[1]**.5)asserth*w==x.shape[1]x=x.reshape(shape=(x.shape[0],h,w,p,p,3))x=torch.einsum("nhwpqc->nchpwq",x)imgs=x.reshape(shape=(x.shape[0],3,h*p,h*p))returnimgsdefrandom_masking(self,x,mask_ratio):"""Performper-samplerandommaskingbyper-sampleshuffling.Per-sampleshufflingisdonebyargsortrandomnoise.x:[N,L,D],sequence"""N,L,D=x.shape#batch,length,dimlen_keep=int(L*(1-mask_ratio))noise=torch.rand(N,L,device=x.device)#noisein[0,1]#sortnoiseforeachsampleids_shuffle=torch.argsort(noise,dim=1)#ascend:smalliskeep,largeisremoveids_restore=torch.argsort(ids_shuffle,dim=1)#keepthefirstsubsetids_keep=ids_shuffle[:,:len_keep]x_masked=torch.gather(x,dim=1,index=ids_keep.unsqueeze(-1).repeat(1,1,D))#generatethebinarymask:0iskeep,1isremovemask=torch.ones([N,L],device=x.device)mask[:,:len_keep]=0#unshuffletogetthebinarymaskmask=torch.gather(mask,dim=1,index=ids_restore)returnx_masked,mask,ids_restoredefforward_encoder(self,x,mask_ratio):#embedpatchesx=self.patch_embed(x)#addposembedw/oclstokenx=x+self.pos_embed[:,1:,:]#masking:length->length*mask_ratiox,mask,ids_restore=self.random_masking(x,mask_ratio)#appendclstokencls_token=self.cls_token+self.pos_embed[:,:1,:]cls_tokens=cls_token.expand(x.shape[0],-1,-1)x=torch.cat((cls_tokens,x),dim=1)#applyTransformerblocksforblkinself.blocks:x=blk(x)x=self.norm(x)returnx,mask,ids_restoredefforward_decoder(self,x,ids_restore):#embedtokensx=self.decoder_embed(x)#appendmasktokenstosequencemask_tokens=self.mask_token.repeat(x.shape[0],ids_restore.shape[1]+1-x.shape[1],1)x_=torch.cat([x[:,1:,:],mask_tokens],dim=1)#noclstokenx_=torch.gather(x_,dim=1,index=ids_restore.unsqueeze(-1).repeat(1,1,x.shape[2]))#unshufflex=torch.cat([x[:,:1,:],x_],dim=1)#appendclstoken#addposembedx=x+self.decoder_pos_embed#applyTransformerblocksforblkinself.decoder_blocks:x=blk(x)x=self.decoder_norm(x)#predictorprojectionx=self.decoder_pred(x)#removeclstokenx=x[:,1:,:]returnxdefforward_loss(self,imgs,pred,mask):"""imgs:[N,3,H,W]pred:[N,L,p*p*3]mask:[N,L],0iskeep,1isremove,"""target=self.patchify(imgs)ifself.norm_pix_loss:mean=target.mean(dim=-1,keepdim=True)var=target.var(dim=-1,keepdim=True)target=(target-mean)/(var+1.e-6)**.5loss=(pred-target)**2loss=loss.mean(dim=-1)#[N,L],meanlossperpatchloss=(loss*mask).sum()/mask.sum()#meanlossonremovedpatchesreturnlossdefforward(self,imgs,mask_ratio=0.75):latent,mask,ids_restore=self.forward_encoder(imgs,mask_ratio)pred=self.forward_decoder(latent,ids_restore)#[N,L,p*p*3]loss=self.forward_loss(imgs,pred,mask)returnloss,pred,mask

论文选择ViT-Large(ViT-L/16)作为encoder在ImageNet-1K上实验,首先进行无监督预训练,然后进行监督训练以评估encoder的表征能力,包括常用linear probing和finetune两个实验结果。下表是baseline MAE方法的实验结果,可以看到经过MAE预训练后finetune的效果要超过直接从头训练(84.9 vs 82.5):更重要的是,论文做了MAE各个部分的不同设置对比实验,这些实验能够揭示MAE更多的特性。首先是masking ratio,从下图可以看到,最优的设置是75%的masking ratio,此时linear probing和finetune效果最好,这比之前的研究要高很多,比如BEiT的masking ratio是40%。另外也可以看到linear probing和finetune的表现不一样,linear probing效果随着masking ratio的增加逐渐提高直至一个峰值后出现下降,而finetune效果在不同making ratio下差异小,masking ratio在40%~80%范围内均能表现较好。

这么高的masking ratio,模型到底能学习到什么?这里采用预训练好的模型在验证集进行重建,效果如下所示,可以看到decoder重建出来的图像还是比较让人惊艳的(95%的masking ratio竟然也能work!),这或许说明模型已经学习到比较好的特征。第二个是encoder的设计,这里主要探讨decoder的深度(transformer blocks数量)和宽度(channels数量)对效果的影响,实验结果如下表所示。首先,要想得到比较好的linear probing效果,就需要一个比较深的decoder,这不难理解,前面说过重建图像和图像识别两个任务的gap较大,如果decoder比较深,那么decoder就有足够的容量学习到重建能力,这样encoder可以更专注于提取特征。但是不同的深度对finetune效果影响较小,只用一个transformer block就可以work。相比之下,网络宽度对linear probing影响比网络深度要小一点。论文选择的默认设置是:8个blocks,width为512,一个token的FLOPs只有encoder的9%。第三个是mask token,这里探讨的是encoder是否处理mask tokens带来的影响,从对比实验来看,encoder不处理mask tokens不仅效果更好而且训练更高效,首先linear probing的效果差异非常大,如果encoder也处理mask tokens,此时linear probing的效果较差,这主要是训练和测试的不一致带来的,因为测试时都是正常的图像,但经过finetune后也能得到较好的效果。最重要的是,不处理mask tokens模型的FLOPs大大降低(3.3x),而且训练也能加速2.8倍,这里也可以看到采用较小的decoder可以进一步加速训练。

第四个是探讨不同的重建目标对效果的影响,从对比实验看,如果对像素值做归一化处理(用patch所有像素点的MAEn和std),效果有一定提升,采用PCA处理效果无提升。这里也实验了BEiT采用的dVAE tokenizer,此时训练loss是交叉熵,从效果上看比baseline有一定提升(finetune有提升,但是linear probing下降),但不如归一化处理的结果。注意的是dVAE tokenizer需要非常大的数据来单独训练,这是非常不方便的。第五个是数据增强的影响,这里让人惊奇的是MAE在无数据增强下(center crop)依然可以表现出好的效果,如果采用random crop(固定size或随机size)+random horizontal flipping(其实也属于轻量级)效果有微弱的提升,但加上color jit效果反而有所下降。相比之下,对比学习往往需要非常heavy的数据增强。这差异的背后主要是因为MAE采用的random mask patch已经起到了数据增强的效果。

第六个是mask sampling策略的影响,相比BEiT采用的block-wise或grid-wise方式,random sampling效果最好。

另外,论文也发现MAE和对比学习方法在training schedule上也存在差异,之前的实验都是基于800 epoch的训练时长,而实验发现训练到更长的epoch(1600 epoch+),模型的linear probing性能依然还在上升,而MoCoV3在300 epoch后就饱和了。不过,MAE在75%的masking ratio下每个epoch其实只相当于见了25%的数据,而对比学习往往学习two-crop和multi-crop,每个epoch见到的数据在200%以上,这也意味着MAE可以训练更多的epoch。虽然MAE训练更长,但是由于其特殊的设置,基于ViT-L的MAE训练1600 epoch的时长比MoCoV3训练300 epoch还要短(31h vs 36h)。

image.png

MAE与其它无监督方法的对比如下所示,可以看到在同样条件下MAE要比BEiT更好,而且也超过有监督训练,其中ViT-H在448大小finetune后在ImageNet上达到了87.8%的top1 acc。不过MAE的效果还是比谷歌采用JFT300M训练的ViT要差一些,这说明训练数据量可能是一个瓶颈。在linear probing方面,MAE要比其它的MIMI方法要好很多,前面已经说过,这主要归功于encoder不处理mask tokens。在鲁棒性方面,论文测试了几种ImageNet数据集的变种,从下表可以看到,相比直接有监督训练模型,基于MAE先预训练再finetune的模型鲁棒性更好。比如在ImageNet-A数据集上,基于MAE的ViT-H模型的top1-acc远高于有监督模型(68.2% vs 33.1%)。同时,论文也对比了MAE训练的encoder在下游任务(检测和分割)的迁移能力,同等条件下,MAE均能超过有监督训练或者其它无监督训练方法:这里要注意的一点是检测和分割模型需要多尺度的特征(即FPN),而ViT模型只输出一种尺度的特征(比如1/16大小特征),这里采用XCiT所提出的一种简单策略来产生多尺度特征,即对ViT的中间特征进行上采样和下采样。这里以Mask R-CNN模型为例,它需要提出backbone的1/4,1/8,1/16和1/32共4个level的特征,而ViT16只输出1/16的特征,这里将ViT的transformer blocks均分成4个部分,假定d为ViT的blocks数量,那么分别用位置为d/4,2d/4,3d/4和d的block的输出来提取特征,这里位置为d/4的block的输出需要上采样4x才能得到1/4大小的特征,可以通过两个stride=2的2x2反卷积操作来实现(第一个反卷积后接GN和GeLU),而位置为2d/4的block的输出只需要一个stride=2的2x2反卷积就能得到1/8大小的特征,对于位置为3d/4的block的输出则不需要任何操作,最后一个block的输出可以通过stride=2的2x2 max-pooling来产生1/32特征。(具体见论文Benchmarking Detection Transfer Learning with Vision Transformers)

论文最后还有一个额外的部分,那就是对linear probing评估方式的讨论。从前面的实验我们看到,虽然MAE训练的encoder在finetune下能取得比较SOTA的结果,但是其linear probing和finetune效果存在不小的差异,单从linear probing效果来看,MAE并不比MoCoV3要好(ViT-L:73.5 vs 77.6)。虽然linear probing一直是无监督训练的最常用的评估方法,但是它追求的是encoder提取特征的线性可分能力,这不并能成为唯一的一个评价指标,而且linear probing也不能很好地和下游任务迁移能力关联起来。所以论文额外做了partial fine-tuning的实验,这里可以看到如果仅对encoder的最后一个block进行finetune的话,MAE就能达到和MoCoV3一样的效果,如果finetune更多的blocks,MAE就会超过MoCoV3。这说明虽然MAE得到的特征线性可分能力差了点,但是它其实是更强的非线性特征。

最后谈一点自己对MAE的认识:首先MAE并不是第一个基于MIM方法做无监督训练,之前微软的BEiT基于MIM也取得了很好的效果,还有MST和iBOT等工作。但是MAE让人看起来更简单有效,比如BEiT需要单独训练的tokenizer,而其它的一些工作往往引入了对比学习的类似设计。对于MAE的成功,我觉得是一些突破常规的设计,比如很高的masking ratio,这是很难想象会work的,但MAE却证明了这是成功的关键。

参考Mocov3: An Empirical Study of Training Self-Supervised Vision TransformersDINO: Emerging Properties in Self-Supervised Vision TransformersMST: Masked Self-Supervised Transformer for Visual RepresentationBEiT: BERT Pre-Training of Image TransformersEsViT: Efficient Self-supervised Vision Transformers for Representation LearningImage BERT Pre-training with Online TokenizerMasked Autoencoders Are Scalable Vision Learnershttps://github.com/facebookresearch/mae

关键词: 无监督学习 实验结果 也可以看到

相关阅读