您的位置:首页 >聚焦 >

EMA在detectron2中的实现

2022-03-11 05:24:38    来源:程序员客栈

点蓝色字关注“机器学习算法工程师”

设为星标,干货直达!

近期很流行的一些检测模型如YOLOv5和YOLOX都包含了很多的tricks,如数据增强(MixUp, Mosaic)等,其中EMA也是一种常采用的trick。EMA全称为Exponential Moving Average,最早是在TensorFlow中出现(具体实现为tf.train.ExponentialMovingAverage),简单来说,在模型训练过程中对模型参数计算指数移动平均,得到的模型参数要比最后训练得到的模型参数在效果上可能要好一点。从某种意义上来看,EMA有点像模型集成,但是它在测试时不需要额外的负担,在训练过程只是多消耗一份显存(多一份模型参数)以及训练过程稍多一点开销(对参数进行移动平均,耗时很小)。

EMA的实现也很简单,对模型参数params只需要多维护一份参数ema_params就好,然后在每个训练step后,对每一个模型参数进行移动平均:

这里的decay是一个超参数,一般取值接近1,比如设置为0.999。可以看到EMA比较通用,几乎适用于任何模型训练中。

目前商汤开源的mmdet框架已经复现了YOLOX,里面也包含了EMA的实现。而目前Facebook AI的detectron2还没有包含EMA的实现,但是其移动端版本D2Go已经实现了EMA,两个版本其实是互通的,只有略微的差别。这里就讲一下如何将D2Go的EMA应用到detectron2中,这主要包括三个部分:模型中添加EMA参数、训练过程中进行更新以及测试时使用EMA参数。

EMA需要多维护一份模型参数,就是EMA参数,这里定义一个EMAState类来存储EMA参数,这个类里面的state字典存储EMA参数。这里的get_model_state_iterator方法是获得模型的参数,包括训练参数params以及buffers,BN的一些参数moving_mean和moving_var属于buffers,一般情况下对BN的moving_mean和moving_var也进行EMA效果会更好一点。

classEMAState(object):def__init__(self):self.state={}@classmethoddefFromModel(cls,model:torch.nn.Module,device:str=""):ret=cls()ret.save_from(model,device)returnretdefsave_from(self,model:torch.nn.Module,device:str=""):"""Savemodelstatefrom`model`tothisobject"""forname,valinself.get_model_state_iterator(model):val=val.detach().clone()self.state[name]=val.to(device)ifdeviceelsevaldefapply_to(self,model:torch.nn.Module):"""Applystateto`model`fromthisobject"""withtorch.no_grad():forname,valinself.get_model_state_iterator(model):assert(nameinself.state),f"Name{name}notexisted,availablenames{self.state.keys()}"val.copy_(self.state[name])defget_ema_model(self,model):ret=copy.deepcopy(model)self.apply_to(ret)returnret@propertydefdevice(self):ifnotself.has_inited():returnNonereturnnext(iter(self.state.values())).devicedefto(self,device):fornameinself.state:self.state[name]=self.state[name].to(device)returnselfdefhas_inited(self):returnself.statedefclear(self):self.state.clear()returnselfdefget_model_state_iterator(self,model):param_iter=model.named_parameters()buffer_iter=model.named_buffers()returnitertools.chain(param_iter,buffer_iter)defstate_dict(self):returnself.statedefload_state_dict(self,state_dict,strict:bool=True):self.clear()forx,yinstate_dict.items():self.state[x]=yreturntorch.nn.modules.module._IncompatibleKeys(missing_keys=[],unexpected_keys=[])def__repr__(self):ret=f"EMAState(state=[{",".join(self.state.keys())}])"returnret

这样在d2的Trainer中,创建model的同时也定义EMA,添加后model会多一个model_ema属性,它是EMAState的一个实例:

defmay_build_model_ema(cfg,model):ifnotcfg.MODEL_EMA.ENABLED:returnmodel=_remove_ddp(model)assertnothasattr(model,"ema_state"),"Name`ema_state`isreservedformodelema."model.ema_state=EMAState()#添加到model的属性中logger.info("UsingModelEMA.")classTrainer(DefaultTrainer):#overridebuild_model,在里面添加ema@classmethoddefbuild_model(cls,cfg):"""Returns:torch.nn.Module:Itnowcalls:func:`detectron2.modeling.build_model`.Overwriteitifyou"dlikeadifferentmodel."""model=build_model(cfg)logger=logging.getLogger(__name__)logger.info("Model:\n{}".format(model))#addmodelEMAifenabledmodel_ema.may_build_model_ema(cfg,model)returnmodel

上面实现了ema的添加,但是在训练后还需要保存ema参数,这可以通过d2的DetectionCheckpointer来实现,DetectionCheckpointer在创建时可以传入额外的checkpointable objects,在save和load时除了模型参数也会同步对这些objects进行保存和加载。checkpointable objects需要实现两个方法:state_dict()和load_state_dict(),而前面定义的EMAState类也包含了这两个方法,用于save和load对应的ema参数。具体的实现代码如下:

classTrainer(DefaultTrainer):def__init__(self,cfg):#addmodelEMAkwargs={"trainer":weakref.proxy(self),}kwargs.update(model_ema.may_get_ema_checkpointer(cfg,model))#添加ema到checkpointablesself.checkpointer=DetectionCheckpointer(#Assumeyouwanttosavecheckpointstogetherwithlogs/statisticsmodel,cfg.OUTPUT_DIR,**kwargs,)

上面完成了第一个部分,就是在模型中添加ema参数,第二个要做的工作就是实现ema参数在训练过程的更新,首先定义一个EMAUpdater,其中update方法用来进行一次ema更新:

classEMAUpdater(object):"""ModelExponentialMovingAverageKeepamovingaverageofeverythinginthemodelstate_dict(parametersandbuffers).Thisisintendedtoallowfunctionalitylikehttps://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverageNote:It"sveryimportanttosetEMAforALLnetworkparameters(insteadofparametersthatrequiregradient),includingbatch-normmovingaveragemeanandvariance.Thisleadstosignificantimprovementinaccuracy.Forexample,forEfficientNetB3,withdefaultsetting(nomixup,lrexponentialdecay)withoutbn_sync,theEMAaccuracywithEMAonparamsthatrequiresgradientis79.87%,whilethecorrespondingaccuracywithEMAonallparamsis80.61%.Also,bnsyncshouldbeswitchedonforEMA."""def__init__(self,state:EMAState,decay:float=0.999,device:str=""):self.decay=decayself.device=deviceself.state=statedefinit_state(self,model):self.state.clear()self.state.save_from(model,self.device)defupdate(self,model):withtorch.no_grad():forname,valinself.state.get_model_state_iterator(model):ema_val=self.state.state[name]ifself.device:val=val.to(self.device)#指数移动平均ema_val.copy_(ema_val*self.decay+val*(1.0-self.decay))

要实现训练过程中的更新,可以采用hook的方式,这里定义一个EMAHook,这里主要是在after_step方法中加入ema的update:

classEMAHook(HookBase):def__init__(self,cfg,model):model=_remove_ddp(model)assertcfg.MODEL_EMA.ENABLEDasserthasattr(model,"ema_state"),"Call`may_build_model_ema`firsttoinitilaizethemodelema"self.model=modelself.ema=self.model.ema_stateself.device=cfg.MODEL_EMA.DEVICEorcfg.MODEL.DEVICEself.ema_updater=EMAUpdater(self.model.ema_state,decay=cfg.MODEL_EMA.DECAY,device=self.device)defbefore_train(self):ifself.ema.has_inited():self.ema.to(self.device)else:self.ema_updater.init_state(self.model)defafter_train(self):passdefbefore_step(self):passdefafter_step(self):ifnotself.model.train:returnself.ema_updater.update(self.model)

然后把EMAHook加到trainer中的hooks里:

defbuild_hooks(self):"""Buildalistofdefaulthooks,includingtiming,evaluation,checkpointing,lrscheduling,preciseBN,writingevents.Returns:list[HookBase]:"""cfg=self.cfg.clone()cfg.defrost()cfg.DATALOADER.NUM_WORKERS=0#savesomememoryandtimeforPreciseBNret=[hooks.IterationTimer(),model_ema.EMAHook(self.cfg,self.model)ifcfg.MODEL_EMA.ENABLEDelseNone,#addEMAhookhooks.LRScheduler(),hooks.PreciseBN(#Runatthesamefreqas(butbefore)evaluation.cfg.TEST.EVAL_PERIOD,self.model,#Buildanewdataloadertonotaffecttrainingself.build_train_loader(cfg),cfg.TEST.PRECISE_BN.NUM_ITER,)ifcfg.TEST.PRECISE_BN.ENABLEDandget_bn_modules(self.model)elseNone,]

最后一个要实现的就是如何在测试时采用ema参数,这里采用的方法是每次进行test时,先将model参数保存一个副本,然后用ema参数替换,完成测试后再用保存的副本复原回来,在实现上,可以采用python的上下文管理器来巧妙地实现:

@contextmanagerdefapply_model_ema_and_restore(model,state=None):"""Applyemastoredin`model`tomodelandreturnsafunctiontorestoretheweightsareapplied"""model=_remove_ddp(model)ifstateisNone:state=get_model_ema_state(model)old_state=EMAState.FromModel(model,state.device)#创建当前模型参数副本state.apply_to(model)#用ema替换模型参数yieldold_stateold_state.apply_to(model)#恢复模型参数

用这个上下文管理器对test进行包装,就可以实现想要的效果了:

@classmethoddefdo_test(cls,cfg,model,evaluators=None):#modelwithemaweightslogger=logging.getLogger("detectron2")ifcfg.MODEL_EMA.ENABLED:logger.info("RunevaluationwithEMA.")withmodel_ema.apply_model_ema_and_restore(model):results=cls.test(cfg,model,evaluators=evaluators)else:results=cls.test(cfg,model,evaluators=evaluators)returnresults

完整的代码放在了github上,欢迎试用和star(https://github.com/xiaohu2015/detectron2_ema)。我初步用RetinaNet_R_50_FPN_1x测试的话,采用ema比原始效果要好一点(37.23 vs 37.18),而YOLOv5采用ema能提升1~2个点的。在YOLOv5中,ema的实现有一个额外的trick,那就是在训练前期,采用较小的decay,然后逐步增到默认值,因为前期模型训练速度快,应该对ema参数更新更激进一些,具体的实现如下:

self.decay=lambdax:decay*(1-math.exp(-x/2000))#decayexponentialramp(tohelpearlyepochs)

这个实现应该很容易在d2的EMA中添加,有时间再更新(mmdet的ema已经实现这个功能了)。

参考fvcored2goyolov5

关键词: 移动平均 有时间再

相关阅读