coreai_opt.inspection.ModelInspector¶
- class coreai_opt.inspection.ModelInspector(model, example_inputs, execution_mode, compressor=None, dynamic_shapes=None, export_with_no_grad=True)[source]¶
Bases:
objectInspect operations in a PyTorch model for compression configuration.
Accepts an
nn.Modulewith example inputs, auto-exports the model (for graph mode), and provides query methods for discovering operation names, types, and module hierarchy.- summary¶
The underlying operation summary.
- Type:
- Parameters:
model (torch.fx.GraphModule | torch.nn.Module) – The model to inspect.
example_inputs (tuple[Any, ...] | None) – Example inputs for tracing.
execution_mode (ExecutionMode) – Execution mode to use for model inspection.
compressor (type[_BaseModelCompressor] | None) – A compressor class (e.g.,
Quantizer) to filter ops to only those supported by that compression algorithm. WhenNone, all ops in the model are included.dynamic_shapes (dict[str, Any] | tuple[Any] | list[Any] | None) – Only relevant for graph execution mode. Optional dynamic shapes specification for torch.export.
export_with_no_grad (bool) – Only relevant for “graph” execution mode. Whether to call torch.export.export within a torch.no_grad() context. Defaults to True.
- Raises:
TypeError – If model is not an
nn.Module.NotImplementedError – If execution_mode is
"eager".RuntimeError – If model export fails (graph mode).
ValueError – If example_inputs is None without the right model/execution_mode combination, or if execution_mode is not either “eager” or “graph”.
Example
>>> import torch >>> import torch.nn as nn >>> from coreai_opt.inspection import ModelInspector >>> from coreai_opt.quantization import Quantizer >>> model = nn.Sequential(nn.Linear(10, 5))
Inspect all compressable ops for the Quantizer compressor:
>>> inspector = ModelInspector(model, (torch.randn(1, 10),), ... execution_mode="graph", compressor=Quantizer)
Query all ops in the model:
>>> ops = inspector.summary.model.all_ops()Pretty print color coded summary of model inspection:
>>> print(inspector.format_summary())Navigate the module hierarchy:
>>> root = inspector.summary.model >>> for name, child in root.named_children(): ... print(f"{name}: {child.module_type}, {len(child.ops)} ops")
Look up a specific submodule by fully-qualified name:
>>> linear_mod = root.get_submodule("0") >>> print(linear_mod.module_type) # torch.nn.modules.linear.Linear >>> print(linear_mod.ops) # ops directly owned by this module
Get all ops under a subtree (the module and all its descendants):
>>> subtree_ops = linear_mod.all_ops()Filter ops by type, name pattern, or module with the same filtering logic which Quantizer uses:
>>> inspector.get_matched_ops_for_op_type("linear") >>> inspector.get_matched_ops_for_op_name(".*linear.*") >>> inspector.get_matched_ops_for_module_type(nn.Linear)
- __init__(model, example_inputs, execution_mode, compressor=None, dynamic_shapes=None, export_with_no_grad=True)[source]¶
- Parameters:
model (GraphModule | Module)
example_inputs (tuple[Any, ...] | None)
execution_mode (ExecutionMode)
compressor (type[_BaseModelCompressor] | None)
dynamic_shapes (dict[str, Any] | tuple[Any] | list[Any] | None)
export_with_no_grad (bool)
- Return type:
None
Methods
format_summary([colorize])Format discovered operations as a module-hierarchy tree string.
get_matched_ops_for_module_name(module_name)Return operations whose module stack contains the given module name.
get_matched_ops_for_module_type(module_type)Return operations whose module stack contains the given type.
get_matched_ops_for_op_name(pattern)Return operations whose name matches the given regex pattern.
get_matched_ops_for_op_type(op_type)Return operations matching the given op type.
- format_summary(colorize=None)[source]¶
Format discovered operations as a module-hierarchy tree string.
- Parameters:
colorize (bool | None) – Whether to include ANSI color codes in the output.
None(default) auto-detects based on terminal capabilities. PassFalsewhen writing to files or logs.- Returns:
The formatted tree.
- Return type:
str
- get_matched_ops_for_module_name(module_name)[source]¶
Return operations whose module stack contains the given module name.
Uses
re.fullmatchagainst each module FQN in the op’s module stack, consistent with howmodule_name_configspatterns are matched in Graph mode.- Parameters:
module_name (str) – A regex pattern to match against module FQNs (e.g.,
"encoder.layer1","encoder\..*").- Returns:
Matching operations.
- Return type:
tuple[OpInfo, …]
- Raises:
ValueError – If module_name is not a valid regex.
- get_matched_ops_for_module_type(module_type)[source]¶
Return operations whose module stack contains the given type.
Matches using exact string equality against the fully-qualified type name, consistent with how
module_type_configskeys are resolved in the quantizer. Accepts either a class (converted viafqn()) or a fully-qualified type string (e.g.,"torch.nn.modules.conv.Conv2d").- Parameters:
module_type (type | str) – Module type to filter by.
- Returns:
Matching operations.
- Return type:
tuple[OpInfo, …]
- get_matched_ops_for_op_name(pattern)[source]¶
Return operations whose name matches the given regex pattern.
Uses
re.fullmatch, consistent with howop_name_configpatterns are matched in Graph mode.- Parameters:
pattern (str) – A regex pattern to match against op names.
- Returns:
Matching operations.
- Return type:
tuple[OpInfo, …]
- Raises:
ValueError – If pattern is not a valid regex.
- get_matched_ops_for_op_type(op_type)[source]¶
Return operations matching the given op type.
- Parameters:
op_type (str) – The operation type to filter by (e.g.,
"conv2d","linear").- Returns:
Matching operations.
- Return type:
tuple[OpInfo, …]
- property summary: ModelSummary¶
The underlying operation summary.