Source code for coremltools.converters.mil.mil.passes.defs.optimize_state

#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause


from typing import List

from coremltools.converters.mil.mil import Block
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import Operation, Program
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
from coremltools.converters.mil.mil.scope import ScopeInfo


[docs] @register_pass(namespace="common") class canonicalize_inplace_pattern(AbstractGraphPass): """ As a functional-graph framework, Core ML represents in-place operation as .. code-block:: read_state -> functional operation -> write_state Due to the non-uniqueness of topological order, in the list representation of ops, ``write_state`` can be anywhere after the functional op. We prefer the canonical order, i.e. have ``write_state`` immediately follow the functional op In practice 1. In PyMIL, we do not use ``write_state`` op. Instead, we use ``coreml_update_state``, which is the composition of ``write_state -> read_state`` 2. The ``read_state`` op does not matter in the pattern match and transform So we will match .. code-block:: functional operation -> coreml_update_state then reorder the ``coreml_update_state``. For example .. code-block:: Given: mul = mul(state, x) add = add(mul, y) update = coreml_update_state(state, mul) Return: mul = mul(state, x) update = coreml_update_state(state, mul) add = add(mul, y) """ def apply(self, prog: Program) -> None: for f in prog.functions.values(): self._apply_block(f) @block_context_manager def _apply_block(self, block: Block) -> None: block_operation_list = list(block.operations) for op in block_operation_list: # general boilterplate: special case when op manipulates block if op.enclosing_block is None: continue for b in op.blocks: self._apply_block(b) # Although downstream iterator (op list) gets changed, the change is only in # ``coreml_udpate_state`` op, which cannot be the pattern start and will quick return, # so no need to break and iterate self._try_match_and_transform_pattern(op, block, block_operation_list) def _try_match_and_transform_pattern( self, op: Operation, block: Block, block_operation_list: List[Operation] ) -> None: # state op itself is irrelevant if op.op_type in ("read_state", "coreml_update_state"): return coreml_update_state_ops = self._try_find_child_coreml_update_state_ops(op) for coreml_update_state_op in coreml_update_state_ops: before_op = block_operation_list[block_operation_list.index(op) + 1] scopes = self._construct_scope_info_list_from_op_scopes(op) with mb.scope(*scopes): immediate_coreml_update_state = mb.coreml_update_state( state=coreml_update_state_op.state, value=coreml_update_state_op.value, before_op=before_op, ) # We need to eliminate dead code here, # because our dead code elimination graph pass does not work for coreml_update_state if block.try_replace_uses_of_var_after_op( anchor_op=coreml_update_state_op, old_var=coreml_update_state_op.outputs[0], new_var=immediate_coreml_update_state, ): block.remove_ops([coreml_update_state_op]) @staticmethod def _try_find_child_coreml_update_state_ops(op: Operation) -> List[Operation]: coreml_update_state_ops = [] for output in op.outputs: for child_op in output.child_ops: if child_op.op_type == "coreml_update_state": coreml_update_state_ops.append(child_op) return coreml_update_state_ops @staticmethod def _construct_scope_info_list_from_op_scopes(op: Operation) -> List[ScopeInfo]: scope_info_list = [] for source, data in op.scopes.items(): scope_info_list.append(ScopeInfo(source=source, data=data)) return scope_info_list
[docs] @register_pass(namespace="common") class prefer_state_in_downstream(AbstractGraphPass): """ As a functional-graph framework, Core ML represents in-place operation as .. code-block:: read_state -> functional operation -> write_state When the output of the in-place operation is used downstream, there are 2 possible patterns, one reuses state memory .. code-block:: read_state -> functional operation -> write_state -> read_state -> ... the other wastes memory for keeping functional output .. code-block:: |-> write_state read_state -> functional operation -| |-> ... We prefer the reuse-state one In practice 1. In PyMIL, we do not use ``write_state`` op. Instead, we use ``coreml_update_state``, which is the composition of ``write_state -> read_state`` 2. With canonical inplace pattern (guaranteed by graph pass ``canonicalize_inplace_pattern``), simply replace the usage of functional output with ``coreml_update_state`` output is enough For example .. code-block:: Given: mul = mul(state, x) update = coreml_update_state(state, mul) add = add(mul, y) Return: mul = mul(state, x) update = coreml_update_state(state, mul) add = add(update, y) """ def apply(self, prog: Program) -> None: for f in prog.functions.values(): self._apply_block(f) @block_context_manager def _apply_block(self, block: Block) -> None: for op in list(block.operations): # general boilterplate: special case when op manipulates block if op.enclosing_block is None: continue for b in op.blocks: self._apply_block(b) self._try_match_and_transform_pattern(op, block) def _try_match_and_transform_pattern(self, op: Operation, block: Block) -> None: if op.op_type == "coreml_update_state": # if the var is both blck input and output, we should not replace it if op.value in block.outputs and op.value in block.inputs.values(): return other_child_ops = [val for val in op.value.child_ops if val != op] # if the var doesn't feed into any other op, this pass should do nothing if len(other_child_ops) == 0: return # if the var only feeds into coreml_update_state ops, this pass should do nothing if all([val.op_type == "coreml_update_state" for val in other_child_ops]): return block.try_replace_uses_of_var_after_op( anchor_op=op, old_var=op.value, new_var=op.outputs[0], )