Unified (TensorFlow and Pytorch)

coremltools.converters._converters_entry.convert(model, source='auto', inputs=None, outputs=None, classifier_config=None, minimum_deployment_target=None, convert_to=None, compute_precision=None, skip_model_load=False, compute_units=ComputeUnit.ALL, package_dir=None, debug=False, pass_pipeline: Optional[PassPipeline] = None)[source]

Convert a TensorFlow or PyTorch model to the Core ML model format as either a neural network or an ML program. Some parameters and requirements differ for TensorFlow and PyTorch conversions.

Parameters:
model

TensorFlow 1, TensorFlow 2, or PyTorch model in one of the following formats:

sourcestr (optional)

One of [auto, tensorflow, pytorch, milinternal]. auto determines the framework automatically for most cases. Raises ValueError if it fails to determine the source framework.

inputslist of TensorType or ImageType
  • If you specify dtype with TensorType or ImageType, it will be applied to the input of the converted model. For example, the following code snippet will produce a Core ML model with float 16 typed inputs.

    import coremltools as ct
    
    mlmodel = ct.convert(
        keras_model,
        inputs=[ct.TensorType(dtype=np.float16)],
        minimum_deployment_target=ct.target.macOS13,
    )
    
  • The following code snippet will produce a Core ML model with the GRAYSCALE_FLOAT16 input image type:

    import coremltools as ct
    
    # H : image height, W: image width
    mlmodel = ct.convert(
        torch_model,
        inputs=[
            ct.ImageType(shape=(1, 1, H, W), color_layout=ct.colorlayout.GRAYSCALE_FLOAT16)
        ],
        minimum_deployment_target=ct.target.macOS13,
    )
    
  • TensorFlow 1 and 2 (including tf.keras):

    • The inputs parameter is optional. If not provided, the inputs are placeholder nodes in the model (if the model is a frozen graph) or function inputs (if the model is a tf.function).

    • If inputs is provided, it must be a flat list.

    • The inputs must correspond to all or some of the placeholder nodes in the TF model.

    • If name is specified with TensorType and ImageType, it must correspond to a placeholder op in the TF graph. The input names in the converted Core ML model can later be modified using the ct.utils.rename_feature API.

    • If dtype is not specified, it defaults to the dtype of the inputs in the TF model.

    • For minimum_deployment_target >= ct.target.macOS13, and with compute_precision in float 16 precision. When inputs not provided or dtype not specified, the float 32 inputs default to float 16.

  • PyTorch:

    • TorchScript Models:
      • The inputs parameter is required.

      • Number of elements in inputs must match the number of inputs of the PyTorch model.

      • inputs may be a nested list or tuple.

      • TensorType and ImageType must have the shape specified.

      • If the name argument is specified with TensorType or ImageType, the converted Core ML model will have inputs with the same name.

      • If dtype is missing:
        • For minimum_deployment_target <= ct.target.macOS12, it defaults to float 32.

        • For minimum_deployment_target >= ct.target.macOS13, and with compute_precision in float 16 precision. It defaults to float 16.

    • Torch Exported Models:
      • The inputs parameter is not supported.

      • The inputs parameter is inferred from the Torch ExportedProgram.

outputslist of TensorType or ImageType (optional)
  • If you specify dtype with TensorType or ImageType, it will be applied to the output of the converted model. For example, to produce float 16 typed inputs and outputs:

    import coremltools as ct
    
    mlmodel = ct.convert(
        keras_model,
        inputs=[ct.TensorType(dtype=np.float16)],
        outputs=[ct.TensorType(dtype=np.float16)],
        minimum_deployment_target=ct.target.macOS13,
    )
    
  • To produce image inputs and outputs:

    import coremltools as ct
    
    # H: image height, W: image width
    mlmodel = ct.convert(
        torch_model,
        inputs=[ct.ImageType(shape=(1, 3, H, W), color_layout=ct.colorlayout.RGB)],
        outputs=[ct.ImageType(color_layout=ct.colorlayout.RGB)],
        minimum_deployment_target=ct.target.macOS13,
    )
    
  • TensorFlow 1 and 2 (including tf.keras):

    • If outputs is not specified, the converter infers outputs from the sink nodes in the graph.

    • If specified, the name with TensorType or ImageType must correspond to a node in the TF graph. In this case, the model will be converted up to that node.

    • For minimum_deployment_target >= ct.target.macOS13, and with compute_precision in float 16 precision. If dtype not specified, the outputs inferred of type float 32 default to float 16.

  • PyTorch: TorchScript Models
    • If specified, the length of the list must match the number of outputs returned by the PyTorch model.

    • If name is specified, it is applied to the output names of the converted Core ML model.

    • For minimum_deployment_target >= ct.target.macOS13, and with compute_precision in float 16 precision.

    • If dtype not specified, the outputs inferred of type float 32 default to float 16.

  • PyTorch: Torch Exported Models:
    • The outputs parameter is not supported.

    • The outputs parameter is inferred from Torch ExportedProgram.

classifier_configClassifierConfig class (optional)

The configuration if the MLModel is intended to be a classifier.

minimum_deployment_targetcoremltools.target enumeration (optional)

A member of the coremltools.target enum. The value of this parameter determines the type of the model representation produced by the converter. To learn about the differences between ML programs and neural networks, see ML Programs.

  • The converter produces a neural network (neuralnetwork) if:

    minimum_deployment_target <= coremltools.target.iOS14/
                                 coremltools.target.macOS11/
                                 coremltools.target.watchOS7/
                                 coremltools.target.tvOS14:
    
  • The converter produces an ML program (mlprogram) if:

    minimum_deployment_target >= coremltools.target.iOS15/
                                  coremltools.target.macOS12/
                                  coremltools.target.watchOS8/
                                  coremltools.target.tvOS15:
    
  • If neither the minimum_deployment_target nor the convert_to parameter is specified, the converter produces an ML program model type with as minimum of a deployment target as possible.

  • If this parameter is specified and convert_to is also specified, they must be compatible. The following are examples of invalid values:

    # Invalid:
    convert_to="mlprogram", minimum_deployment_target=coremltools.target.iOS14
    
    # Invalid:
    convert_to="neuralnetwork", minimum_deployment_target=coremltools.target.iOS15
    
convert_tostr (optional)

Must be one of ['mlprogram', 'neuralnetwork', 'milinternal']. The value of this parameter determines the type of the model representation produced by the converter. To learn about the differences between ML programs and neural networks, see ML Programs.

  • 'mlprogram' : Returns an MLModel (coremltools.models.MLModel) containing a MILSpec.Program proto, which is the Core ML program format. The model saved from this returned object is executable on iOS15, macOS12, watchOS8, and tvOS15.

  • 'neuralnetwork': Returns an MLModel (coremltools.models.MLModel) containing a NeuralNetwork proto, which is the original Core ML format. The model saved from this returned object is executable either on iOS13/macOS10.15/watchOS6/tvOS13 and newer, or on iOS14/macOS11/watchOS7/tvOS14 and newer, depending on the layers used in the model.

  • 'milinternal': Returns an MIL program object (coremltools.converters.mil.Program). An MIL program is primarily used for debugging and inspection. It can be converted to an MLModel for execution by using one of the following:

    ct.convert(mil_program, convert_to="neuralnetwork")
    ct.convert(mil_program, convert_to="mlprogram")
    
  • If neither the minimum_deployment_target nor the convert_to parameter is specified, the converter produces the ML programs model type with as minimum of a deployment target as possible.

compute_precisioncoremltools.precision enumeration or ct.transform.FP16ComputePrecision() (optional)

Use this argument to control the storage precision of the tensors in the ML program. Must be one of the following.

  • coremltools.precision.FLOAT16 enum: The following transform is applied to produce a float 16 program; that is, a program in which all the intermediate float tensors are of type float 16 (for ops that support that type).

    coremltools.transform.FP16ComputePrecision(op_selector=
                                               lambda op:True)
    

    The above transform iterates through all the ops, looking at each op’s inputs and outputs. If they are of type float 32, cast ops are injected to convert those tensors (also known as vars) to type float 16. Similarly, int32 vars will also be cast to int16.

  • coremltools.precision.FLOAT32 enum: No transform is applied.

    The original float32 tensor dtype in the source model is preserved. Opt into this option if the default converted model is displaying numerical precision issues.

  • coremltools.transform.FP16ComputePrecision(op_selector=...)

    Use this option to control which tensors are cast to float 16. Before casting the inputs/outputs of any op from float32 to float 16, the op_selector function is invoked on the op object. This function must return a boolean value. By default it returns True for every op, but you can customize this.

    For example:

    coremltools.transform.FP16ComputePrecision(op_selector=
                                lambda op: op.op_type != "linear")
    

    The above casts all the float32 tensors to be float 16, except the input/output tensors to any linear op. See more examples below.

  • None: The default
    • When convert_to="mlprogram", the compute_precision parameter defaults to coremltools.precision.FLOAT16.

    • When convert_to="neuralnetwork", the compute_precision parameter needs to be None and has no meaning.

    • For example, you can customize the float 16 precision transform to prevent casting all the real_div ops in the program to float 16 precision:

      def skip_real_div_ops(op):
          if op.op_type == "real_div":
              return False
          return True
      
      model = ct.convert(
          source_model,
          compute_precision=ct.transform.FP16ComputePrecision(op_selector=skip_real_div_ops),
          minimum_deployment_target=ct.target.iOS15,
      )
      
skip_model_loadbool

Set to True to prevent coremltools from calling into the Core ML framework to compile and load the model, post-conversion. In that case, the returned model object cannot be used to make a prediction, but can be used to save with model.save(). This flag may be used to convert to a newer model type on an older Mac, which may raise a runtime warning if done without turning this flag on.

Example: Use this flag to suppress a runtime warning when converting to an ML program model on macOS 11, since an ML program can only be compiled and loaded from macOS12+.

Defaults to False.

compute_units: coremltools.ComputeUnit

The set of processing units the model can use to make predictions. After conversion, the model is loaded with the provided set of compute units and returned.

An enum with the following possible values:

  • coremltools.ComputeUnit.ALL: Use all compute units available, including the neural engine.

  • coremltools.ComputeUnit.CPU_ONLY: Limit the model to only use the CPU.

  • coremltools.ComputeUnit.CPU_AND_GPU: Use both the CPU and GPU, but not the neural engine.

  • coremltools.ComputeUnit.CPU_AND_NE: Use both the CPU and neural engine, but not the GPU. Available only for macOS >= 13.0.

package_dirstr

Post conversion, the model is saved at a temporary location and loaded to form the MLModel object ready for prediction.

  • If package_dir is provided, model will be saved at this location rather than creating a temporary directory.

  • If not None, this must be a path to a directory with the extension .mlpackage.

debugbool

This flag should generally be False except for debugging purposes. Setting this flag to True produces the following behavior:

  • For Torch conversion, it will print the list of supported and unsupported ops found in the model if conversion fails due to an unsupported op.

  • For Tensorflow conversion, it will cause to display extra logging and visualizations.

pass_pipelinePassPipeline

Manage graph passes. You can control which graph passes to run and the order of the graph passes. You can also specify options for each pass. See the details in the docstring of PassPipeline (coremltools/converters/mil/mil/passes/pass_pipeline.py).

  • To avoid fusing the conv and batchnorm ops, skip the corresponding pass as shown in the following example:

    pipeline = ct.PassPipeline()
    pipeline.remove_passes({"common::fuse_conv_batchnorm"})
    mlmodel = ct.convert(model, pass_pipeline=pipeline)
    
  • To avoid folding too-large const ops that lead to a large model, set pass option as shown in the following example:

    pipeline = ct.PassPipeline()
    pipeline.set_options("common::const_elimination", {"skip_const_by_size": "1e6"})
    mlmodel = ct.convert(model, pass_pipeline=pipeline)
    

We also provide a set of predefined pass pipelines that you can directly call.

  • To avoid running all graph pass, you can use:

    mlmodel = ct.convert(model, pass_pipeline=ct.PassPipeline.EMPTY)
    
  • To only run the cleanup graph passes, like constant_elimination, dead_code_elimination, etc. You can use:

    mlmodel = ct.convert(model, pass_pipeline=ct.PassPipeline.CLEANUP)
    
  • To convert a source model with sparse weights to a sparse format Core ML model, you can use:

    mlmodel = ct.convert(model, pass_pipeline=ct.PassPipeline.DEFAULT_PRUNING)
    
  • To convert a source model with palettized weights to a compressed format Core ML model, you can use:

    mlmodel = ct.convert(model, pass_pipeline=ct.PassPipeline.DEFAULT_PALETTIZATION)
    
Returns:
modelcoremltools.models.MLModel or coremltools.converters.mil.Program

A Core ML MLModel object or MIL program object (see convert_to).

Examples

TensorFlow 1, 2 (model is a frozen graph):

>>> with tf.Graph().as_default() as graph:
>>>     x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input")
>>>     y = tf.nn.relu(x, name="output")

Automatically infer inputs and outputs:

>>> mlmodel = ct.convert(graph)
>>> test_input = np.random.rand(1, 2, 3) - 0.5
>>> results = mlmodel.predict({"input": test_input})
>>> print(results['output'])

TensorFlow 2 (model is a tf.Keras model path):

>>> x = tf.keras.Input(shape=(32,), name='input')
>>> y = tf.keras.layers.Dense(16, activation='softmax')(x)
>>> keras_model = tf.keras.Model(x, y)
>>> keras_model.save(h5_path)
>>> mlmodel = ct.convert(h5_path)
>>> test_input = np.random.rand(2, 32)
>>> results = mlmodel.predict({'input': test_input})
>>> print(results['Identity'])

PyTorch:

TorchScript Models:

>>> model = torchvision.models.mobilenet_v2()
>>> model.eval()
>>> example_input = torch.rand(1, 3, 256, 256)
>>> traced_model = torch.jit.trace(model, example_input)
>>> input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256))
>>> mlmodel = ct.convert(traced_model, inputs=[input])
>>> results = mlmodel.predict({"input": example_input.numpy()})
>>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT

See Conversion Options for more advanced options.