# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause
from abc import abstractmethod
from pydantic import BaseModel, ConfigDict, Field, NonNegativeInt, PositiveFloat, PositiveInt
from coreai_opt._utils.registry_utils import ConfigRegistryMixin as _ConfigRegistryMixin
[docs]
class SparsityScheduleBase(BaseModel, _ConfigRegistryMixin):
"""Abstract base for sparsity schedules used by ``MagnitudePruner``.
A sparsity schedule defines how the sparsity applied during pruning
evolves over training steps. Instead of applying the full target sparsity
immediately, a schedule lets sparsity rise gradually so the model can
adapt to it during training. Each schedule is a pure function of the
pruner's step count and the spec's target sparsity.
"""
model_config = ConfigDict(frozen=True, extra="forbid")
[docs]
@abstractmethod
def compute_sparsity(
self,
step_count: int,
target_sparsity: float,
prev_sparsity: float | None = None,
) -> float:
"""Return the sparsity that should be applied at *step_count*.
Args:
step_count (int): The current step count of the pruner
(monotonically increasing).
target_sparsity (float): The final sparsity we want to reach at the
end of the pruning schedule.
prev_sparsity (float | None): Sparsity from the previous invocation.
Schedules that don't need this can ignore it; schedules that do
(e.g. ``PolynomialDecaySchedule`` with an ``update_frequency``
gap) raise ``ValueError`` when omitted.
Returns:
float: The sparsity level to apply at the current step.
"""
[docs]
@SparsityScheduleBase.register("constant")
class ConstantSparsitySchedule(SparsityScheduleBase):
"""Step function: zero before ``begin_step``, ``target_sparsity`` at and after.
Attributes:
begin_step (int): Step at which to switch from 0 to ``target_sparsity``.
Default: 0.
"""
begin_step: NonNegativeInt = 0
[docs]
def compute_sparsity(
self,
step_count: int,
target_sparsity: float,
prev_sparsity: float | None = None,
) -> float:
return target_sparsity if step_count >= self.begin_step else 0.0
[docs]
@SparsityScheduleBase.register("polynomial_decay")
class PolynomialDecaySchedule(SparsityScheduleBase):
r"""Polynomial schedule from ``initial_sparsity`` to ``target_sparsity``.
Inspired by PyTorch's ``torch.optim.lr_scheduler.PolynomialLR`` and the paper
`"To prune or not to prune" <https://arxiv.org/pdf/1710.01878.pdf>`_.
Behavior by step:
- ``step < begin_step`` → ``initial_sparsity``
- ``begin_step <= step < begin_step + total_iters`` → scheduled value
- ``step >= begin_step + total_iters`` → ``target_sparsity``
Formula at update index :math:`i \in [0, n\_updates - 1]`:
.. math::
t = i / \max(n\_updates - 1, 1)
sparsity = target + (initial - target) \cdot (1 - t)^{power}
Attributes:
begin_step (int): Step at which the schedule starts. Default: 0.
total_iters (int): Length of the schedule in steps. Must be positive.
power (float): Polynomial exponent. ``1.0`` is linear; higher values
keep sparsity low for longer before climbing. Default: 3.0.
initial_sparsity (float): Sparsity before and at the start of the
schedule, in ``[0, 1]``. Default: 0.0.
update_frequency (int): Steps between sparsity updates within the
schedule. Must be >= 1. Default: 1 (update every step).
"""
begin_step: int = Field(default=0, ge=0)
total_iters: PositiveInt
power: PositiveFloat = 3.0
initial_sparsity: float = Field(default=0.0, ge=0.0, le=1.0)
update_frequency: PositiveInt = 1
[docs]
def compute_sparsity(
self,
step_count: int,
target_sparsity: float,
prev_sparsity: float | None = None,
) -> float:
if step_count < self.begin_step:
return self.initial_sparsity
if step_count >= self.begin_step + self.total_iters:
return target_sparsity
offset = step_count - self.begin_step
if offset % self.update_frequency != 0:
if prev_sparsity is None:
raise ValueError(
"prev_sparsity is required for off-boundary steps when "
f"update_frequency={self.update_frequency} > 1."
)
return prev_sparsity
n_updates = max((self.total_iters - 1) // self.update_frequency + 1, 1)
i = offset // self.update_frequency
t = i / max(n_updates - 1, 1)
return target_sparsity + (self.initial_sparsity - target_sparsity) * (1 - t) ** self.power