Unified (TensorFlow and Pytorch)

coremltools.converters._converters_entry._determine_source(model, source, outputs)

Infer source (which can be auto) to the precise framework.

coremltools.converters._converters_entry._validate_inputs(model, exact_source, inputs, outputs, classifier_config, **kwargs)

Validate and process model, inputs, outputs, classifier_config based on exact_source (which cannot be auto)

coremltools.converters._converters_entry.convert(model, source='auto', inputs=None, outputs=None, classifier_config=None, minimum_deployment_target=None, convert_to='nn_proto', **kwargs)

Convert TensorFlow or Pytorch models to the Core ML model format. Whether a parameter is required may differ between frameworks (see below). Note that this function is aliased as ct.convert in the tutorials.

Parameters
model:

TensorFlow 1, TensorFlow 2 or Pytorch model in one of the following formats:

For TensorFlow versions 1.x:
For TensorFlow versions 2.x:
For Pytorch:
source: str (optional)

One of [auto, tensorflow, pytorch, mil]. auto determines the framework automatically for most cases. Raise ValueError if it fails to determine the source framework.

inputs: list of `TensorType` or `ImageType`
TensorFlow 1 and 2:
  • inputs are optional. If not provided, the inputs are Placeholder nodes in the model (if model is frozen graph) or function inputs (if model is tf function)

  • inputs must corresponds to all or some of the Placeholder nodes in the TF model

  • TensorType and ImageType in inputs must have name specified. shape is optional.

  • If inputs is provided, it must be a flat list.

PyTorch:
  • inputs are required.

  • inputs may be nested list or tuple.

  • TensorType and ImageType in inputs must have name and shape specified.

outputs: list[str] (optional)
TensorFlow 1 and 2:
  • outputs are optional.

  • If specified, outputs is a list of string representing node names.

  • If outputs are not specified, converter infers outputs as all terminal identity nodes.

PyTorch:
  • outputs must not be specified.

classifier_config: ClassifierConfig class (optional)

The configuration if the mlmodel is intended to be a classifier.

minimum_deployment_target: coremltools.target enumeration (optional)
  • one of the members of enum coremltools.target.

  • When not-specified or None, converter aims for as minimum of a deployment target as possible

convert_to: str (optional)
  • Must be one of ['nn_proto', 'mil'].

  • 'nn_proto': Returns MLModel containing a NeuralNetwork proto

  • 'mil': Returns MIL program object. MIL program is primarily used for debugging purpose and currently cannot be compiled to executable.

Returns
model: coremltools.models.MLModel or coremltools.converters.mil.Program

A Core ML MLModel object or MIL Program object (see convert_to)

Examples

TensorFlow 1, 2 (model is a frozen graph):

>>> with tf.Graph().as_default() as graph:
>>>     x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input")
>>>     y = tf.nn.relu(x, name="output")

# Automatically infer inputs and outputs

>>> mlmodel = ct.convert(graph)
>>> test_input = np.random.rand(1, 2, 3) - 0.5
>>> results = mlmodel.predict({"input": test_input})
>>> print(results['output'])

TensorFlow 2 (model is tf.Keras model path):

>>> x = tf.keras.Input(shape=(32,), name='input')
>>> y = tf.keras.layers.Dense(16, activation='softmax')(x)
>>> keras_model = tf.keras.Model(x, y)
>>> keras_model.save(h5_path)
>>> mlmodel = ct.convert(h5_path)
>>> test_input = np.random.rand(2, 32)
>>> results = mlmodel.predict({'input': test_input})
>>> print(results['Identity'])

Pytorch:

>>> model = torchvision.models.mobilenet_v2()
>>> model.eval()
>>> example_input = torch.rand(1, 3, 256, 256)
>>> traced_model = torch.jit.trace(model, example_input)
>>> input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256))
>>> mlmodel = ct.convert(traced_model, inputs=[input])
>>> results = mlmodel.predict({"input": example_input.numpy()})
>>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT

See neural-network-conversion for more advanced options.