Unified (TensorFlow and Pytorch)¶
-
coremltools.converters._converters_entry._determine_source(model, source, outputs)¶ Infer source (which can be auto) to the precise framework.
-
coremltools.converters._converters_entry._validate_inputs(model, exact_source, inputs, outputs, classifier_config, **kwargs)¶ Validate and process model, inputs, outputs, classifier_config based on exact_source (which cannot be auto)
-
coremltools.converters._converters_entry.convert(model, source='auto', inputs=None, outputs=None, classifier_config=None, minimum_deployment_target=None, convert_to='nn_proto', **kwargs)¶ Convert TensorFlow or Pytorch models to the Core ML model format. Whether a parameter is required may differ between frameworks (see below). Note that this function is aliased as ct.convert in the tutorials.
- Parameters
- model:
TensorFlow 1, TensorFlow 2 or Pytorch model in one of the following formats:
- For TensorFlow versions 1.x:
Frozen tf.Graph
Frozen graph (
.pb) file pathHDF5 file path (
.h5)SavedModel directory path
- For TensorFlow versions 2.x:
HDF5 file path (
.h5)SavedModel directory path
- For Pytorch:
A TorchScript object
Path to a
.ptfile
- source: str (optional)
One of [
auto,tensorflow,pytorch,mil]. auto determines the framework automatically for most cases. Raise ValueError if it fails to determine the source framework.- inputs: list of `TensorType` or `ImageType`
- TensorFlow 1 and 2:
inputsare optional. If not provided, the inputs are Placeholder nodes in the model (if model is frozen graph) or function inputs (if model is tf function)inputsmust corresponds to all or some of the Placeholder nodes in the TF modelTensorTypeandImageTypeininputsmust havenamespecified.shapeis optional.If
inputsis provided, it must be a flat list.
- PyTorch:
inputsare required.inputsmay be nested list or tuple.TensorTypeandImageTypeininputsmust havenameandshapespecified.
- outputs: list[str] (optional)
- TensorFlow 1 and 2:
outputsare optional.If specified,
outputsis a list of string representing node names.If
outputsare not specified, converter infers outputs as all terminal identity nodes.
- PyTorch:
outputsmust not be specified.
- classifier_config: ClassifierConfig class (optional)
The configuration if the mlmodel is intended to be a classifier.
- minimum_deployment_target: coremltools.target enumeration (optional)
one of the members of enum
coremltools.target.When not-specified or None, converter aims for as minimum of a deployment target as possible
- convert_to: str (optional)
Must be one of [
'nn_proto','mil'].'nn_proto': Returns MLModel containing a NeuralNetwork proto'mil': Returns MIL program object. MIL program is primarily used for debugging purpose and currently cannot be compiled to executable.
- Returns
- model: coremltools.models.MLModel or coremltools.converters.mil.Program
A Core ML MLModel object or MIL Program object (see
convert_to)
Examples
TensorFlow 1, 2 (
modelis a frozen graph):>>> with tf.Graph().as_default() as graph: >>> x = tf.placeholder(tf.float32, shape=(1, 2, 3), name="input") >>> y = tf.nn.relu(x, name="output")
# Automatically infer inputs and outputs
>>> mlmodel = ct.convert(graph) >>> test_input = np.random.rand(1, 2, 3) - 0.5 >>> results = mlmodel.predict({"input": test_input}) >>> print(results['output'])
TensorFlow 2 (
modelis tf.Keras model path):>>> x = tf.keras.Input(shape=(32,), name='input') >>> y = tf.keras.layers.Dense(16, activation='softmax')(x) >>> keras_model = tf.keras.Model(x, y)
>>> keras_model.save(h5_path) >>> mlmodel = ct.convert(h5_path)
>>> test_input = np.random.rand(2, 32) >>> results = mlmodel.predict({'input': test_input}) >>> print(results['Identity'])
Pytorch:
>>> model = torchvision.models.mobilenet_v2() >>> model.eval() >>> example_input = torch.rand(1, 3, 256, 256) >>> traced_model = torch.jit.trace(model, example_input)
>>> input = ct.TensorType(name='input_name', shape=(1, 3, 256, 256)) >>> mlmodel = ct.convert(traced_model, inputs=[input]) >>> results = mlmodel.predict({"input": example_input.numpy()}) >>> print(results['1651']) # 1651 is the node name given by PyTorch's JIT
See neural-network-conversion for more advanced options.