Source code for cvnets.misc.averaging_utils

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from copy import deepcopy
from typing import Optional

import torch
from torch import nn


[docs]class EMA(object): """ For a given model, this class computes the exponential moving average of weights Args: model (torch.nn.Module): Model ema_momentum (Optional[float]): Momentum value shows the contribution of weights at current iteration. Default: 0.0005 device (Optional[str]): Device (CPU or GPU) on which model resides. Default: cpu """
[docs] def __init__( self, model: nn.Module, ema_momentum: Optional[float] = 0.0005, device: Optional[str] = "cpu", *args, **kwargs ) -> None: # make a deep copy of the model for accumulating moving average of parameters and set to eval mode self.ema_model = deepcopy(model) self.ema_model.eval() self.momentum = ema_momentum self.device = device if device: self.ema_model.to(device=device) self.ema_has_module = hasattr(self.ema_model, "module") for param in self.ema_model.parameters(): param.requires_grad = False
[docs] def update_parameters(self, model): # correct a mismatch in state dict keys has_module = hasattr(model, "module") and not self.ema_has_module with torch.no_grad(): msd = model.state_dict() for k, ema_v in self.ema_model.state_dict().items(): if has_module: # .module is added if we use DistributedDataParallel or DataParallel wrappers around model k = "module." + k model_v = msd[k].detach() if self.device: model_v = model_v.to(device=self.device) ema_v.copy_((ema_v * (1.0 - self.momentum)) + (self.momentum * model_v))
[docs]def arguments_ema(parser: argparse.ArgumentParser): group = parser.add_argument_group( title="EMA", description="Exponential moving average arguments" ) group.add_argument( "--ema.enable", action="store_true", help="Use exponential moving average" ) group.add_argument( "--ema.momentum", type=float, default=0.0001, help="EMA momentum. Defaults to 0.0001.", ) group.add_argument( "--ema.copy-at-epoch", type=int, default=-1, help="Update model weights with EMA model at this epoch", ) return parser