Model Exporting#
The recommended way to generate ExportedProgram for your model is to use PyTorch’s torch.export.export. Exporting runs an example input tensor through your model, and captures the operations that are invoked as that input makes its way through the model’s layers.
Exporting Limitations
The conversion from torch.export
graph has been newly added to Core ML Tools 8.0.
It is currently in beta state, in line with the export API status in PyTorch.
As of Core ML Tools 8.0, representative models such as MobileBert, ResNet, ViT, MobileNet, DeepLab, OpenELM can be converted, and the total PyTorch op translation test coverage is roughly ~70%. You can start trying the torch.export path on your models that are working with torch.jit.trace already, so as to gradually move them to the export path as PyTorch also moves its support and development to that path over a period of time. In case you hit issues (e.g. models converted via export path are slower than the ones converted from jit.trace path), please report them on Github.
Also, torch.export has limitations, see here
Requirements#
This example requires PyTorch and Core ML Tools 8.0 or newer versions. Use the following commands:
pip install torch
pip install coremltools
At the time of creating this example, the author environment is
torch 2.4.1
coremltools 8.0
Export and Convert your Model#
The following example builds a simple model from scratch and exports it to generate the ExportedProgram object needed by the converter. Follow these steps:
Define a simple layer module to reuse:
import torch import torch.nn as nn import torch.nn.functional as F # Define a simple layer module we will reuse in our network class Attention(nn.Module): def __init__(self): super().__init__() self.wq = nn.Linear(32, 64) self.wk = nn.Linear(32, 64) self.wv = nn.Linear(32, 64) self.wo = nn.Linear(64, 32) def forward(self, embedding): q = self.wq(embedding) k = self.wk(embedding) v = self.wv(embedding) attention = F.scaled_dot_product_attention(q, k, v) o = self.wo(attention) return o
Define a simple network consisting of several base layers:
# A simple network consisting of several base layers class Transformer(nn.Module): def __init__(self): super().__init__() self.attention = Attention() self.w1 = nn.Linear(32, 16) self.w2 = nn.Linear(16, 32) self.w3 = nn.Linear(32, 16) def forward(self, embedding): attention = self.attention(embedding) x = embedding + attention y = self.w2(F.silu(self.w1(x)) * self.w3(x)) z = x + y return z
Instantiate the network:
# Instantiate the network model = Transformer() model.eval()
Define the example input, which is needed by exporter, and export the model.
Run
torch.export.export
on your model with an example input, and save the resulting exported object. For an example input, you can use one sample of training or validation data, or even use randomly-generated data as shown in the following code snippet:# Example input, needed by exporter example_input = (torch.rand(2, 32),) # Define dynamic shapes to be considered by exporter, if any batch_dim = torch.export.Dim(name="batch_dim", min=1, max=128) dynamic_shapes = {"embedding": {0: batch_dim}} # Generate ExportedProgram by exporting exported_model = torch.export.export(model, example_input, dynamic_shapes=dynamic_shapes)
Convert the exported model to Core ML:
import coremltools as ct mlmodel = ct.convert(exported_model)
Difference from Tracing#
For tracing, ct.convert
requires the inputs
arg from user. This is no longer required for exporting, since the ExportedProgram object carries all name and shape and dtype info, so TensorType
, RangeDim
, and StateType
will be automatically created based on ExportedProgram info if inputs
is abscent. There are 3 cases where inputs
is still necessary
Customize name / dtype
Another difference between tracing and exporting is how to create dynamic shapes. Torch.jit.trace simply traces the executed torch ops and does not have the concept of dynamism, so dynamic shapes are specified and propagated in ct.convert
. Torch.export, however, rigorously expresses dynamism, so dynamic shapes are first specified and propagated in torch.export, then when calling ct.convert
If
RangeDim
is desired, then nothing more is needed, since it will be automatically converted fromtorch.export.Dim
Else if
EnumeratedShapes
are desired, then user will need to specify shape enumeration ininputs
arg, and only the torch.export dynamic dimensions are allowed to have more-than-1 possible sizes