您的位置:首页 >聚焦 > AugMix:我比MixUp要强! 2022-03-03 07:33:31 来源:程序员客栈 点蓝色字关注“机器学习算法工程师”设为星标,干货直达!AugMix是DeepMind和谷歌在2019年提出的一种数据增强方法,它可以提升图像分类模型的鲁棒性(robustness)和不确定性估计(uncertainty estimates),而且非常容易嵌入目前的训练流程中。AugMix的原理很简单:随机对图像进行不同的数据增强(Aug),然后混合(Mix)多个数据增强后的图像;同时在分类器上施加对同一图像的不同增强后的一致性约束。本文将简单介绍AugMix的原理以及具体的代码实现。数据偏移(data shift)即训练和测试数据分布不一致是深度学习模型面临的挑战之一,因为实际中训练数据往往只是有限的子集,模型在未见过的数据上往往表现较差。常见的提高模型的鲁棒性和泛化性的方法就是通过数据增强产生更丰富的训练样本,如上图所示的CutOut,MixUp和CutMix,这里要介绍的AugMix也属于一种数据增强,其算法的伪代码如下所示:对于AugMix,首先需要定义一系列的数据增强操作,这里采用AutoAugment中的数据增强(autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y,这里去除了contrast, color, brightness, sharpness,因为它们可能和ImageNet-C测试集的变换重合),每种数据增强可以采用不同的增强幅度(即论文中的所说的severity)。下图为9种数据增强在同一张图像的应用效果:AugMix随机选择不同的数据增强,然后混合不同增强后的图像,具体的步骤如下:用zero初始化一个和原始图像一样大小的图像;并根据Dirichlet分布随机生成个权重,用来混合不同的图像;从定义的数据增强中随机选择3个:,用这3个op组合出不同深度的操作,其中为,而,这样为深度分别是1,2,3的增强组合链,随机选择其中一个组合链来增强图像,并混合增强后的图像:;多步骤2重复次,这相当于混合了个不同增强后的图像;根据beta分布随机生成权重,然后将上述得到的图像和原始图像进行混合:。这个过程涉及到两个参数:和。其中为要混合的增强图像次数,默认为3;而为分布的控制参数,默认为1。下图为AugMix的一个具体应用实例:这里选择的三个增强组合链分别是translate_x+shear_y, rotate, posterize+equalize+posterize,它们的深度分别是2,1,3,混合权重分别为0.12,0.2和0.68;而最后混合原图的权重为0.2。可以看出AugMix其实是混合同一个图像的经过不同数据增强得到的图像,而CutMix和MixUp是混合两个不同的图像。AugMix由于是混合同一个图像,相比CutMix和MixUp,其得到的图像更自然一些;如果只是级联不同的数据增强,最后得到的图像往往偏离原图,变得失真(如下图所示),而AugMix通过混合不同增强的图像可以减轻这个问题,而且不损失图像多样性。除了对单纯的数据增强,论文还提出了额外增加一个JS散度一致性损失(Jensen-Shannon Divergence Consistency Loss)来进一步提升模型的稳定性。具体地,对图像做两次AugMix,得到两个不同的增强图像和,然后最小化它们和原始图像的概率分布(模型预测的分类概率)的的JS散度,这样分类模型的训练损失变为:其中第一项是原始的分类损失,即交叉熵;第二项为一致性约束损失,为损失权重。关于JS散度,可以参考这里,JS散度是基于KL散度的,不过JS散度是对称的,而且存在上限。要计算上述三个分布的JS散度,首先要计算三个分布的平均值:,那么JS散度为:这里我们是取两个AugMix来计算一致性约束,论文发现直接做一次即计算效果并不好;而取三次AugMix没有更大的收益,即计算。一致性约束在半监督学习中常见的方法:对一张未标注的图像做两次不同的数据增强,然后施加一致性约束,使得模型的输出尽量一致。对于模型的鲁棒性评估,可以通过评估模型在其它分布的测试集上的分类误差,下图为AugMix与其它方法在CIFAR-10-C测试集上(通过对标准测试集图像做15种corruptions)的分类误差对比,可以看到,AugMix相比其它方法可取得更小的分类误差,也更接近标准测试集上结果(clean error)。对于模型的不确定性评估,可以计算模型的校准误差(Calibration Error),一个理想的分类器其准确度应该和置信度(预测概率)是一致的,比如模型预测的置信度为0.7的样本,它们的准确度应该为70%,这说明模型能够输出比较可靠的置信度,但实际的分类器往往出现overconfident或者underconfident,所以需要校准。对于校准误差,一般计算RMS Calibration Error,即模型在不同的置信度下实际准确度和预期准确度的平方根误差,误差越小,说明模型预测越可靠。下图为AugMix与其它方法在CIFAR-10-C测试集上的校准误差对比,可以看到AugMix有更小的校准误差。目前谷歌已经开源了AugMix的代码,而且torchvision最近也增加对AugMix的支持,这里给出torchvision的实现:classAugMix(torch.nn.Module):r"""AugMixdataaugmentationmethodbasedon`"AugMix:ASimpleDataProcessingMethodtoImproveRobustnessandUncertainty"`_.IftheimageistorchTensor,itshouldbeoftypetorch.uint8,anditisexpectedtohave[...,1or3,H,W]shape,where...meansanarbitrarynumberofleadingdimensions.IfimgisPILImage,itisexpectedtobeinmode"L"or"RGB".Args:severity(int):Theseverityofbaseaugmentationoperators.Defaultis``3``.mixture_width(int):Thenumberofaugmentationchains.Defaultis``3``.chain_depth(int):Thedepthofaugmentationchains.Anegativevaluedenotesstochasticdepthsampledfromtheinterval[1,3].Defaultis``-1``.alpha(float):Thehyperparameterfortheprobabilitydistributions.Defaultis``1.0``.all_ops(bool):Usealloperations(includingbrightness,contrast,colorandsharpness).Defaultis``True``.interpolation(InterpolationMode):Desiredinterpolationenumdefinedby:class:`torchvision.transforms.InterpolationMode`.Defaultis``InterpolationMode.NEAREST``.IfinputisTensor,only``InterpolationMode.NEAREST``,``InterpolationMode.BILINEAR``aresupported.fill(sequenceornumber,optional):Pixelfillvaluefortheareaoutsidethetransformedimage.Ifgivenanumber,thevalueisusedforallbandsrespectively."""def__init__(self,severity:int=3,#数据增强的最大幅度mixture_width:int=3,#k,即混合的增强图像数量,默认是3chain_depth:int=-1,#增强组合链的深度,设置为-1表示从[1,3]随机选择alpha:float=1.0,#分布的参数all_ops:bool=True,interpolation:InterpolationMode=InterpolationMode.BILINEAR,fill:Optional[List[float]]=None,)->None:super().__init__()self._PARAMETER_MAX=10ifnot(1<=severity<=self._PARAMETER_MAX):raiseValueError(f"Theseveritymustbebetween[1,{self._PARAMETER_MAX}].Got{severity}instead.")self.severity=severityself.mixture_width=mixture_widthself.chain_depth=chain_depthself.alpha=alphaself.all_ops=all_opsself.interpolation=interpolationself.fill=fill#数据增强空间def_augmentation_space(self,num_bins:int,image_size:List[int])->Dict[str,Tuple[Tensor,bool]]:s={#op_name:(magnitudes,signed)"ShearX":(torch.linspace(0.0,0.3,num_bins),True),"ShearY":(torch.linspace(0.0,0.3,num_bins),True),"TranslateX":(torch.linspace(0.0,image_size[0]/3.0,num_bins),True),"TranslateY":(torch.linspace(0.0,image_size[1]/3.0,num_bins),True),"Rotate":(torch.linspace(0.0,30.0,num_bins),True),"Posterize":(4-(torch.arange(num_bins)/((num_bins-1)/4)).round().int(),False),"Solarize":(torch.linspace(255.0,0.0,num_bins),False),"AutoContrast":(torch.tensor(0.0),False),"Equalize":(torch.tensor(0.0),False),}ifself.all_ops:s.update({"Brightness":(torch.linspace(0.0,0.9,num_bins),True),"Color":(torch.linspace(0.0,0.9,num_bins),True),"Contrast":(torch.linspace(0.0,0.9,num_bins),True),"Sharpness":(torch.linspace(0.0,0.9,num_bins),True),})returns@torch.jit.unuseddef_pil_to_tensor(self,img)->Tensor:returnF.pil_to_tensor(img)@torch.jit.unuseddef_tensor_to_pil(self,img:Tensor):returnF.to_pil_image(img)def_sample_dirichlet(self,params:Tensor)->Tensor:#Mustbeonaseparatemethodsothatwecanoverwriteitintests.returntorch._sample_dirichlet(params)defforward(self,orig_img:Tensor)->Tensor:"""img(PILImageorTensor):Imagetobetransformed.Returns:PILImageorTensor:Transformedimage."""fill=self.fillifisinstance(orig_img,Tensor):img=orig_imgifisinstance(fill,(int,float)):fill=[float(fill)]*F.get_image_num_channels(img)eliffillisnotNone:fill=[float(f)forfinfill]else:img=self._pil_to_tensor(orig_img)op_meta=self._augmentation_space(self._PARAMETER_MAX,F.get_image_size(img))orig_dims=list(img.shape)batch=img.view([1]*max(4-img.ndim,0)+orig_dims)batch_dims=[batch.size(0)]+[1]*(batch.ndim-1)#随机生成m,即原始x_ori和x_aug的混合权重m=self._sample_dirichlet(torch.tensor([self.alpha,self.alpha],device=batch.device).expand(batch_dims[0],-1))#随机生成w,即混合不同的增强图像的权重,这里乘以m[:,1]以直接考虑了最后的混合combined_weights=self._sample_dirichlet(torch.tensor([self.alpha]*self.mixture_width,device=batch.device).expand(batch_dims[0],-1))*m[:,1].view([batch_dims[0],-1])#初始化mix,这里直接用原始图像乘以混合权重mix=m[:,0].view(batch_dims)*batch#随机生成不同的增强foriinrange(self.mixture_width):aug=batch#随机采样深度depth=self.chain_depthifself.chain_depth>0elseint(torch.randint(low=1,high=4,size=(1,)).item())for_inrange(depth):op_index=int(torch.randint(len(op_meta),(1,)).item())op_name=list(op_meta.keys())[op_index]magnitudes,signed=op_meta[op_name]magnitude=(float(magnitudes[torch.randint(self.severity,(1,),dtype=torch.long)].item())ifmagnitudes.ndim>0else0.0)ifsignedandtorch.randint(2,(1,)):magnitude*=-1.0aug=_apply_op(aug,op_name,magnitude,interpolation=self.interpolation,fill=fill)#混合得到的增强图像mix.add_(combined_weights[:,i].view(batch_dims)*aug)mix=mix.view(orig_dims).to(dtype=img.dtype)ifnotisinstance(orig_img,Tensor):returnself._tensor_to_pil(mix)returnmix对于一致性约束损失,那就比较容易了,按照公式计算即可:#原始图像的交叉熵loss=F.cross_entropy(logits_clean,targets)p_clean,p_aug1,p_aug2=F.softmax(logits_clean,dim=1),F.softmax(logits_aug1,dim=1),F.softmax(logits_aug2,dim=1)#计算平均值p_mixture=torch.clamp((p_clean+p_aug1+p_aug2)/3.,1e-7,1).log()loss+=12*(F.kl_div(p_mixture,p_clean,reduction="batchmean")+F.kl_div(p_mixture,p_aug1,reduction="batchmean")+F.kl_div(p_mixture,p_aug2,reduction="batchmean"))/3.参考AugMix: A Simple Data Processing Method to Improve Robustness and Uncertaintyhttps://github.com/google-research/augmixhttps://github.com/pytorch/vision/blob/main/torchvision/transforms/autoaugment.py 关键词: 一致性约束 校准误差 随机选择 相关阅读 世界热推荐:今晚7:00直播丨下一个突破... 今晚19:00,Cocos视频号直播马上点击【预约】啦↓↓↓在运营了三年... NFT周刊|Magic Eden宣布支持Polygon网... Block-986在NFT这样的市场,每周都会有相当多项目起起伏伏。在过去... 环球今亮点!头条观察 | DeFi的兴衰与... 在比特币得到机构关注之后,许多财务专家预测世界将因为加密货币的... 重新审视合作,体育Crypto的可靠关系才能双赢 Block-987即使在体育Crypto领域,人们的目光仍然集中在FTX上。随着... 简讯:前端单元测试,更进一步 前端测试@2022如果从2014年Jest的第一个版本发布开始计算,前端开发... 焦点热讯:刘强东这波操作秀 近日,刘强东发布京东全员信,信中提到:自2023年1月1日起,逐步为...