# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause
from __future__ import annotations
import copy
import re
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import IO, Annotated, Any, Generic, Self, TypeAlias, TypeVar, overload
import torch
import torch.nn as nn
import yaml
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
PrivateAttr,
field_validator,
model_validator,
)
from coreai_opt._utils.config_utils import ConfigLevel as _ConfigLevel
from coreai_opt._utils.python_utils import (
fqn as _fqn,
get_generic_type_arg as _get_generic_type_arg,
)
from coreai_opt.config.spec import CompressionSpec
def _convert_none_to_empty_dict(v: Any):
"""If v is None, return an empty dict. Return v otherwise."""
if v is None:
return {}
return v
def _convert_digit_str_keys_to_int(v: Any):
"""Coerce digit-string dict keys to int for fields that accept int keys."""
if not isinstance(v, dict):
return v
def _try_int_key(k: Any) -> Any:
if isinstance(k, str) and k.isascii() and k.isdigit():
return int(k)
return k
result: dict[Any, Any] = {}
original_keys: dict[Any, Any] = {}
for k, val in v.items():
new_key = _try_int_key(k)
if new_key in result:
msg = (
f"Key collision detected: keys {original_keys[new_key]!r} and {k!r} "
f"both convert to {new_key!r}"
)
raise ValueError(msg)
result[new_key] = val
original_keys[new_key] = k
return result
_SpecT = TypeVar("_SpecT", bound=CompressionSpec)
[docs]
class OpCompressionConfig(BaseModel, ABC, Generic[_SpecT]):
"""
Abstract base configuration class for op-level compression settings.
This generic class defines the structure for configuring compression at the
operation level. Subclasses must implement the default spec providers to define
compression-specific default values. Parameterized by ``_SpecT``, the compression
spec type (e.g., QuantizationSpec, PalettizationSpec).
Attributes:
op_input_spec (dict[str | int, _SpecT | None] | None): Compression specifications
for operation inputs. Keys can be either all indices or all string names,
but not a mix of both. The special key "*" can be used in both cases to
refer to all inputs. Values are compression spec objects or None to
disable compression.
op_output_spec (dict[str | int, _SpecT | None] | None): Compression
specifications for operation outputs. Keys can be either all indices or all
string names, but not a mix of both. The special key "*" can be used in both
cases to refer to all outputs. Values are compression spec objects or None
to disable compression.
op_state_spec (dict[str, _SpecT | None] | None): Compression specifications for
operation state tensors (parameters, buffers, constants). Keys are string
names (e.g., "weight", "bias") or "*" to refer to all state inputs.
Values are compression spec objects or None to disable compression.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True, extra="forbid", validate_assignment=True
)
op_input_spec: Annotated[
dict[str | int, _SpecT | None] | None,
BeforeValidator(_convert_none_to_empty_dict),
BeforeValidator(_convert_digit_str_keys_to_int),
]
op_output_spec: Annotated[
dict[str | int, _SpecT | None] | None,
BeforeValidator(_convert_none_to_empty_dict),
BeforeValidator(_convert_digit_str_keys_to_int),
]
op_state_spec: Annotated[
dict[str, _SpecT | None] | None, BeforeValidator(_convert_none_to_empty_dict)
]
[docs]
@classmethod
@abstractmethod
def get_default_output_spec(cls) -> dict[str | int, _SpecT | None]:
"""
Provide default output spec for this compression type.
Override in subclasses to define compression-specific defaults.
Return empty dict if this compression type doesn't apply to outputs.
Returns:
Dictionary mapping output identifiers to compression specs
"""
pass
[docs]
@classmethod
@abstractmethod
def get_default_state_spec(cls) -> dict[str, _SpecT | None]:
"""
Provide default state spec for this compression type.
Override in subclasses to define compression-specific defaults.
Return empty dict if this compression type doesn't apply to state.
Returns:
Dictionary mapping state tensor names to compression specs
"""
pass
@model_validator(mode="before")
@classmethod
def apply_defaults(cls, data: Any) -> Any:
"""
Apply class-specific defaults before validation.
This validator runs before field validation and applies the default specs
defined by subclasses if the specs are not provided in the input data.
Note: Explicitly passing None will result in an empty dict (via BeforeValidator)
while omitting the field entirely will apply defaults.
"""
if not isinstance(data, dict):
return data
if "op_input_spec" not in data:
data["op_input_spec"] = cls.get_default_input_spec()
if "op_output_spec" not in data:
data["op_output_spec"] = cls.get_default_output_spec()
if "op_state_spec" not in data:
data["op_state_spec"] = cls.get_default_state_spec()
return data
[docs]
class WeightOnlyOpValidationMixin:
"""
Mixin that adds weight-only validation to OpCompressionConfig subclasses.
This mixin is for compression types that only apply to weights/state tensors
and don't compress activations (inputs/outputs). Examples include palettization,
pruning, and low-rank decomposition.
This mixin:
1. Provides default empty implementations for get_default_input_spec and
get_default_output_spec (satisfying abstract methods)
2. Adds a model_validator that rejects any op_input_spec or op_output_spec
Note:
The mixin MUST come first in the inheritance list due to Python's
MRO and abstract method resolution.
Example:
>>> class MyOpConfig(WeightOnlyOpValidationMixin, OpCompressionConfig[MySpec]):
... @classmethod
... def get_default_state_spec(cls):
... return {"weight": MySpec()}
"""
[docs]
@classmethod
def get_default_output_spec(cls) -> dict:
"""Return empty dict as weight-only compression doesn't apply to outputs."""
return {}
[docs]
@model_validator(mode="after")
def validate_weight_only_op_constraint(self):
"""Ensure no input/output specs are set (weight-only compression)."""
error_msg = (
f"{self.__class__.__name__} does not support {{key}}. "
f"This is a weight-only compression type that only supports op_state_spec."
)
if self.op_input_spec:
raise ValueError(error_msg.format(key="op_input_spec"))
if self.op_output_spec:
raise ValueError(error_msg.format(key="op_output_spec"))
return self
_OpConfigT = TypeVar("_OpConfigT", bound=OpCompressionConfig)
[docs]
class ModuleCompressionConfig(BaseModel, Generic[_OpConfigT, _SpecT]):
"""
Abstract base configuration class for module-level compression settings.
This generic class defines the structure for configuring compression at the
module level. Subclasses must implement the default spec providers to define
compression-specific default values. Parameterized by ``_OpConfigT`` (the op-level
config type, e.g., OpQuantizerConfig) and ``_SpecT`` (the compression spec type,
e.g., QuantizationSpec).
Attributes:
op_input_spec (dict[str | int, _SpecT | None] | None): Compression specifications
for operation inputs applied to all registered operations / patterns within
this module that don't have a more specific configuration.
op_output_spec (dict[str | int, _SpecT | None] | None): Compression
specifications for operation outputs applied to all registered operations /
patterns within this module that don't have a more specific configuration.
op_state_spec (dict[str, _SpecT | None] | None): Compression specifications for
operation state tensors applied to all registered operations / patterns
within this module that don't have a more specific configuration.
op_type_config (dict[str, _OpConfigT | None] | None): Operation type-specific
configurations. Keys are operation type names, values are op-level config
objects or None to disable compression for that type.
op_name_config (dict[str, _OpConfigT | None] | None): Operation name-specific
configurations. Keys are operation name patterns (supports regex), values
are op-level config objects or None to disable compression.
module_input_spec (dict[str | int, _SpecT | None] | None): Compression
specifications for module inputs. Module input settings treat the module
as an opaque entity.
module_output_spec (dict[str | int, _SpecT | None] | None): Compression
specifications for module outputs. Module output settings treat the module
as an opaque entity.
module_state_spec (dict[str, _SpecT | None] | None): Compression specifications
for module state tensors (parameters, buffers, constants). Module state
settings will override op state settings for the same state tensors.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True, extra="forbid", validate_assignment=True
)
op_input_spec: Annotated[
dict[str | int, _SpecT | None] | None,
BeforeValidator(_convert_none_to_empty_dict),
BeforeValidator(_convert_digit_str_keys_to_int),
]
op_output_spec: Annotated[
dict[str | int, _SpecT | None] | None,
BeforeValidator(_convert_none_to_empty_dict),
BeforeValidator(_convert_digit_str_keys_to_int),
]
op_state_spec: Annotated[
dict[str, _SpecT | None] | None, BeforeValidator(_convert_none_to_empty_dict)
]
op_type_config: Annotated[
dict[str, _OpConfigT | None] | None, BeforeValidator(_convert_none_to_empty_dict)
] = Field(default_factory=dict)
op_name_config: Annotated[
dict[str, _OpConfigT | None] | None, BeforeValidator(_convert_none_to_empty_dict)
] = Field(default_factory=dict)
module_input_spec: Annotated[
dict[str | int, _SpecT | None] | None,
BeforeValidator(_convert_none_to_empty_dict),
BeforeValidator(_convert_digit_str_keys_to_int),
] = Field(default_factory=dict)
module_output_spec: Annotated[
dict[str | int, _SpecT | None] | None,
BeforeValidator(_convert_none_to_empty_dict),
BeforeValidator(_convert_digit_str_keys_to_int),
] = Field(default_factory=dict)
module_state_spec: Annotated[
dict[str, _SpecT | None] | None, BeforeValidator(_convert_none_to_empty_dict)
] = Field(default_factory=dict)
@classmethod
def _get_op_config_class(cls) -> type[_OpConfigT] | None:
"""
Get the op config class from generic type parameters.
Assumes the class is a direct subclass of
ModuleCompressionConfig[_OpConfigT, _SpecT].
"""
return _get_generic_type_arg(cls, ModuleCompressionConfig, arg_index=0)
@model_validator(mode="before")
@classmethod
def apply_defaults(cls, data: Any) -> Any:
"""
Apply class-specific defaults before validation.
Gets default specs from the _OpConfigT type parameter's default methods.
Note: Explicitly passing None will result in an empty dict (via BeforeValidator)
while omitting the field entirely will apply defaults.
"""
if not isinstance(data, dict):
return data
op_config_cls = cls._get_op_config_class()
if op_config_cls is not None:
if "op_input_spec" not in data:
data["op_input_spec"] = op_config_cls.get_default_input_spec()
if "op_output_spec" not in data:
data["op_output_spec"] = op_config_cls.get_default_output_spec()
if "op_state_spec" not in data:
data["op_state_spec"] = op_config_cls.get_default_state_spec()
return data
@model_validator(mode="after")
def _normalize_none_op_configs(self) -> ModuleCompressionConfig:
"""Replace None values in op_type_config/op_name_config with disabled configs.
A None value means "disable compression for this op type/name". This
validator normalises it to an explicit OpConfig with empty specs.
"""
op_config_cls = self._get_op_config_class()
if op_config_cls is None:
raise RuntimeError(
f"Unable to generate empty op_config_cls for config type {type(self).__name__}."
)
disabled_op_cfg = op_config_cls(
op_input_spec=None,
op_output_spec=None,
op_state_spec=None,
)
for cfg_name in ("op_type_config", "op_name_config"):
cfg_dict = getattr(self, cfg_name)
if any(v is None for v in cfg_dict.values()):
setattr(
self,
cfg_name,
{k: disabled_op_cfg if v is None else v for k, v in cfg_dict.items()},
)
return self
def _get_compressor_specific_settings(self) -> dict[str, Any]:
"""
Get compressor-specific settings, excluding base ModuleCompressionConfig fields.
Returns only fields defined by the concrete subclass (e.g.,
enable_fast_kmeans_mode, rounding_precision for palettization), not the base
spec and config fields (op_input_spec, op_output_spec, op_state_spec,
op_type_config, op_name_config, module_input_spec, module_output_spec,
module_state_spec).
This is useful when constructing arguments for compression operations that need
the compression-specific settings but handle the spec fields separately.
Returns:
Dictionary of compressor-specific field names to their values.
"""
base_field_names = set(ModuleCompressionConfig.model_fields.keys())
return {k: v for k, v in self.model_dump().items() if k not in base_field_names}
[docs]
class WeightOnlyModuleValidationMixin:
"""
Mixin that adds weight-only validation to ModuleCompressionConfig subclasses.
This mixin is for compression types that only apply to weights/state tensors
and don't compress activations (inputs/outputs).
This mixin adds a model_validator that rejects activation specs:
- op_input_spec, op_output_spec (op-level activations)
- module_input_spec, module_output_spec (module-level activations)
Only op_state_spec and module_state_spec are allowed.
Note:
The mixin should come first in the inheritance list for proper MRO resolution.
Example:
>>> class MyModuleConfig(
... WeightOnlyModuleValidationMixin,
... ModuleCompressionConfig[MyOpConfig, MySpec]
... ):
... pass
"""
[docs]
@model_validator(mode="after")
def validate_weight_only_module_constraint(self):
"""Ensure no activation specs are set (weight-only compression)."""
error_msg = (
f"{self.__class__.__name__} does not support {{key}}. "
f"This is a weight-only compression type that only supports "
f"op_state_spec and module_state_spec."
)
if self.op_input_spec:
raise ValueError(error_msg.format(key="op_input_spec"))
if self.op_output_spec:
raise ValueError(error_msg.format(key="op_output_spec"))
if self.module_input_spec:
raise ValueError(error_msg.format(key="module_input_spec"))
if self.module_output_spec:
raise ValueError(error_msg.format(key="module_output_spec"))
return self
_T = TypeVar("_T", bound=ModuleCompressionConfig)
ModuleConfigDict: TypeAlias = dict[_ConfigLevel, dict[str, _T]]
def _build_module_alias_map(
model: torch.nn.Module,
) -> tuple[dict[str, list[str]], dict[int, str]]:
"""Return a mapping from each module's canonical name to all its registered names,
and a reverse mapping from module id to canonical name.
A "canonical" name is the first name returned by ``named_modules()`` (which
deduplicates by object identity). An "alias" is any additional name under
which the same module object is reachable, arising when a model registers the
same submodule under more than one attribute (e.g. HuggingFace wrappers that
hoist backbone children to the top level).
The first returned dict has one entry per canonical name; the value is a list
containing the canonical name followed by any alias names, in registration
order. Modules with no aliases have a single-element list.
The second returned dict maps id(module) → canonical name for all modules.
Args:
model: The model to inspect.
Returns:
Tuple of (canonical_to_aliases, id_to_canonical) where canonical_to_aliases
maps canonical module name → list of all names, and id_to_canonical maps
id(module) → canonical name.
"""
id_to_canonical: dict[int, str] = {}
canonical_to_aliases: dict[str, list[str]] = {}
for name, module in model.named_modules(remove_duplicate=False):
m_id = id(module)
if m_id not in id_to_canonical:
id_to_canonical[m_id] = name
canonical_to_aliases[name] = [name]
else:
canonical_to_aliases[id_to_canonical[m_id]].append(name)
return canonical_to_aliases, id_to_canonical
[docs]
class CompressionConfig(BaseModel, Generic[_T]):
"""
Top level configuration class for model compression.
This class manages compression configurations at different scopes:
- Global configuration (applies to all modules by default)
- Module type configurations (applies to all modules of specific type)
- Module name configurations (applies to module instances identified by name)
The configuration lookup follows a hierarchical precedence, where more specific
configurations override more general ones.
Generic type _T must be a subclass of ModuleCompressionConfig.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True, extra="forbid", validate_assignment=True
)
global_config: _T | None = None
module_type_configs: Annotated[
dict[str | type[nn.Module], _T | None], BeforeValidator(_convert_none_to_empty_dict)
] = {}
module_name_configs: Annotated[
dict[str, _T | None], BeforeValidator(_convert_none_to_empty_dict)
] = {}
# Guards only_for() against double-redistribution.
_global_config_disabled: bool = PrivateAttr(default=False)
@model_validator(mode="before")
@classmethod
def _apply_default_global_config(cls, data: Any) -> Any:
"""Apply default global_config based on the generic type parameter _T."""
if "global_config" not in data:
module_config_cls = cls._get_module_config_class()
if module_config_cls is not None:
data["global_config"] = module_config_cls()
return data
@classmethod
def _get_module_config_class(cls) -> type[_T] | None:
"""
Get the module config class from generic type parameters.
Assumes the class is a direct subclass of CompressionConfig[_T].
"""
return _get_generic_type_arg(cls, CompressionConfig, arg_index=0)
@model_validator(mode="after")
def _normalize_none_module_configs(self) -> CompressionConfig:
"""Replace None values in global/module_type/module_name configs with disabled configs.
A None value means "disable compression for this scope". This validator
normalises it to an explicit ModuleCompressionConfig with empty specs, so
downstream consumers never need to handle None.
"""
# Use __dict__ to bypass validate_assignment which would re-trigger
# model validators and cause infinite recursion.
if self.global_config is None:
self._global_config_disabled = True
self.__dict__["global_config"] = self._normalize_none_config(self.global_config)
for cfg_name in ("module_type_configs", "module_name_configs"):
cfg_dict = getattr(self, cfg_name)
if any(v is None for v in cfg_dict.values()):
self.__dict__[cfg_name] = {
k: self._normalize_none_config(v) for k, v in cfg_dict.items()
}
return self
@model_validator(mode="after")
def _validate_global_config_restrictions(self) -> CompressionConfig:
"""Ensure global_config doesn't have module-level specs"""
if self.global_config is not None and (
self.global_config.module_input_spec
or self.global_config.module_output_spec
or self.global_config.module_state_spec
):
error_msg = (
"global_config cannot have module_input_spec, "
"module_output_spec, or module_state_spec. "
"These are only allowed in module_type_configs and "
"module_name_configs."
)
raise ValueError(error_msg)
return self
@field_validator("module_type_configs", mode="after")
@classmethod
def _normalize_module_type_configs(cls, value):
normalized_value = {}
for module_type, config in value.items():
module_type_str = cls._normalize_module_type(module_type)
normalized_value[module_type_str] = config
return normalized_value
@staticmethod
def _normalize_module_type(module_type: str | type[nn.Module]) -> str:
"""
Normalize module type to a fully-qualified name string.
Args:
module_type: Either a fully-qualified name string or a PyTorch module type
Returns:
str: Normalized fully-qualified name of the module type
"""
if isinstance(module_type, str):
# module_type specified as a string
if "." not in module_type:
raise ValueError(f"Expected fully-qualified name, got {module_type}")
return module_type
if isinstance(module_type, type) and issubclass(module_type, nn.Module):
# module_type specified as a type
return _fqn(module_type)
error_msg = (
f"Keys in module_type_configs must be either fully-qualified "
f"strings or nn.Module types, got {type(module_type)}"
)
raise TypeError(error_msg)
def _normalize_none_config(self, config: _T | None) -> _T:
"""Return a disabled ModuleCompressionConfig if config is None, otherwise return as-is."""
if config is not None:
return config
module_config_cls = self._get_module_config_class()
if module_config_cls is None:
error_msg = (
f"Unable to generate empty module_config_cls for config type {type(self).__name__}."
)
raise RuntimeError(error_msg)
return module_config_cls(
op_input_spec=None,
op_output_spec=None,
op_state_spec=None,
module_input_spec=None,
module_output_spec=None,
module_state_spec=None,
)
[docs]
def set_global(self, config: _T | None) -> Self:
"""Set the global config.
Accepts a ``ModuleCompressionConfig`` (the canonical form) or ``None``
to disable compression globally.
"""
self._global_config_disabled = config is None
self.global_config = self._normalize_none_config(config)
return self
[docs]
def set_module_type(
self,
module_type: str | type[nn.Module],
config: _T | None,
) -> Self:
"""Set the module level compression config for a given module type.
If the module level compression config for an existing module type was
already set, the new config will override the old one.
"""
module_type_str = self._normalize_module_type(module_type)
self.module_type_configs[module_type_str] = self._normalize_none_config(config)
return self
[docs]
def set_module_name(self, module_name: str, config: _T | None) -> Self:
"""Set the module level compression config for a given module instance.
If the module level compression config for an existing module was
already set, the new config will override the old one.
"""
self.module_name_configs[module_name] = self._normalize_none_config(config)
return self
@overload
def only_for(
self,
targets: list[str | type[nn.Module]] | tuple[str | type[nn.Module], ...],
/,
) -> Self: ...
@overload
def only_for(self, *targets: str | type[nn.Module]) -> Self: ...
[docs]
def only_for(self, *targets: Any) -> Self:
"""Restrict this config to apply only to the given module types/names.
Disables ``global_config`` and re-applies it as a deep-copied per-module
override on each listed target. Targets may be ``nn.Module`` subclasses
or module-name strings, mixed in the same call and passed either as
varargs or as a single list/tuple. All targets are validated before any
mutation happens.
Args:
*targets: One or more ``nn.Module`` subclasses or module name
strings, passed as varargs or a single list/tuple.
Returns:
Self: ``self``, for chaining.
Raises:
ValueError: If no targets are provided, or if ``global_config`` is
already disabled. Pass all targets in one call instead of
chaining ``only_for``.
TypeError: If a target is neither an ``nn.Module`` subclass nor a
string.
Note:
If a target already has an explicit override (via ``set_module_type``
or ``set_module_name``), ``only_for`` overwrites it with the former
global config. To keep per-target customizations, call ``only_for``
first and ``set_module_type`` / ``set_module_name`` after.
The ``ValueError`` raised when ``only_for`` is called twice (see
``Raises``) uses a private attribute that is excluded from
``model_dump`` / ``to_yaml``, so a round-tripped config will
accept ``only_for`` again (functionally a no-op).
Example:
>>> config = QuantizerConfig.presets.w8().only_for(nn.Linear, nn.Conv2d)
>>> config = QuantizerConfig.presets.w8().only_for([nn.Linear, "lm_head"])
"""
targets = self._unpack_targets(targets)
if not targets:
msg = "only_for requires at least one target"
raise ValueError(msg)
if self._is_global_config_disabled():
msg = (
"only_for requires a non-disabled global_config to "
"redistribute as per-module overrides. If you've already "
"called only_for or set_global(None), pass all targets in "
"one only_for(...) call instead of chaining."
)
raise ValueError(msg)
for target in targets:
self._validate_target(target)
spec = self.global_config
self._global_config_disabled = True
self.global_config = self._normalize_none_config(None)
for target in targets:
self._apply_target_config(target, copy.deepcopy(spec))
return self
@overload
def without(
self,
targets: list[str | type[nn.Module]] | tuple[str | type[nn.Module], ...],
/,
) -> Self: ...
@overload
def without(self, *targets: str | type[nn.Module]) -> Self: ...
[docs]
def without(self, *targets: Any) -> Self:
"""Exclude the given module types/names from this config.
Each target gets a per-module override of ``None`` (disabled). The
global config and other overrides are unchanged. Targets may be
``nn.Module`` subclasses or module-name strings, mixed in the same call
and passed either as varargs or as a single list/tuple. All targets are
validated before any mutation happens. Passing no targets (or an empty
list) is a no-op.
Args:
*targets: Zero or more ``nn.Module`` subclasses or module name
strings, passed as varargs or a single list/tuple.
Returns:
Self: ``self``, for chaining.
Raises:
TypeError: If a target is neither an ``nn.Module`` subclass nor a
string.
Example:
>>> config = QuantizerConfig.presets.w4().without(nn.LayerNorm)
>>> config = QuantizerConfig.presets.w4().without([nn.LayerNorm, "lm_head"])
"""
targets = self._unpack_targets(targets)
for target in targets:
self._validate_target(target)
for target in targets:
self._apply_target_config(target, None)
return self
def _apply_target_config(self, target: str | type[nn.Module], config: _T | None) -> None:
"""Route ``target`` to ``set_module_type`` or ``set_module_name`` based on its kind."""
if isinstance(target, type):
self.set_module_type(target, config)
else:
self.set_module_name(target, config)
@staticmethod
def _unpack_targets(targets: tuple[Any, ...]) -> tuple[Any, ...]:
"""Unwrap a single list/tuple argument into its elements.
Strings are not unpacked even though they are iterable — a bare string
is a valid single target.
"""
if len(targets) == 1 and isinstance(targets[0], (list, tuple)):
return tuple(targets[0])
return targets
@staticmethod
def _validate_target(target: Any) -> None:
"""Raise ``TypeError`` if ``target`` is not a string or ``nn.Module`` subclass."""
if isinstance(target, str):
return
if isinstance(target, type):
if issubclass(target, nn.Module):
return
msg = f"targets must be nn.Module subclasses or name strings, got {target.__name__}"
raise TypeError(msg)
msg = f"targets must be module types or name strings, got {type(target).__name__}"
raise TypeError(msg)
def _is_global_config_disabled(self) -> bool:
"""Return True if ``global_config`` was explicitly disabled."""
return self._global_config_disabled
[docs]
def get_module_config(self, name: str, module: torch.nn.Module) -> _T:
"""
Get the compression config for a module with priority.
1. Module name match (supports regex)
2. Module type match
3. Global config
Args:
name (str): Name of module to get config for
module(torch.nn.Module): Module to get config for
Returns:
_T: Module config for the given module.
"""
# module name match
for mod_name in self.module_name_configs:
if re.fullmatch(mod_name, name):
return self.module_name_configs[mod_name]
# module type match
module_fqn = _fqn(type(module))
if module_fqn in self.module_type_configs:
return self.module_type_configs[module_fqn]
# fallback to global config
return self.global_config
[docs]
def build_module_config_dict(self, model: torch.nn.Module) -> ModuleConfigDict[_T]:
"""
Build a mapping of module names to their quantization configurations,
separating modules by config level.
The modules are associated with configs according to the following rules:
- If a module is already associated with a config of a higher priority, the
lower priority config is ignored (module_name > module_type > global)
- A module may match with more than one config within the same config level.
For example, a nested module named model.outer.inner could be associated
with module name level configs for "model.outer.inner" as well as
"model.outer.*".
Whichever config is defined later in the config list is
the one which ends up being used.
- When a module is matched with a config, the config is applied to all of
its child modules recursively, subject to the above priority constraint.
For example, assume a nested module model.outer.inner.module and the
following config::
module_name: {"model.outer.inner": config1},
module_type: {type(model): config2}
Then we would have the following associations:
- "model": config2 (set with module_type config)
- "model.outer": config2 (recursively set when setting "model")
- "model.outer.inner": config1 (higher priority module_name match)
- "model.outer.inner.module": config1 (set as part of recursively
processing "model.outer.inner"'s child modules)
Args:
model: The model with modules to get configs for
Returns:
Dictionary with nested dictionary mapping modules to configs for each config
level
"""
module_config_dict: ModuleConfigDict[_T] = {
_ConfigLevel.MODULE_NAME: {},
_ConfigLevel.MODULE_TYPE: {},
_ConfigLevel.GLOBAL: {},
}
name_to_modules_dict = dict(model.named_modules())
# Build alias map and id→canonical lookup in one pass.
canonical_to_aliases, id_to_canonical = _build_module_alias_map(model)
_visited: set[int] = set()
def _apply_configs(
identifier_and_config: Iterable[tuple[str, _T]],
config_level: _ConfigLevel,
matcher_fn: Callable[[str, str, nn.Module], bool],
) -> None:
"""Apply configs to modules matching the given criteria."""
for identifier, config in identifier_and_config:
# Use reversed for name_to_modules_dict in order to process nested
# child modules before parents. This is only relevant for module level
# settings since they do not pass on to child module settings. For op
# level settings, because they are inherited by child modules, the order
# of processing would not matter.
for name, module in reversed(name_to_modules_dict.items()):
# Match against all known names (canonical + aliases) so that a
# user-supplied regex or type spec targeting an alias path still
# applies the config under the canonical name.
all_names = canonical_to_aliases.get(name, [name])
if any(matcher_fn(identifier, n, module) for n in all_names):
self._set_config_for_module(
name,
module,
module_config_dict,
config,
config_level,
_visited,
id_to_canonical,
)
# Apply in priority order
_apply_configs(
reversed(self.module_name_configs.items()),
_ConfigLevel.MODULE_NAME,
lambda identifier, name, _: re.fullmatch(identifier, name),
)
_apply_configs(
reversed(self.module_type_configs.items()),
_ConfigLevel.MODULE_TYPE,
lambda identifier, _, module: _fqn(type(module)) == identifier,
)
# Global config matches all remaining modules, no special id/matcher_fn needed
_apply_configs([("", self.global_config)], _ConfigLevel.GLOBAL, lambda *_: True)
return module_config_dict
def _prepare_config_for_child(self, parent_config: _T) -> _T:
"""
Prepare config for child modules by filtering out module-level specs.
When recursively applying configs to child modules, all properties of the
parent config are propagated EXCEPT module-level settings. This includes:
- Op-level settings: op_input_spec, op_output_spec, op_state_spec,
op_type_config, and op_name_config
- Compression-specific properties
Module-level settings (module_input_spec, module_output_spec,
module_state_spec) are NOT inherited by children as they are specific
to the parent module boundary.
Args:
parent_config: The config from the parent module
Returns:
Config with all properties except module-level settings.
"""
# Create new config with only op-level settings
child_config = copy.deepcopy(parent_config)
# Omit module_input/output/state_spec
child_config.module_input_spec = None
child_config.module_output_spec = None
child_config.module_state_spec = None
return child_config
def _set_config_for_module(
self,
module_name: str,
module: torch.nn.Module,
module_config_dict: ModuleConfigDict[_T],
config: _T,
config_level: _ConfigLevel,
visited: set[int],
id_to_canonical: dict[int, str],
):
"""
Set a config for module_name in module_config_dict. Skip if the module_name
exists in module_config_dict within an equal or higher config level.
This function also recursively sets the same config for all child modules of
the given module in accordance to the rule above.
visited tracks module object ids already processed to prevent infinite recursion
and duplicate entries. id_to_canonical maps id(module) to its canonical name
(as yielded by named_modules()), ensuring module_config_dict keys are always
canonical even when a module is reached via an alias path through named_children().
"""
# Always use the canonical name so module_config_dict keys stay consistent
# with module_name_to_state_names_map (which is also built from named_modules()).
# This handles the case where named_children() reaches a shared module via a
# non-canonical path (e.g. a shared ReLU at body.15.body.1 whose canonical
# name is body.0.body.1).
module_id = id(module)
canonical_name = id_to_canonical.get(module_id, module_name)
if module_id in visited:
return
visited.add(module_id)
# Since configs within a config level are processed in reversed order, if a
# module name already exists for the current config level, we can skip.
# We also skip if the name exists in a higher config level since that should
# always take priority.
for level in _ConfigLevel.priority_order():
if canonical_name in module_config_dict[level]:
return
if level == config_level:
break
module_config_dict[config_level][canonical_name] = config
child_config = self._prepare_config_for_child(config)
for child_name, child_module in module.named_children():
# Construct the path from the canonical name so all entries remain
# rooted at canonical paths.
child_full_name = f"{canonical_name}.{child_name}" if canonical_name else child_name
self._set_config_for_module(
child_full_name,
child_module,
module_config_dict,
child_config,
config_level,
visited,
id_to_canonical,
)
[docs]
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> CompressionConfig | None:
"""
Create configuration from a dictionary.
The dictionary must contain _CONFIG_KEY key whose value
defines the hierarchical configuration structure specifying which
compression specs to apply at different scopes (global, module type,
module name) and levels (op-level or module-level).
All compression specifications should be inline dictionaries (not
references). If this dictionary was created from YAML using from_yaml(),
any YAML anchor/alias references have already been resolved by the YAML
parser and substituted with the actual spec dictionaries.
This method requires subclasses to define _CONFIG_KEY and _SPEC_KEY class
attributes. It will:
1. Validate that only expected keys are present
2. Extract the nested config from config_dict[_CONFIG_KEY]
3. Create the config from the extracted data
Args:
config_dict (:obj:`dict` of :obj:`str` and values): A nested dictionary
of strings and values. Must have the key specified by _CONFIG_KEY
with the actual config nested inside.
Returns:
CompressionConfig: The created configuration object, or None if
config_dict is None.
Raises:
RuntimeError: If:
- Subclass doesn't define _CONFIG_KEY or _SPEC_KEY
- Unexpected keys are found in config_dict
- The required _CONFIG_KEY is not present in config_dict
Example:
>>> # Define a sample spec class for this compression type
>>> class MySpec(CompressionSpec):
... _compression_type = CompressionType.QUANTIZATION
... foo: str = "default" # String property
... bar: float = 1.0 # Float property
>>>
>>> # Subclass with config key pattern
>>> class MyConfig(CompressionConfig):
... _CONFIG_KEY: ClassVar[str] = "my_config"
... _SPEC_KEY: ClassVar[str] = "my_spec"
>>>
>>> # Create configuration dictionary
>>> # (Equivalent to a parsed YAML file with all anchors/aliases resolved)
>>> # Note: Specs are inlined as dictionaries, not as MySpec objects
>>> config_dict = {
... 'my_config': {
... 'global_config': {
... 'op_input_spec': {
... "*": {
... 'foo': 'custom_1',
... 'bar': 1.0
... }
... },
... 'op_state_spec': {
... 'weight': {
... 'foo': 'custom_2',
... 'bar': 0.5
... }
... }
... },
... 'module_type_configs': {
... 'torch.nn.modules.linear.Linear': {
... 'op_state_spec': {
... 'weight': {
... 'foo': 'custom_3',
... 'bar': 0.25
... }
... }
... }
... },
... 'module_name_configs': {
... 'model.encoder.layer1': {
... 'module_output_spec': {
... "*": {
... 'foo': 'default',
... 'bar': 2.0
... }
... }
... }
... }
... }
... }
>>>
>>> # Load the configuration
>>> config = MyConfig.from_dict(config_dict)
>>> # Access global config and its specs
>>> print(config.global_config.op_state_spec["weight"].bar)
>>> # 0.5
"""
# Check if subclass defines required class attributes
config_key = getattr(cls, "_CONFIG_KEY", None)
spec_key = getattr(cls, "_SPEC_KEY", None)
if config_key is None:
raise RuntimeError(f"{cls.__name__} must define _CONFIG_KEY class attribute. ")
if spec_key is None:
raise RuntimeError(f"{cls.__name__} must define _SPEC_KEY class attribute. ")
# Build expected keys
expected_keys = {config_key, spec_key}
# Validate keys
for key in config_dict:
if key not in expected_keys:
error_msg = (
f"Found unexpected key '{key}' in config dict. "
f"Supported keys are {expected_keys}."
)
raise RuntimeError(error_msg)
# Check required config key exists
if config_key not in config_dict:
error_msg = (
f"Did not find '{config_key}' in config dict. Expected keys: {expected_keys}."
)
raise RuntimeError(error_msg)
# Unwrap and create from nested config
return cls(**config_dict[config_key])
[docs]
@classmethod
def from_yaml(cls, yml: IO | str | Path) -> CompressionConfig | None:
"""
Create configuration from a YAML file or stream.
Args:
yml: File path or IO stream containing YAML data
Returns:
A CompressionConfig instance or None if the YAML content was empty
Raises:
ValueError: If the YAML content is not a dictionary
"""
# Handle different input types
if isinstance(yml, str | Path):
path = Path(yml)
config_data = yaml.safe_load(path.read_bytes())
else:
# Handle IO stream directly
config_data = yaml.safe_load(yml)
# Handle empty content
if config_data is None:
warnings.warn(
"Empty YAML content detected, returning None instead of a configuration object",
stacklevel=2,
)
return None
# Validate data type
if not isinstance(config_data, dict):
raise ValueError(f"Invalid YAML: expected dict, got {type(config_data)}.")
# Create configuration from dictionary
return cls.from_dict(config_data)
[docs]
def to_dict(self) -> dict[str, Any]:
"""
Convert configuration to dictionary.
"""
config_key = getattr(self, "_CONFIG_KEY", None)
if config_key is None:
raise RuntimeError(
f"{self.__class__.__name__} must define _CONFIG_KEY class attribute. "
)
return {config_key: self.model_dump()}