Source code for coreai_opt.pruning.spec.scheme

# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause

"""Pruning scheme specifications."""

from __future__ import annotations

from typing import Any, Literal

from pydantic import BaseModel, ConfigDict, Field, model_serializer

from coreai_opt._utils.registry_utils import ConfigRegistryMixin


[docs] class PruningScheme(BaseModel, ConfigRegistryMixin): """Base class for pruning scheme specifications. A pruning scheme defines the structural pattern of sparsity applied to a tensor. Subclasses represent different ways of structuring the pruning. Attributes: axis (int | None): The axis along which structured pruning is applied. ``None`` for unstructured (element-wise) pruning. """ model_config = ConfigDict(frozen=True, extra="forbid") axis: int | None = Field( default=None, description="Axis along which structured pruning is applied. None for unstructured.", ) @model_serializer def _serialize_model(self) -> dict[str, Any]: """Custom serializer that includes the registry type.""" data = {} for field_name in type(self).model_fields: data[field_name] = getattr(self, field_name) registry_key = None for key, registered_class in PruningScheme.REGISTRY.items(): if registered_class is type(self): registry_key = key break if registry_key is not None: data["type"] = registry_key return data
[docs] @PruningScheme.register("unstructured") class Unstructured(PruningScheme): """Unstructured pruning scheme. Individual elements are pruned independently — any element can be zeroed regardless of its position in the tensor. """ axis: Literal[None] = None
[docs] @PruningScheme.register("channel_structured") class ChannelStructured(PruningScheme): """Channel-structured pruning scheme. Entire channels (slices along ``axis``) are pruned or kept together. Channel importance is determined by the pruning algorithm (e.g. L1 norm of each channel for magnitude-based pruning). """ axis: int = Field(default=0, description="Axis along which channels are pruned.")