Source code for coreai_opt.pruning.magnitude_pruner

# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause

"""Magnitude pruner implementation."""

import logging
from typing import NamedTuple

import torch

from coreai_opt._utils.config_utils import ConfigLevel as _ConfigLevel
from coreai_opt._utils.eager_utils import (
    EagerCompressionComponentBuilderMixin as _EagerCompressionComponentBuilderMixin,
)
from coreai_opt._utils.insertion.torch_function import (
    TorchFunctionEagerHandler as _TorchFunctionEagerHandler,
)
from coreai_opt._utils.spec_utils import PartialConstructor as _PartialConstructor
from coreai_opt._utils.torch_utils import (
    find_parametrization_matching_cls as _find_parametrization_matching_cls,
    move_model_to_eval as _move_model_to_eval,
)
from coreai_opt.common import ExportBackend
from coreai_opt.config.compression_config import ModuleCompressionConfig
from coreai_opt.config.spec import CompressionTargetTensor
from coreai_opt.config.spec.base import CompressionSpec
from coreai_opt.pruning.spec import PruneImplBase

from ._prepare_for_export import prepare_for_mil_export, prepare_for_mlir_export
from .base_pruner import _BasePruner
from .config import MagnitudePrunerConfig
from .supported_ops_registry import _PrunerSupportedOpsRegistry

logger = logging.getLogger(__name__)


class _ScheduledModule(NamedTuple):
    """Cached location of a scheduled PruneImplBase: the impl plus where to find its weight."""

    impl: PruneImplBase
    module: torch.nn.Module
    param_name: str


[docs] class MagnitudePruner(_BasePruner, _EagerCompressionComponentBuilderMixin): """Apply magnitude-based pruning to a model. This pruner zeros out the smallest-magnitude weight elements to reach a configurable sparsity target. The model is parsed in an eager fashion and the pruner registers parametrizations for each candidate parameter to be pruned. The mask is applied on every forward pass while parametrizations are active. When a ``sparsity_schedule`` is configured on a module's config, ``step()`` advances the schedule and recomputes the mask for that module's parametrizations. Without a schedule, the spec's ``target_sparsity`` is applied statically. Args: model (torch.nn.Module): Model to prune. config (MagnitudePrunerConfig | None): Pruning configuration. When ``None``, a default config with 50 % sparsity is used. Example: >>> model = torch.nn.Linear(100, 50) >>> pruner = MagnitudePruner(model, MagnitudePrunerConfig()) >>> pruner.prepare((torch.randn(1, 100),)) >>> pruner.finalize() """ _step_count: int _scheduled_modules: list[_ScheduledModule]
[docs] def __init__(self, model: torch.nn.Module, config: MagnitudePrunerConfig | None = None): if config is None: config = MagnitudePrunerConfig() super().__init__(model, config) module_components_dict, module_priority_dict = ( self._get_module_compression_components_and_priority(model, config) ) self._handler = _TorchFunctionEagerHandler( compression_config=config, module_components_dict=module_components_dict, module_priority_dict=module_priority_dict, supported_ops_registry=_PrunerSupportedOpsRegistry, optimization_type_name="prune", ) self._step_count = 0 self._scheduled_modules = []
[docs] def prepare(self, example_inputs: tuple[torch.Tensor]) -> torch.nn.Module: """Prepare the model for pruning. Args: example_inputs (tuple[torch.Tensor]): Sample inputs to trace the model and configure pruning parametrizations. Returns: torch.nn.Module: The prepared model with pruning parametrizations. Raises: RuntimeError: If the model has already been prepared. """ if self._is_model_prepared(self._model): raise RuntimeError( "Model has already been prepared. Cannot re-prepare a prepared model." ) logger.info("Preparing model for pruning") prepared_model = self._handler.prepare(self._model, example_inputs=example_inputs) self._mark_model_as_prepared(prepared_model) self._model = prepared_model # Apply schedule before we run a forward pass to initialize the parameterizations self._build_scheduled_modules() with _move_model_to_eval(prepared_model), torch.no_grad(): prepared_model(*example_inputs) return self._model
[docs] def step(self) -> None: """Advance the sparsity schedule by one step. Increments the step count, then recomputes and applies the mask for every parametrization with a configured ``sparsity_schedule``. Safe to call when no schedule is configured (no-op). """ self._step_count += 1 self._apply_schedule()
[docs] def finalize( self, model: torch.nn.Module | None = None, backend: ExportBackend = ExportBackend.CoreAI, ) -> torch.nn.Module: """Finalize the model to be lowered to the target backend. Args: model (torch.nn.Module | None): Model to finalize. Uses the model passed at construction time when ``None``. backend (ExportBackend): Target export backend. Returns: torch.nn.Module: The finalized model ready for the target backend. Raises: RuntimeError: If the model has not been prepared. """ update_internal = model is None if model is None: model = self._model if not self._is_model_prepared(model): raise RuntimeError("Model must be prepared before finalization. Call prepare() first.") match backend: case ExportBackend._TORCH: pass case ExportBackend.CoreAI: model = prepare_for_mlir_export(model) case ExportBackend.CoreML: model = prepare_for_mil_export(model) case _: raise ValueError(f"Unsupported backend: {backend}") if update_internal: self._model = model return model
def _build_scheduled_modules(self) -> None: """Attach schedules to each scheduled PruneImplBase and apply the step-0 state.""" config_dict = self._config.build_module_config_dict(self._model) for name, module in self._model.named_modules(): if not hasattr(module, "parametrizations"): continue module_config = ( config_dict[_ConfigLevel.MODULE_NAME].get(name) or config_dict[_ConfigLevel.MODULE_TYPE].get(name) or config_dict[_ConfigLevel.GLOBAL].get(name) ) if module_config is None or module_config.sparsity_schedule is None: continue for param_name in module.parametrizations: impl = _find_parametrization_matching_cls(module, param_name, PruneImplBase) if impl is None: continue impl.schedule = module_config.sparsity_schedule self._scheduled_modules.append( _ScheduledModule(impl=impl, module=module, param_name=param_name) ) self._apply_schedule() def _apply_schedule(self) -> None: """For each scheduled impl, advance its sparsity and materialize the mask.""" for entry in self._scheduled_modules: entry.impl.update_sparsity(self._step_count) original = entry.module.parametrizations[entry.param_name].original with torch.no_grad(): entry.impl(original.detach()) @staticmethod def _spec_to_partial( spec: CompressionSpec | None, target: CompressionTargetTensor, module_config: ModuleCompressionConfig, ) -> _PartialConstructor | None: if spec is None: return None return spec.pruning_algo.with_args( target_sparsity=spec.target_sparsity, pruning_scheme=spec.pruning_scheme, )