coreai_opt.pruning.spec.PruneImplBase¶
- class coreai_opt.pruning.spec.PruneImplBase(target_sparsity, pruning_scheme, **kwargs)[source]¶
Bases:
CompressionSimulatorBaseAbstract base for pruning parametrizations that mask a layer’s weight.
Subclasses implement
compute_mask()— a pure static function from(weight, sparsity, pruning_scheme)to a binary mask. The base class handles the mask buffer and optional schedule-driven sparsity updates.- Parameters:
target_sparsity (float)
pruning_scheme (PruningScheme)
kwargs (Any)
- __init__(target_sparsity, pruning_scheme, **kwargs)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- Parameters:
target_sparsity (float)
pruning_scheme (PruningScheme)
kwargs (Any)
Methods
compute_mask(weight, sparsity, pruning_scheme)Compute a binary pruning mask for the given weight tensor.
forward(weight)Compute / re-compute the mask if stale, and then apply it to the weight.
get_class(key)list_registry_keys()list_registry_values()register(key)Register a virtual subclass of an ABC.
update_sparsity(step_count)Update the sparsity based on the configured schedule and the provided step count.
with_args(**kwargs)Create a partial constructor with pre-filled arguments.
- abstract static compute_mask(weight, sparsity, pruning_scheme)[source]¶
Compute a binary pruning mask for the given weight tensor.
- Parameters:
weight (torch.Tensor) – The weight tensor to compute a mask for.
sparsity (float) – Fraction of elements to prune, in [0, 1].
pruning_scheme (PruningScheme) – Structural pattern of sparsity.
- Returns:
Binary mask with the same shape as weight (1 = keep, 0 = prune).
- Return type:
torch.Tensor
- forward(weight)[source]¶
Compute / re-compute the mask if stale, and then apply it to the weight.
- Parameters:
weight (Tensor)
- Return type:
Tensor
- update_sparsity(step_count)[source]¶
Update the sparsity based on the configured schedule and the provided step count.
- Raises:
RuntimeError – If no schedule is attached. This method should be invoked only after setting the
scheduleproperty.- Parameters:
step_count (int)
- Return type:
None
- classmethod with_args(**kwargs)[source]¶
Create a partial constructor with pre-filled arguments.
- Parameters:
kwargs (Any)
- Return type:
PartialConstructor[PruneImplBase]
- schedule: SparsityScheduleBase | None = None¶
- property sparsity: float¶
Sparsity that the current mask reflects. Use
update_sparsityto change.