Unified (TensorFlow and Pytorch)¶
-
coremltools.converters._converters_entry.
_determine_source
(model, source, outputs)¶ Infer source (which can be auto) to the precise framework.
-
coremltools.converters._converters_entry.
_validate_inputs
(model, exact_source, inputs, outputs, classifier_config, **kwargs)¶ Validate and process model, inputs, outputs, classifier_config based on exact_source (which cannot be auto)
-
coremltools.converters._converters_entry.
convert
(model, source='auto', inputs=None, outputs=None, classifier_config=None, minimum_deployment_target=None, convert_to='nn_proto', **kwargs)¶ Convert TensorFlow or Pytorch models to the Core ML model format. Whether a parameter is required may differ between frameworks (see below). Note that this function is aliased as ct.convert in the tutorials.
- Parameters
- model:
TensorFlow 1, TensorFlow 2 or Pytorch model in one of the following formats:
- For TensorFlow versions 1.x:
Frozen tf.Graph
Frozen graph (
.pb
) file pathHDF5 file path (
.h5
)SavedModel directory path
- For TensorFlow versions 2.x:
HDF5 file path (
.h5
)SavedModel directory path
- For Pytorch:
A TorchScript object
Path to a
.pt
file
- source: str (optional)
One of [
auto
,tensorflow
,pytorch
,mil
]. auto determines the framework automatically for most cases. Raise ValueError if it fails to determine the source framework.- inputs: list of `TensorType` or `ImageType`
- TensorFlow 1 and 2:
inputs
are optional. If not provided, the inputs are Placeholder nodes in the model (if model is frozen graph) or function inputs (if model is tf function)inputs
must corresponds to all or some of the Placeholder nodes in the TF modelTensorType
andImageType
ininputs
must havename
specified.shape
is optional.If
inputs
is provided, it must be a flat list.
- PyTorch:
inputs
are required.inputs
may be nested list or tuple.TensorType
andImageType
ininputs
must havename
andshape
specified.
- outputs: list[str] (optional)
- TensorFlow 1 and 2:
outputs
are optional.If specified,
outputs
is a list of string representing node names.If
outputs
are not specified, converter infers outputs as all terminal identity nodes.
- PyTorch:
outputs
must not be specified.
- classifier_config: ClassifierConfig class (optional)
The configuration if the mlmodel is intended to be a classifier.
- minimum_deployment_target: coremltools.target enumeration (optional)
one of the members of enum
coremltools.target
.When not-specified or None, converter aims for as minimum of a deployment target as possible
- convert_to: str (optional)
Must be one of [
'nn_proto'
,'mil'
].'nn_proto'
: Returns MLModel containing a NeuralNetwork proto'mil'
: Returns MIL program object. MIL program is primarily used for debugging purpose and currently cannot be compiled to executable.
- Returns
- model: coremltools.models.MLModel or coremltools.converters.mil.Program
A Core ML MLModel object or MIL Program object (see
convert_to
)
Examples
TensorFlow 1, 2 (
model
is a frozen graph):>>> with tf.Graph().as_default() as graph: >>> x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input") >>> y = tf.nn.relu(x, name="output")
# Automatically infer inputs and outputs
>>> mlmodel = ct.convert(graph) >>> test_input = np.random.rand(1, 2, 3) - 0.5 >>> results = mlmodel.predict({"input": test_input}) >>> print(results['output'])
TensorFlow 2 (
model
is tf.Keras model path):>>> x = tf.keras.Input(shape=(32,), name='input') >>> y = tf.keras.layers.Dense(16, activation='softmax')(x) >>> keras_model = tf.keras.Model(x, y)
>>> keras_model.save(h5_path) >>> mlmodel = ct.convert(h5_path)
>>> test_input = np.random.rand(2, 32) >>> results = mlmodel.predict({'input': test_input}) >>> print(results['Identity'])
Pytorch:
>>> model = torchvision.models.mobilenet_v2() >>> model.eval() >>> example_input = torch.rand(1, 3, 256, 256) >>> traced_model = torch.jit.trace(model, example_input)
>>> input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256)) >>> mlmodel = ct.convert(traced_model, inputs=[input]) >>> results = mlmodel.predict({"input": example_input.numpy()}) >>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT
See neural-network-conversion for more advanced options.