# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause
"""Pruning specification."""
from __future__ import annotations
from typing import Annotated, Any
from pydantic import BeforeValidator, Field, PrivateAttr, field_validator, model_validator
from coreai_opt._utils.registry_utils import ClassRegistryMixin
from coreai_opt.common import CompressionType
from coreai_opt.config.spec import CompressionSpec
from .prune import PruneImplBase
from .scheme import PruningScheme, Unstructured
[docs]
class PruningSpec(CompressionSpec):
"""Specification for pruning tensors.
Attributes:
target_sparsity (float): Fraction of elements to prune, in ``[0, 1]``.
Default: 0.5.
pruning_scheme (PruningScheme): Structural pattern of sparsity.
Default: ``Unstructured()``.
pruning_algo (type[PruneImplBase]): Pruning implementation class.
Default: ``"default"`` (magnitude-based pruning).
Example:
>>> spec = PruningSpec()
>>> spec.target_sparsity
0.5
>>> spec = PruningSpec(target_sparsity=0.75)
"""
_compression_type: CompressionType = PrivateAttr(default=CompressionType.PRUNING)
target_sparsity: float = Field(default=0.5, ge=0.0, le=1.0)
pruning_scheme: Annotated[
PruningScheme,
BeforeValidator(PruningScheme.maybe_build_from_dict),
] = Field(default_factory=Unstructured)
pruning_algo: type[PruneImplBase] = Field(default="default", validate_default=True)
@staticmethod
def _convert_with_registry(data: str | type, registry_class: type[ClassRegistryMixin]) -> type:
"""Convert string or type to a registered class from the given registry.
Args:
data (str | type): Either a string key or a class type.
registry_class (type[ClassRegistryMixin]): The registry class to look up.
Returns:
type: The registered class type.
Raises:
ValueError: If the key is not found in registry or type is not registered.
TypeError: If *data* is neither string nor type.
"""
if isinstance(data, str):
try:
return registry_class.get_class(data) # type: ignore[no-any-return]
except KeyError as err:
available_keys = registry_class.list_registry_keys()
raise ValueError(
f"No class is registered with key: '{data}' "
f"in registry {registry_class.__name__}. "
f"Available keys: {sorted(available_keys)}"
) from err
elif isinstance(data, type):
if data in registry_class.list_registry_values():
return data
available_classes = [cls.__name__ for cls in registry_class.list_registry_values()]
raise ValueError(
f"Class {data.__name__} is not registered in "
f"{registry_class.__name__}. "
f"Available classes: {sorted(available_classes)}"
)
else:
raise TypeError(
f"Expected str or type for registry lookup, got {type(data).__name__}: {data}"
)
@field_validator("pruning_algo", mode="before")
@classmethod
def convert_pruning_algo(cls, data: Any) -> type[PruneImplBase]:
"""Resolve string keys to registered pruning implementation classes."""
return cls._convert_with_registry(data, PruneImplBase)
@model_validator(mode="before")
@classmethod
def _strip_computed_fields(cls, data: Any) -> Any:
"""Strip computed fields when deserializing from dict for round-trip support."""
if isinstance(data, dict):
declared = set(cls.model_fields.keys())
return {k: v for k, v in data.items() if k in declared}
return data
[docs]
def default_weight_pruning_spec() -> PruningSpec:
"""Return the default pruning spec for weight tensors."""
return PruningSpec(
target_sparsity=0.5,
pruning_scheme=Unstructured(),
pruning_algo="default",
)