#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
import math
import numpy as np
from optim.scheduler import SCHEDULER_REGISTRY
from optim.scheduler.base_scheduler import BaseLRScheduler
from utils import logger
SUPPORTED_LAST_CYCLES = ["cosine", "linear"]
[docs]@SCHEDULER_REGISTRY.register("cyclic")
class CyclicLRScheduler(BaseLRScheduler):
"""
Cyclic LR: https://arxiv.org/abs/1811.11431
"""
[docs] def __init__(self, opts, **kwargs) -> None:
cycle_steps = getattr(opts, "scheduler.cyclic.steps", [25])
if cycle_steps is not None and isinstance(cycle_steps, int):
cycle_steps = [cycle_steps]
gamma = getattr(opts, "scheduler.cyclic.gamma", 0.5)
anneal_type = getattr(opts, "scheduler.cyclic.last_cycle_type", "linear")
min_lr = getattr(opts, "scheduler.cyclic.min_lr", 0.1)
end_lr = getattr(opts, "scheduler.cyclic.last_cycle_end_lr", 1e-3)
ep_per_cycle = getattr(opts, "scheduler.cyclic.epochs_per_cycle", 5)
warmup_iterations = getattr(opts, "scheduler.warmup_iterations", 0)
n_cycles = getattr(opts, "scheduler.cyclic.total_cycles", 10) - 1
max_epochs = getattr(opts, "scheduler.max_epochs", 100)
if anneal_type not in SUPPORTED_LAST_CYCLES:
logger.error(
"Supported anneal types for {} are: {}".format(
self.__class__.__name__, SUPPORTED_LAST_CYCLES
)
)
if min_lr < end_lr:
logger.error(
"Min LR should be greater than end LR. Got: {} and {}".format(
min_lr, end_lr
)
)
super(CyclicLRScheduler, self).__init__(opts=opts)
self.min_lr = min_lr
self.cycle_length = ep_per_cycle
self.end_lr = end_lr
self.max_lr = self.min_lr * self.cycle_length
self.last_cycle_anneal_type = anneal_type
if self.warmup_iterations > 0:
self.warmup_step = (
self.min_lr - self.warmup_init_lr
) / self.warmup_iterations
self.n_cycles = n_cycles
self.cyclic_epochs = self.cycle_length * self.n_cycles
self.max_epochs = max_epochs
self.last_cycle_epochs = self.max_epochs - self.cyclic_epochs
assert self.max_epochs == self.cyclic_epochs + self.last_cycle_epochs
self.steps = [self.max_epochs] if cycle_steps is None else cycle_steps
self.gamma = gamma if cycle_steps is not None else 1
self._lr_per_cycle()
self.epochs_lr_stepped = []
def _lr_per_cycle(self) -> None:
lrs = list(
np.linspace(self.max_lr, self.min_lr, self.cycle_length, dtype=np.float32)
)
lrs = [lrs[-1]] + lrs[:-1]
self.cycle_lrs = lrs
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(
title="Cyclic LR arguments", description="Cyclic LR arguments"
)
group.add_argument(
"--scheduler.cyclic.min-lr",
default=0.1,
type=float,
help="Min. lr for a cycle",
)
group.add_argument(
"--scheduler.cyclic.last-cycle-end-lr",
default=1e-3,
type=float,
help="End LR for the last cycle",
)
group.add_argument(
"--scheduler.cyclic.total-cycles",
default=11,
type=int,
help="Number of cycles. Default is 10",
)
group.add_argument(
"--scheduler.cyclic.epochs-per-cycle",
default=5,
type=int,
help="Number of epochs per cycle. Default is 5",
)
group.add_argument(
"--scheduler.cyclic.steps",
default=None,
type=int,
nargs="+",
help="steps at which LR should be decreased",
)
group.add_argument(
"--scheduler.cyclic.gamma",
default=0.5,
type=float,
help="Factor by which LR should be decreased",
)
group.add_argument(
"--scheduler.cyclic.last-cycle-type",
default="linear",
type=str,
choices=SUPPORTED_LAST_CYCLES,
help="Annealing in last cycle",
)
return parser
[docs] def get_lr(self, epoch: int, curr_iter: int) -> float:
if curr_iter < self.warmup_iterations:
curr_lr = self.warmup_init_lr + curr_iter * self.warmup_step
else:
if epoch <= self.cyclic_epochs:
if epoch in self.steps and epoch not in self.epochs_lr_stepped:
self.min_lr *= self.gamma ** (self.steps.index(epoch) + 1)
self.max_lr *= self.gamma ** (self.steps.index(epoch) + 1)
self._lr_per_cycle()
self.epochs_lr_stepped.append(epoch)
idx = epoch % self.cycle_length
curr_lr = self.cycle_lrs[idx]
else:
base_lr = self.min_lr
if self.last_cycle_anneal_type == "linear":
lr_step = (base_lr - self.end_lr) / self.last_cycle_epochs
curr_lr = base_lr - (epoch - self.cyclic_epochs + 1) * lr_step
elif self.last_cycle_anneal_type == "cosine":
curr_epoch = epoch - self.cyclic_epochs
period = self.max_epochs - self.cyclic_epochs + 1
curr_lr = self.end_lr + 0.5 * (base_lr - self.end_lr) * (
1 + math.cos(math.pi * curr_epoch / period)
)
else:
raise NotImplementedError
return max(0.0, curr_lr)
def __repr__(self):
repr_str = (
"{}(\n \t C={},\n \t C_length={},\n \t C_last={},\n \t Total_Epochs={}, "
"\n \t steps={},\n \t gamma={},\n \t last_cycle_anneal_method={} "
"\n \t min_lr={}, \n\t max_lr={}, \n\t end_lr={}\n)".format(
self.__class__.__name__,
self.n_cycles,
self.cycle_length,
self.last_cycle_epochs,
self.max_epochs,
self.steps,
self.gamma,
self.last_cycle_anneal_type,
self.min_lr,
self.min_lr * self.cycle_length,
self.end_lr,
)
)
return repr_str