Source code for coreai_opt.pruning.config.magnitude_pruner_config

# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause

"""Pruning configuration classes."""

from typing import ClassVar

from coreai_opt.config import (
    CompressionConfig,
    ModuleCompressionConfig,
    OpCompressionConfig,
    WeightOnlyModuleValidationMixin,
    WeightOnlyOpValidationMixin,
)
from coreai_opt.pruning.spec import PruningSpec, default_weight_pruning_spec

from .sparsity_schedule import SparsityScheduleBase

_MAGNITUDE_PRUNING_CONFIG = "magnitude_pruning_config"
_PRUNING_SPEC = "pruning_spec"


[docs] class OpMagnitudePrunerConfig(WeightOnlyOpValidationMixin, OpCompressionConfig[PruningSpec]): """Operation-level pruning configuration. Pruning is a weight-only compression technique. Only ``op_state_spec`` is used to configure which state tensors (e.g. weights) to prune. Attributes: op_state_spec (dict[str, PruningSpec | None]): Mapping of parameter names to their pruning specs. Default includes ``"weight"`` and ``"in_proj_weight"`` at 50 % sparsity. Example: >>> config = OpMagnitudePrunerConfig() >>> config = OpMagnitudePrunerConfig( ... op_state_spec={"weight": PruningSpec(target_sparsity=0.75)} ... ) """
[docs] @classmethod def get_default_state_spec(cls) -> dict[str, PruningSpec | None]: """Provide default state spec for pruning.""" spec = default_weight_pruning_spec() return {"weight": spec, "in_proj_weight": spec}
[docs] class ModuleMagnitudePrunerConfig( WeightOnlyModuleValidationMixin, ModuleCompressionConfig[OpMagnitudePrunerConfig, PruningSpec], ): """Module-level pruning configuration. Manages pruning settings for an entire module, following the same hierarchical precedence as other compression configs: 1. ``op_name_config`` (most specific) 2. ``op_type_config`` 3. ``op_state_spec`` (least specific) Attributes: op_state_spec (dict[str, PruningSpec | None] | None): Default pruning specs for state tensors in this module. op_type_config (dict[str, OpMagnitudePrunerConfig]): Per-op-type overrides. op_name_config (dict[str, OpMagnitudePrunerConfig]): Per-op-name overrides. module_state_spec (dict[str, PruningSpec | None] | None): Specs applied across all ops in the module. sparsity_schedule (SparsityScheduleBase | None): Optional sparsity schedule. When set, the ``pruner.step()`` API drives sparsity over training steps; when ``None`` (default), the spec's ``target_sparsity`` is applied immediately and statically. """ sparsity_schedule: SparsityScheduleBase | None = None
[docs] class MagnitudePrunerConfig(CompressionConfig[ModuleMagnitudePrunerConfig]): """Top-level configuration for magnitude pruning. Attributes: global_config (ModuleMagnitudePrunerConfig | None): Default pruning config applied to all modules. module_type_configs (dict[str, ModuleMagnitudePrunerConfig | None]): Per-module-type overrides. module_name_configs (dict[str, ModuleMagnitudePrunerConfig | None]): Per-module-name overrides (highest priority). Example: >>> config = MagnitudePrunerConfig() # 50 % sparsity everywhere >>> config = MagnitudePrunerConfig( ... module_name_configs={"layer1": None} # skip layer1 ... ) """ _CONFIG_KEY: ClassVar[str] = _MAGNITUDE_PRUNING_CONFIG _SPEC_KEY: ClassVar[str] = _PRUNING_SPEC