Unified (TensorFlow and Pytorch)
- coremltools.converters._converters_entry.convert(model, source='auto', inputs=None, outputs=None, classifier_config=None, minimum_deployment_target=None, convert_to=None, compute_precision=None, skip_model_load=False, compute_units=ComputeUnit.ALL, package_dir=None, debug=False, pass_pipeline: Optional[PassPipeline] = None)[source]
Convert a TensorFlow or PyTorch model to the Core ML model format as either a neural network or an ML program. Some parameters and requirements differ for TensorFlow and PyTorch conversions.
- Parameters:
- model
TensorFlow 1, TensorFlow 2, or PyTorch model in one of the following formats:
TensorFlow versions 1.x
Frozen tf.Graph
Frozen graph (
.pb
) file pathHDF5 file path (
.h5
)SavedModel directory path
TensorFlow versions 2.x
HDF5 file path (
.h5
)SavedModel directory path
A GraphDef
PyTorch
A TorchScript object
Path to a
.pt
file
- sourcestr (optional)
One of [
auto
,tensorflow
,pytorch
,milinternal
].auto
determines the framework automatically for most cases. RaisesValueError
if it fails to determine the source framework.- inputslist of
TensorType
orImageType
If you specify
dtype
withTensorType
orImageType
, it will be applied to the input of the converted model. For example, the following code snippet will produce a Core ML model with float 16 typed inputs.import coremltools as ct mlmodel = ct.convert( keras_model, inputs=[ct.TensorType(dtype=np.float16)], minimum_deployment_target=ct.target.macOS13, )
The following code snippet will produce a Core ML model with the
GRAYSCALE_FLOAT16
input image type:import coremltools as ct # H : image height, W: image width mlmodel = ct.convert( torch_model, inputs=[ ct.ImageType(shape=(1, 1, H, W), color_layout=ct.colorlayout.GRAYSCALE_FLOAT16) ], minimum_deployment_target=ct.target.macOS13, )
- TensorFlow 1 and 2 (including tf.keras):
The
inputs
parameter is optional. If not provided, the inputs are placeholder nodes in the model (if the model is a frozen graph) or function inputs (if the model is atf.function
).If
inputs
is provided, it must be a flat list.The
inputs
must correspond to all or some of the placeholder nodes in the TF model.If
name
is specified withTensorType
andImageType
, it must correspond to a placeholder op in the TF graph. The input names in the converted Core ML model can later be modifed using thect.utils.rename_feature
API.If
dtype
is not specified, it defaults to thedtype
of the inputs in the TF model.
- PyTorch:
The
inputs
parameter is required.Number of elements in
inputs
must match the number of inputs of the PyTorch model.inputs
may be a nested list or tuple.TensorType
andImageType
must have theshape
specified.If the
name
argument is specified withTensorType
orImageType
, the converted Core ML model will have inputs with the same name.If
dtype
is missing, it defaults to float 32.
- outputslist of
TensorType
orImageType
(optional) If you specify
dtype
withTensorType
orImageType
, it will be applied to the output of the converted model. For example, to produce float 16 typed inputs and outputs:import coremltools as ct mlmodel = ct.convert( keras_model, inputs=[ct.TensorType(dtype=np.float16)], outputs=[ct.TensorType(dtype=np.float16)], minimum_deployment_target=ct.target.macOS13, )
To produce image inputs and outputs:
import coremltools as ct # H: image height, W: image width mlmodel = ct.convert( torch_model, inputs=[ct.ImageType(shape=(1, 3, H, W), color_layout=ct.colorlayout.RGB)], outputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)], minimum_deployment_target=ct.target.macOS13, )
TensorFlow 1 and 2 (including tf.keras):
If
outputs
is not specified, the converter infers outputs from the sink nodes in the graph.If specified, the
name
withTensorType
orImageType
must correspond to a node in the TF graph. In this case, the model will be converted up to that node.
PyTorch:
If specified, the length of the list must match the number of outputs returned by the PyTorch model.
If
name
is specified, it is applied to the output names of the converted Core ML model.
- classifier_configClassifierConfig class (optional)
The configuration if the MLModel is intended to be a classifier.
- minimum_deployment_targetcoremltools.target enumeration (optional)
A member of the
coremltools.target
enum. The value of this parameter determines the type of the model representation produced by the converter. To learn about the differences between neural networks and ML programs, see ML Programs.The converter produces a neural network (
neuralnetwork
) if:minimum_deployment_target <= coremltools.target.iOS14/ coremltools.target.macOS11/ coremltools.target.watchOS7/ coremltools.target.tvOS14:
The converter produces an ML program (
mlprogram
) if:minimum_deployment_target >= coremltools.target.iOS15/ coremltools.target.macOS12/ coremltools.target.watchOS8/ coremltools.target.tvOS15:
If neither the
minimum_deployment_target
nor theconvert_to
parameter is specified, the converter produces the neural network model type with as minimum of a deployment target as possible.If this parameter is specified and
convert_to
is also specified, they must be compatible. The following are examples of invalid values:# Invalid: convert_to="neuralnetwork", minimum_deployment_target=coremltools.target.iOS15 # Invalid: convert_to="mlprogram", minimum_deployment_target=coremltools.target.iOS14
- convert_tostr (optional)
Must be one of [
'neuralnetwork'
,'mlprogram'
,'milinternal'
]. The value of this parameter determines the type of the model representation produced by the converter. To learn about the differences between neural networks and ML programs, see ML Programs.'neuralnetwork'
: Returns an MLModel (coremltools.models.MLModel
) containing a NeuralNetwork proto, which is the original Core ML format. The model saved from this returned object is executable either on iOS13/macOS10.15/watchOS6/tvOS13 and newer, or on iOS14/macOS11/watchOS7/tvOS14 and newer, depending on the layers used in the model.'mlprogram'
: Returns an MLModel (coremltools.models.MLModel
) containing a MILSpec.Program proto, which is the Core ML program format. The model saved from this returned object is executable on iOS15, macOS12, watchOS8, and tvOS15.'milinternal'
: Returns an MIL program object (coremltools.converters.mil.Program
). An MIL program is primarily used for debugging and inspection. It can be converted to an MLModel for execution by using one of the following:ct.convert(mil_program, convert_to="neuralnetwork") ct.convert(mil_program, convert_to="mlprogram")
If neither the
minimum_deployment_target
nor theconvert_to
parameter is specified, the converter produces the neural network model type with as minimum of a deployment target as possible.
- compute_precisioncoremltools.precision enumeration or ct.transform.FP16ComputePrecision() (optional)
Use this argument to control the storage precision of the tensors in the ML program. Must be one of the following.
coremltools.precision.FLOAT16
enum: The following transform is applied to produce a float 16 program; that is, a program in which all the intermediate float tensors are of type float 16 (for ops that support that type).coremltools.transform.FP16ComputePrecision(op_selector= lambda op:True)
The above transform iterates through all the ops, looking at each op’s inputs and outputs. If they are of type float 32,
cast
ops are injected to convert those tensors (also known as vars) to type float 16.coremltools.precision.FLOAT32
enum: No transform is applied.The original float32 tensor dtype in the source model is preserved. Opt into this option if the default converted model is displaying numerical precision issues.
coremltools.transform.FP16ComputePrecision(op_selector=...)
Use this option to control which tensors are cast to float 16. Before casting the inputs/outputs of any op from float32 to float 16, the op_selector function is invoked on the op object. This function must return a boolean value. By default it returns
True
for every op, but you can customize this.For example:
coremltools.transform.FP16ComputePrecision(op_selector= lambda op: op.op_type != "linear")
The above casts all the float32 tensors to be float 16, except the input/output tensors to any
linear
op. See more examples below.None
: The defaultWhen
convert_to="mlprogram"
, thecompute_precision
parameter defaults tocoremltools.precision.FLOAT16
.When
convert_to="neuralnetwork"
, thecompute_precision
parameter needs to beNone
and has no meaning.For example, you can customize the float 16 precision transform to prevent casting all the
real_div
ops in the program to float 16 precision:def skip_real_div_ops(op): if op.op_type == "real_div": return False return True model = ct.convert( source_model, compute_precision=ct.transform.FP16ComputePrecision(op_selector=skip_real_div_ops), minimum_deployment_target=ct.target.iOS15, )
- skip_model_loadbool
Set to
True
to prevent coremltools from calling into the Core ML framework to compile and load the model, post-conversion. In that case, the returned model object cannot be used to make a prediction, but can be used to save withmodel.save()
. This flag may be used to convert to a newer model type on an older Mac, which may raise a runtime warning if done without turning this flag on.Example: Use this flag to suppress a runtime warning when converting to an ML program model on macOS 11, since an ML program can only be compiled and loaded from macOS12+.
Defaults to
False
.- compute_units: coremltools.ComputeUnit
An enum with the following possible values.
coremltools.ComputeUnit.ALL
: Use all compute units available, including the neural engine.coremltools.ComputeUnit.CPU_ONLY
: Limit the model to only use the CPU.coremltools.ComputeUnit.CPU_AND_GPU
: Use both the CPU and GPU, but not the neural engine.coremltools.ComputeUnit.CPU_AND_NE
: Use both the CPU and neural engine, but not the GPU. Available only for macOS >= 13.0.
- package_dirstr
Post conversion, the model is saved at a temporary location and loaded to form the MLModel object ready for prediction.
If
package_dir
is provided, model will be saved at this location rather than creating a temporary directory.If not
None
, this must be a path to a directory with the extension.mlpackage
.
- debugbool
This flag should generally be
False
except for debugging purposes. Setting this flag toTrue
produces the following behavior:For Torch conversion, it will print the list of supported and unsupported ops found in the model if conversion fails due to an unsupported op.
For Tensorflow conversion, it will cause to display extra logging and visualizations.
- pass_pipelinePassPipeline
Manage graph passes. You can control which graph passes to run and the order of the graph passes. You can also specify options for each pass. See the details in the docstring of PassPipeline (
coremltools/converters/mil/mil/passes/pass_pipeline.py
).To avoid fusing the
conv
andbatchnorm
ops, skip the corresponding pass as shown in the following example:pipeline = ct.PassPipeline() pipeline.remove_passes({"common::fuse_conv_batchnorm"}) ct.convert(model, pass_pipeline=pipeline)
To avoid folding too-large
const
ops that lead to a large model, set pass option as shown in the following example:pipeline = ct.PassPipeline() pipeline.set_options("common::const_elimination", {"skip_const_by_size": "1e6"}) ct.convert(model, pass_pipeline=pipeline)
- Returns:
- model
coremltools.models.MLModel
orcoremltools.converters.mil.Program
A Core ML MLModel object or MIL program object (see
convert_to
).
- model
Examples
TensorFlow 1, 2 (
model
is a frozen graph):>>> with tf.Graph().as_default() as graph: >>> x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input") >>> y = tf.nn.relu(x, name="output")
Automatically infer inputs and outputs:
>>> mlmodel = ct.convert(graph) >>> test_input = np.random.rand(1, 2, 3) - 0.5 >>> results = mlmodel.predict({"input": test_input}) >>> print(results['output'])
TensorFlow 2 (
model
is a tf.Keras model path):>>> x = tf.keras.Input(shape=(32,), name='input') >>> y = tf.keras.layers.Dense(16, activation='softmax')(x) >>> keras_model = tf.keras.Model(x, y)
>>> keras_model.save(h5_path) >>> mlmodel = ct.convert(h5_path)
>>> test_input = np.random.rand(2, 32) >>> results = mlmodel.predict({'input': test_input}) >>> print(results['Identity'])
PyTorch:
>>> model = torchvision.models.mobilenet_v2() >>> model.eval() >>> example_input = torch.rand(1, 3, 256, 256) >>> traced_model = torch.jit.trace(model, example_input)
>>> input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256)) >>> mlmodel = ct.convert(traced_model, inputs=[input]) >>> results = mlmodel.predict({"input": example_input.numpy()}) >>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT
See Conversion Options for more advanced options.