TorchConverter API reference¶
A converter that produces a Core AI AIProgram from a PyTorch exported program.
Overview¶
from coreai_torch import TorchConverter, get_decomp_table
TorchConverter traverses the exported FX graph of a PyTorch model, maps ATen operators to Core AI operations, and emits an AIProgram ready for on-device execution.
Load a model using add_exported_program() or add_pytorch_module(), then call to_coreai() to produce the artifact. Register custom lowerings with register_torch_lowering() for ops that have no built-in mapping, and supply Metal kernel source for compute-intensive operations with register_custom_kernels().
Constructor¶
TorchConverter()
Creates a TorchConverter instance with no loaded programs and no custom lowerings.
Example:
import torch
from coreai_torch import TorchConverter
converter = TorchConverter()
Methods¶
add_exported_program¶
def add_exported_program(
self,
exported_program: ExportedProgram,
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
state_names: Sequence[str] | None = None,
entrypoint_name: str = "main",
) -> TorchConverter
Loads an exported PyTorch program into the converter. Returns self for method chaining.
Warning
The caller must call run_decompositions() on the program before passing it here — use get_decomp_table() to preserve known composite ops in the lowered IR.
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
— |
The exported PyTorch program to convert. |
|
|
|
Names for non-stateful user inputs only. Mutated inputs (buffers and user-input mutations) are renamed via |
|
|
|
Names for non-state outputs (return values) only. Mutation outputs are renamed via |
|
|
|
Names for stateful IO. Each name applies to both the state input and its corresponding mutation output — matching the Core AI runtime’s single-name-per-state model. See IO Naming below. |
|
|
|
Name of the entry-point graph in the generated Core AI module. Must be unique across all staged programs. |
Returns: self — the same TorchConverter instance, for chaining.
Example:
import torch
from coreai_torch import TorchConverter, get_decomp_table
model = MyModel()
exported = torch.export.export(model, example_inputs)
exported = exported.run_decompositions(get_decomp_table())
converter = TorchConverter().add_exported_program(
exported,
input_names=["image"],
output_names=["logits"],
)
coreai_program = converter.to_coreai()
coreai_program.optimize()
add_pytorch_module¶
def add_pytorch_module(
self,
module: nn.Module,
export_fn: Callable[[nn.Module], ExportedProgram],
externalize_modules: list | None = None,
input_names: Sequence[str] | None = None,
output_names: Sequence[str] | None = None,
state_names: Sequence[str] | None = None,
entrypoint_name: str = "main",
) -> TorchConverter
Stages an nn.Module for conversion. export_fn must return a decomposed ExportedProgram (i.e., with run_decompositions() already called). Returns self for method chaining.
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
— |
The PyTorch module to convert. |
|
|
— |
A callable that exports and decomposes the module. Must call |
|
|
|
Optional list of |
|
|
|
Names for non-stateful user inputs only. Mutated inputs (buffers and user-input mutations) are renamed via |
|
|
|
Names for non-state outputs (return values) only. Mutation outputs are renamed via |
|
|
|
Names for stateful IO. See IO Naming below. |
|
|
|
Name of the entry-point graph in the generated Core AI module. Must be unique across all staged programs. |
Returns: self — the same TorchConverter instance, for chaining.
Example:
model = MyModel().eval()
sample = (torch.randn(1, 3, 224, 224),)
converter = TorchConverter().add_pytorch_module(
model,
export_fn=lambda m: torch.export.export(m, args=sample).run_decompositions(
coreai_torch.get_decomp_table()
),
)
coreai_program = converter.to_coreai()
coreai_program.optimize()
to_coreai¶
def to_coreai(
self, *, entrypoints: Sequence[str] | None = None
) -> AIProgram
Converts the loaded program into a Core AI AIProgram. This is the main entry point for conversion. The model must have been loaded via add_exported_program() or add_pytorch_module() before calling this method.
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
If provided, only convert programs with these entrypoint names. If |
Returns: AIProgram — a compiled Core AI model ready for deployment and execution.
Example:
converter = TorchConverter().add_exported_program(
exported,
input_names=["image"],
output_names=["logits"],
)
coreai_program = converter.to_coreai()
coreai_program.optimize()
clear¶
def clear(self, *, entrypoints: Sequence[str] | None = None) -> None
Remove staged programs. Custom lowerings registered via register_torch_lowering() are always preserved.
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
|
If provided, remove only programs with these entrypoint names. If |
Example:
converter = TorchConverter()
converter.add_exported_program(ep_a, entrypoint_name="encoder")
converter.add_exported_program(ep_b, entrypoint_name="decoder")
# Remove just the encoder
converter.clear(entrypoints=["encoder"])
# Remove everything
converter.clear()
register_torch_lowering¶
def register_torch_lowering(
self,
qualified_name: str,
allow_override: bool = False,
) -> Callable
Decorator that registers a custom lowering function for an FX node. Use this to provide a Core AI implementation for a custom or unsupported torch operator.
Warning
Lowering functions are written against authoring APIs from coreai-core (such as coreai._compiler.dialects). The leading underscore on _compiler marks this as private upstream API — it may move or change without notice across coreai-core releases.
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
— |
Op name in |
|
|
|
If |
Returns: A decorator function. The decorated function must have the signature:
def lowering_func(
values_map: dict[str, Value],
node: torch.fx.Node,
loc: Location,
) -> Value | list[Value]
Callback Parameter |
Type |
Description |
|---|---|---|
|
|
Map of FX node names to their corresponding Core AI |
|
|
The FX node being lowered. Access |
|
|
The Core AI Location for the current operation. Pass this to Core AI op constructors. |
Raises:
ValueErrorifqualified_nameis not in"namespace::op_name"form.ValueErrorif the namespace is reserved (aten,higher_order,coreai,coreaix).ValueErrorif a lowering already exists for the op andallow_overrideisFalse.
Example — register a lowering for a custom torch op:
import torch
from coreai._compiler.dialects import coreai
from coreai_torch._utils import get_operands
@torch.library.custom_op("my_lib::scaled_add", mutates_args=())
def scaled_add(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor:
return x + scale * y
@scaled_add.register_fake
def _(x, y, scale):
return torch.empty_like(x)
converter = TorchConverter()
@converter.register_torch_lowering("my_lib::scaled_add.default")
def lower_scaled_add(values_map, node, loc):
x, y = get_operands(values_map, node, [0, 1], loc)
scale = node.args[2] # plain Python float
scale_val = coreai.constant(scale, dtype=x.type.element_type)
scaled_y = coreai.broadcasting_mul(y, scale_val, loc=loc)
return coreai.broadcasting_add(x, scaled_y, loc=loc)
coreai_program = converter.add_exported_program(exported).to_coreai()
coreai_program.optimize()
Example — override a built-in ATen op lowering:
import numpy as np
from coreai._compiler.dialects import coreai
from coreai_torch._utils import get_operand
converter = TorchConverter()
@converter.register_torch_lowering("aten::_adaptive_avg_pool2d.default", allow_override=True)
def lower_adaptive_avg_pool2d_static(values_map, node, loc):
x = get_operand(values_map, node, 0, loc)
output_h, output_w = node.args[1]
input_h, input_w = x.type.shape[2], x.type.shape[3]
stride_h, stride_w = input_h // output_h, input_w // output_w
kernel_h = input_h - (output_h - 1) * stride_h
kernel_w = input_w - (output_w - 1) * stride_w
return coreai.broadcasting_divide(
coreai.sumpool2d(
x,
kernel_size=np.array([kernel_h, kernel_w], dtype=np.uint32),
strides=np.array([stride_h, stride_w], dtype=np.uint32),
dilation=coreai.constant([1, 1], dtype=np.uint32),
),
coreai.cast(float(kernel_h * kernel_w), x.type.element_type),
)
coreai_program = converter.add_exported_program(exported).to_coreai()
coreai_program.optimize()
register_custom_kernels¶
def register_custom_kernels(
self,
kernels: Sequence[TorchMetalKernel],
) -> TorchConverter
Registers one or more TorchMetalKernel objects so the converter can convert them into Core AI operations during conversion. Call this before add_exported_program().
Parameter |
Type |
Default |
Description |
|---|---|---|---|
|
|
— |
The custom Metal kernels used in the model. |
Returns: self — the same TorchConverter instance, for chaining.
Example:
import torch
from coreai_torch import TorchConverter, TorchMetalKernel, MetalParameter, get_decomp_table
def torch_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
kernel = TorchMetalKernel(
"vector_add",
input_names=["x", "y"],
result_names=["output"],
src="output[id] = x[id] + y[id];",
torch_defn=torch_add,
metal_params=[MetalParameter("id", "uint", "thread_position_in_grid")],
)
converter = TorchConverter()
converter.register_custom_kernels([kernel])
exported = torch.export.export(model, args=example_inputs)
exported = exported.run_decompositions(get_decomp_table())
coreai_program = converter.add_exported_program(exported).to_coreai()
coreai_program.optimize()
See the Custom Metal Kernels guide for a full walkthrough.
Standalone functions¶
get_decomp_table¶
from coreai_torch import get_decomp_table
def get_decomp_table() -> dict
Returns the default PyTorch ATen decomposition table minus the operations that TorchConverter lowers as composite ops, so those operations are preserved in the exported graph rather than being decomposed into lower-level primitives.
Each call returns a fresh copy of the table — mutating it does not affect other callers.
Returns: dict — a decomposition table suitable for ExportedProgram.run_decompositions().
With add_exported_program (caller runs decompositions manually):
import torch
from coreai_torch import TorchConverter, get_decomp_table
ep = torch.export.export(model, args=example_inputs)
ep = ep.run_decompositions(get_decomp_table())
coreai_program = TorchConverter().add_exported_program(ep).to_coreai()
coreai_program.optimize()
With add_pytorch_module:
from coreai_torch import TorchConverter
coreai_program = (
TorchConverter()
.add_pytorch_module(
model,
export_fn=lambda m: torch.export.export(m, args=example_inputs).run_decompositions(
coreai_torch.get_decomp_table()
),
)
.to_coreai()
)
coreai_program.optimize()
IO naming¶
The input_names, output_names, and state_names parameters control the names of the graph’s inputs and outputs. These names propagate to the Core AI runtime descriptor (desc.input_names, desc.output_names, desc.state_names).
Parameters¶
Parameter |
Controls |
Scope |
|---|---|---|
|
Non-stateful |
Inputs that are read but not mutated |
|
Return values from |
Outputs that are not mutation side-effects |
|
Mutable state (buffers + mutated user inputs) |
A single name per state, applied to both the input and its corresponding mutation output |
Warning
What counts as state (no opt-out). The converter treats two things as state:
Mutable buffers registered via
self.register_buffer(...)and mutated in-place insideforward()(e.g.,self.buf.add_(x)).User inputs mutated in-place inside
forward()(e.g.,x.mul_(2)on aforward()arg).
Both are detected from the exported program’s graph signature. There is no flag to opt a mutated user input out of state — it follows directly from how the model mutates its inputs and buffers.
If you don’t want a forward() argument treated as state, eliminate the in-place mutation from your model — clone first (x_local = x.clone(); x_local.mul_(2)) or use the out-of-place form (x_scaled = x * 2). The mutation then disappears from the exported graph and the argument becomes a regular non-state input.
Default names¶
When a naming parameter is not provided, the converter uses names derived from the FX graph:
Category |
FX graph source |
Relates to |
Example |
|---|---|---|---|
Input |
Placeholder |
|
|
Output |
Output node’s input |
Internal op name |
|
State (buffer) |
Placeholder |
|
|
State (mutated user input) |
Placeholder |
|
|
Warning
These naming conventions are observed behavior from the FX graph, not a stable contract from PyTorch. They may change across PyTorch versions. Always provide explicit names for production use.
Ordering¶
Parameter |
Order follows |
|---|---|
|
Non-mutated |
|
|
|
Mutable buffers ( |
Warning
The ordering of state_names (buffers first, then mutated user inputs) is based on observed FX graph behavior, not a stable PyTorch contract. The converter asserts that the number of state inputs matches state outputs, but cannot detect silent reordering. Always verify state ordering when upgrading PyTorch versions.
Examples¶
Stateless model — only input_names and output_names apply:
class Linear(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(8, 4)
def forward(self, x):
return self.fc(x)
ep = torch.export.export(Linear().eval(), args=(torch.randn(1, 8),))
ep = ep.run_decompositions(get_decomp_table())
TorchConverter().add_exported_program(
ep,
input_names=["features"],
output_names=["logits"],
).to_coreai().optimize()
Stateful model — all three parameters:
class KVCache(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("kv_cache", torch.zeros(1, 4)) # state[0]
self.register_buffer("pos_idx", torch.zeros(1)) # state[1]
def forward(self, x, y, z):
self.kv_cache.add_(x) # buffer mutation
self.pos_idx.add_(1) # buffer mutation
y.mul_(2) # state[2]: mutated user input
# non-mutated: x -> input[0], z -> input[1]
return self.kv_cache + y, z * 3
ep = torch.export.export(
KVCache().eval(),
args=(torch.randn(1, 4), torch.randn(1, 4), torch.randn(1, 4)),
)
ep = ep.run_decompositions(get_decomp_table())
TorchConverter().add_exported_program(
ep,
state_names=["kv_cache", "pos_idx", "y_state"],
input_names=["query", "context"],
output_names=["attn_out", "scaled"],
).to_coreai().optimize()
The converter applies these names to the graph’s inputs and outputs: mutated buffers and mutated user inputs become model state (in state_names order), followed by the non-mutated input_names and the returned output_names.
Notices¶
PyTorch is a trademark of Meta Platforms, Inc.