Source code for coreai_opt.config.spec.compression_simulator

# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause

"""Base class for differentiable compression simulators."""

from abc import abstractmethod

import torch
import torch.nn as nn

from coreai_opt._utils.registry_utils import ClassRegistryMixin


[docs] class CompressionSimulatorBase(ClassRegistryMixin, nn.Module): """ Abstract base class for compression simulators. This base class provides a common interface for all compression simulators, regardless of the specific compression technique. The compression simulator takes a tensor and applies the compression technique on the tensor, while allowing the model to be evaluated. Subclasses should implement the forward() method to define how the compression simulation is performed during training. """
[docs] @abstractmethod def forward(self, tensor: torch.Tensor) -> torch.Tensor: """ Apply compression simulation to the input tensor. This method should implement the differentiable approximation of the compression operation. The exact behavior depends on the specific compression technique. Args: tensor: Input tensor to compress Returns: Compressed tensor (or approximation thereof) with gradients flowing through """ pass