Source code for optim

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import Dict, List

import torch.nn

from optim.base_optim import BaseOptim
from utils import logger
from utils.common_utils import unwrap_model_fn
from utils.registry import Registry

OPTIM_REGISTRY = Registry(
    registry_name="optimizer_registry",
    base_class=BaseOptim,
    lazy_load_dirs=["optim"],
    internal_dirs=["internal", "internal/projects/*"],
)


[docs]def check_trainable_parameters(model: torch.nn.Module, model_params: List) -> None: """Helper function to check if any model parameters w/ gradients are not part of model_params""" # get model parameter names model_trainable_params = [] for p_name, param in model.named_parameters(): if param.requires_grad: model_trainable_params.append(p_name) initialized_params = [] for param_info in model_params: if not isinstance(param_info, Dict): logger.error( "Expected format is a Dict with three keys: params, weight_decay, param_names" ) if not {"params", "weight_decay", "param_names"}.issubset(param_info.keys()): logger.error( "Parameter dict should have three keys: params, weight_decay, param_names" ) param_names = param_info.pop("param_names") if isinstance(param_names, List): initialized_params.extend(param_names) elif isinstance(param_names, str): initialized_params.append(param_names) else: raise NotImplementedError uninitialized_params = set(model_trainable_params) ^ set(initialized_params) if len(uninitialized_params) > 0: logger.error( "Following parameters are defined in the model, but won't be part of optimizer. " "Please check get_trainable_parameters function. " "Use --optim.bypass-parameters-check flag to bypass this check. " "Parameter list = {}".format(uninitialized_params) )
[docs]def remove_param_name_key(model_params: List) -> None: """Helper function to remove param_names key from model_params""" for param_info in model_params: if not isinstance(param_info, Dict): logger.error( "Expected format is a Dict with three keys: params, weight_decay, param_names" ) if not {"params", "weight_decay", "param_names"}.issubset(param_info.keys()): logger.error( "Parameter dict should have three keys: params, weight_decay, param_names" ) param_info.pop("param_names")
[docs]def build_optimizer(model: torch.nn.Module, opts, *args, **kwargs) -> BaseOptim: """Helper function to build an optimizer Args: model: A model opts: command-line arguments Returns: An instance of BaseOptim """ optim_name = getattr(opts, "optim.name") weight_decay = getattr(opts, "optim.weight_decay") no_decay_bn_filter_bias = getattr(opts, "optim.no_decay_bn_filter_bias") unwrapped_model = unwrap_model_fn(model) model_params, lr_mult = unwrapped_model.get_trainable_parameters( weight_decay=weight_decay, no_decay_bn_filter_bias=no_decay_bn_filter_bias, *args, **kwargs ) # check to ensure that all trainable model parameters are passed to the model if not getattr(opts, "optim.bypass_parameters_check", False): check_trainable_parameters(model=unwrapped_model, model_params=model_params) else: remove_param_name_key(model_params=model_params) # set the learning rate multiplier for each parameter setattr(opts, "optim.lr_multipliers", lr_mult) return OPTIM_REGISTRY[optim_name](opts, model_params, *args, **kwargs)
[docs]def arguments_optimizer(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser = BaseOptim.add_arguments(parser) parser = OPTIM_REGISTRY.all_arguments(parser) return parser