Source code for coremltools.converters.mil.mil.ops.defs.iOS18.recurrent

# Copyright (c) 2024, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause


from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.input_type import InputSpec, TensorInputType
from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op
from coremltools.converters.mil.mil.ops.defs.iOS17.recurrent import gru as _gru_iOS17
from coremltools.converters.mil.mil.ops.defs.iOS18 import _IOS18_TARGET


[docs] @register_op(opset_version=_IOS18_TARGET) class gru(_gru_iOS17): """ Gated Recurrent Unit (GRU) The only difference between this version and the iOS 17 :py:class:`~.iOS17.recurrent.gru` is the reset_after parameter. This parameter is optional and defaults to False. When True, the reset gate is applied before the elementwise matrix multiplication. """ input_spec = InputSpec( x=TensorInputType(type_domain="T"), initial_h=TensorInputType(type_domain="T"), weight_ih=TensorInputType(const=True, type_domain="T"), weight_hh=TensorInputType(const=True, type_domain="T"), bias=TensorInputType(const=True, optional=True, type_domain="T"), direction=TensorInputType(const=True, optional=True, type_domain=types.str), output_sequence=TensorInputType(const=True, optional=True, type_domain=types.bool), recurrent_activation=TensorInputType(const=True, optional=True, type_domain=types.str), activation=TensorInputType(const=True, optional=True, type_domain=types.str), reset_after=TensorInputType(const=True, optional=True, type_domain=types.bool), input_bias=TensorInputType(const=True, optional=True, type_domain="T"), )