Source code for coremltools.converters.mil.mil.ops.defs.iOS17.recurrent

#  Copyright (c) 2023, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause


from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.input_type import InputSpec, TensorInputType
from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op
from coremltools.converters.mil.mil.ops.defs.iOS15.recurrent import gru as _gru_iOS15
from coremltools.converters.mil.mil.ops.defs.iOS15.recurrent import lstm as _lstm_iOS15
from coremltools.converters.mil.mil.ops.defs.iOS15.recurrent import rnn as _rnn_iOS15
from coremltools.converters.mil.mil.ops.defs.iOS17 import _IOS17_TARGET


[docs]@register_op(opset_version=_IOS17_TARGET) class gru(_gru_iOS15): """ Gated Recurrent Unit (GRU) The only difference between this version and the iOS 15 :py:class:`~.iOS15.recurrent.gru` is adding the support for fp16. """ input_spec = InputSpec( x=TensorInputType(type_domain="T"), initial_h=TensorInputType(type_domain="T"), weight_ih=TensorInputType(const=True, type_domain="T"), weight_hh=TensorInputType(const=True, type_domain="T"), bias=TensorInputType(const=True, optional=True, type_domain="T"), direction=TensorInputType(const=True, optional=True, type_domain=types.str), output_sequence=TensorInputType(const=True, optional=True, type_domain=types.bool), recurrent_activation=TensorInputType(const=True, optional=True, type_domain=types.str), activation=TensorInputType(const=True, optional=True, type_domain=types.str), ) type_domains = { "T": (types.fp16, types.fp32), }
[docs]@register_op(opset_version=_IOS17_TARGET) class lstm(_lstm_iOS15): """ Long Short-Term Memory (LSTM) The only difference between this version and the iOS 15 :py:class:`~.iOS15.recurrent.lstm` is adding the support for fp16. """ input_spec = InputSpec( x=TensorInputType(type_domain="T"), initial_h=TensorInputType(type_domain="T"), initial_c=TensorInputType(type_domain="T"), weight_ih=TensorInputType(const=True, type_domain="T"), # ifoz layout, weight_hh=TensorInputType(const=True, type_domain="T"), # ifoz layout bias=TensorInputType(const=True, optional=True, type_domain="T"), # ifoz layout peephole=TensorInputType(const=True, optional=True, type_domain="T"), # ifo layout weight_ih_back=TensorInputType(const=True, optional=True, type_domain="T"), # ifoz layout, weight_hh_back=TensorInputType(const=True, optional=True, type_domain="T"), # ifoz layout bias_back=TensorInputType(const=True, optional=True, type_domain="T"), # ifoz layout peephole_back=TensorInputType(const=True, optional=True, type_domain="T"), # ifo layout direction=TensorInputType(const=True, optional=True, type_domain=types.str), output_sequence=TensorInputType(const=True, optional=True, type_domain=types.bool), recurrent_activation=TensorInputType(const=True, optional=True, type_domain=types.str), cell_activation=TensorInputType(const=True, optional=True, type_domain=types.str), activation=TensorInputType(const=True, optional=True, type_domain=types.str), clip=TensorInputType(const=True, optional=True, type_domain="T"), ) type_domains = { "T": (types.fp16, types.fp32), }
[docs]@register_op(opset_version=_IOS17_TARGET) class rnn(_rnn_iOS15): """ Recurrent Neural Network (RNN) The only difference between this version and the iOS 15 :py:class:`~.iOS15.recurrent.rnn` is adding the support for fp16. """ input_spec = InputSpec( x=TensorInputType(type_domain="T"), initial_h=TensorInputType(type_domain="T"), weight_ih=TensorInputType(const=True, type_domain="T"), weight_hh=TensorInputType(const=True, type_domain="T"), bias=TensorInputType(const=True, optional=True, type_domain="T"), direction=TensorInputType(const=True, optional=True, type_domain=types.str), output_sequence=TensorInputType(const=True, optional=True, type_domain=types.bool), activation=TensorInputType(const=True, optional=True, type_domain=types.str), ) type_domains = { "T": (types.fp16, types.fp32), }