Source code for coremltools.converters.mil.mil.ops.defs.coreml_dialect.ops

#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.input_type import InputSpec, StateInputType, TensorInputType
from coremltools.converters.mil.mil.operation import Operation
from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op


[docs] @register_op(namespace="coreml") class coreml_update_state(Operation): """ Copy the content of a variable into a state and return the copy of the variable. The type of the variable must match the type that is wrapped inside the state. This is a coreml dialect op to simplify the program. When loading into MIL, the following transformation is done: .. code-block:: %x = coreml_update_state(state=%state, value=%value) --> write_state(state=%state, value=%value) %x = read_state(input=%state) Parameters ---------- state: state<ST> (Required) value: ST (Required) Returns ------- ST Attributes ---------- ST: tensor """ input_spec = InputSpec( state=StateInputType(), value=TensorInputType(type_domain="T"), ) type_domains = { "T": ( types.fp16, types.fp32, types.int8, types.int16, types.int32, types.uint8, types.uint16, types.bool, ), } def type_inference(self): state_wrapped_type = self.state._sym_type.wrapped_type() if not state_wrapped_type == self.value.sym_type: raise ValueError( f"State wrapped type {state_wrapped_type.__type_info__()} not matched with the value's sym_type {self.value.sym_type.__type_info__()}." ) return self.value.sym_type