Source code for optim.scheduler.multi_step

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse

from optim.scheduler import SCHEDULER_REGISTRY
from optim.scheduler.base_scheduler import BaseLRScheduler


[docs]@SCHEDULER_REGISTRY.register("multi_step") class MultiStepLRScheduler(BaseLRScheduler): """ Multi-step learning rate scheduler with optional linear warm-up strategy """
[docs] def __init__(self, opts, **kwargs) -> None: is_iter_based = getattr(opts, "scheduler.is_iteration_based", True) super().__init__(opts=opts) max_iterations = getattr(opts, "scheduler.max_iterations", 150000) self.lr = getattr(opts, "scheduler.multi_step.lr", None) assert self.lr is not None if self.warmup_iterations > 0: self.warmup_step = (self.lr - self.warmup_init_lr) / self.warmup_iterations milestones = getattr(opts, "scheduler.multi_step.milestones", None) if milestones is None: milestones = [-1] elif isinstance(milestones, int): milestones = [milestones] self.milestones = sorted( list(set(milestones)) ) # remove duplicates and sort them self.gamma = getattr(opts, "scheduler.multi_step.gamma", 1.0) self.period = ( max_iterations - self.warmup_iterations + 1 if is_iter_based else getattr(opts, "scheduler.max_epochs", 350) ) self.is_iter_based = is_iter_based
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group( title="{} arguments".format(cls.__name__), description="{} arguments".format(cls.__name__), ) group.add_argument( "--scheduler.multi-step.lr", type=float, default=0.1, help="LR value" ) group.add_argument( "--scheduler.multi-step.gamma", type=float, default=None, help="Decay LR value by this factor", ) group.add_argument( "--scheduler.multi-step.milestones", type=int, nargs="+", default=None, help="Decay LR value at these epoch", ) return parser
[docs] def get_lr(self, epoch: int, curr_iter: int) -> float: if curr_iter < self.warmup_iterations: return max(0.0, self.warmup_init_lr + curr_iter * self.warmup_step) else: if epoch in self.milestones: self.lr = self.lr * self.gamma self.milestones.remove(epoch) return max(0.0, self.lr)
def __repr__(self) -> str: repr_str = "{}(".format(self.__class__.__name__) repr_str += "\n\tlr={}\n\tmilestones={}\n\tgamma={}".format( self.lr, self.milestones, self.gamma ) if self.warmup_iterations > 0: repr_str += "\n\twarmup_init_lr={}\n\twarmup_iters={}".format( self.warmup_init_lr, self.warmup_iterations ) repr_str += "\n )" return repr_str