Source code for optim.base_optim

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse

from utils import logger


[docs]class BaseOptim(object): """Base class for optimizer Args: opts: Command-line arguments """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: self.eps = 1e-8 self.lr = getattr(opts, "scheduler.lr") self.weight_decay = getattr(opts, "optim.weight_decay")
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add optimizer arguments""" if cls != BaseOptim: # Don't re-register arguments in subclasses that don't override `add_arguments()`. return parser group = parser.add_argument_group(cls.__name__) group.add_argument( "--optim.name", type=str, default="sgd", help="Name of the optimizer. Defaults to SGD.", ) group.add_argument( "--optim.eps", type=float, default=1e-8, help="Optimizer epsilon value. Defaults to 1.e-8.", ) group.add_argument( "--optim.weight-decay", default=4e-5, type=float, help="Weight decay (or L2 penalty). Defaults to 4.e-5.", ) group.add_argument( "--optim.no-decay-bn-filter-bias", action="store_true", default=False, help="When enabled, the weight in normalization layers and biases in the model are not decayed." "Defaults to False.", ) group.add_argument( "--optim.bypass-parameters-check", action="store_true", default=False, help="Bypass parameter check when creating optimizer. Defaults to False", ) return parser
def __repr__(self) -> str: group_dict = dict() for i, group in enumerate(self.param_groups): for key in sorted(group.keys()): if key == "params": continue if key not in group_dict: group_dict[key] = [group[key]] else: group_dict[key].append(group[key]) format_string = self.__class__.__name__ + " (" format_string += "\n" for k, v in group_dict.items(): format_string += "\t {0}: {1}\n".format(k, v) format_string += ")" return format_string