Source code for coremltools.models.array_feature_extractor
# Copyright (c) 2017, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
from .. import SPECIFICATION_VERSION
from ..proto import Model_pb2 as _Model_pb2
from . import datatypes
from ._interface_management import set_transform_interface_params
[docs]
def create_array_feature_extractor(
input_features, output_name, extract_indices, output_type=None
):
"""
Creates a feature extractor from an input array ``(feature, return)``.
Parameters
----------
input_features:
A list of one ``(name, array)`` tuple.
extract_indices:
Either an integer or a list.
If it's an integer, the output type is by default a double (but may also be an integer).
If a list, the output type is an array.
"""
# Make sure that our starting stuff is in the proper form.
assert len(input_features) == 1
assert isinstance(input_features[0][1], datatypes.Array)
# Create the model.
spec = _Model_pb2.Model()
spec.specificationVersion = SPECIFICATION_VERSION
if isinstance(extract_indices, int):
extract_indices = [extract_indices]
if output_type is None:
output_type = datatypes.Double()
elif isinstance(extract_indices, (list, tuple)):
if not all(isinstance(x, int) for x in extract_indices):
raise TypeError("extract_indices must be an integer or a list of integers.")
if output_type is None:
output_type = datatypes.Array(len(extract_indices))
else:
raise TypeError("extract_indices must be an integer or a list of integers.")
output_features = [(output_name, output_type)]
for idx in extract_indices:
assert idx < input_features[0][1].num_elements
spec.arrayFeatureExtractor.extractIndex.append(idx)
set_transform_interface_params(spec, input_features, output_features)
return spec