Source code for coremltools.converters.mil.mil.passes.defs.cleanup.noop_elimination

#  Copyright (c) 2023, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause


import numpy as np

from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass


[docs]@register_pass(namespace="common") class noop_elimination(AbstractGraphPass): """ Remove ops that have no effect. .. code-block:: Given: %1 (1, 96, 128, 64, fp32) = ... %2 (1, 96, 128, 64, fp32) = reshape(%1) ... %3 (1, 96, 128, 64, fp32) = add(%2, constant) ... Result: %1 (1, 96, 128, 64, fp32) = ... %3 (1, 96, 128, 64, fp32) = add(%1, constant) ... """ _SUPPORTED_OPS = { "identity", "add", "mul", "floor_div", "pow", "real_div", "sub", "reshape", "split", "slice_by_index", "slice_by_size", "pad", "tile", "transpose", "upsample_nearest_neighbor", "upsample_bilinear", "resize_bilinear", "crop", "linear_activation", } def apply(self, prog): for f in prog.functions.values(): self._noop_elimination_block_wrapper(f) @staticmethod def _match_pattern(op): def remove_identity(op): if op.enclosing_block.try_replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=op.x, ): op.enclosing_block.remove_ops([op]) return True return False def _remove_elementwise_binary(op, x, y): # We remove the ops that has op.x == x or op.y == y def has_all_elements_equal_to(var, value): if value is None: return False if var.val is not None: return np.all(var.val == value) elif var.op is not None and var.op.op_type == "fill": fill_value = var.op.value.val return fill_value is not None and (fill_value == value) else: return False if has_all_elements_equal_to(op.x, x): input_var = op.y input_op = input_var.op elif has_all_elements_equal_to(op.y, y): input_var = op.x input_op = input_var.op else: return False input_shape = input_var.sym_type output_shape = op.outputs[0].sym_type # We might be using elementwise as broadcasting if input_shape != output_shape: return False if op.enclosing_block.try_replace_uses_of_var_after_op( anchor_op=input_op, old_var=op.outputs[0], new_var=input_var, ): op.enclosing_block.remove_ops([op]) return True return False def remove_elementwise(op): if op.op_type in {"add"}: return _remove_elementwise_binary(op, 0, 0) elif op.op_type in {"mul"}: return _remove_elementwise_binary(op, 1, 1) elif op.op_type in {"floor_div", "pow", "real_div"}: return _remove_elementwise_binary(op, None, 1) elif op.op_type in {"sub"}: return _remove_elementwise_binary(op, None, 0) else: return False def remove_slice_by_index(op): input_shape = op.x.sym_type output_shape = op.outputs[0].sym_type if input_shape != output_shape: return False if op.stride is not None and op.stride.val is not None: stride = op.stride.val.flatten().tolist() if any([x < 0 for x in stride]): return False input_var = op.x input_op = input_var.op if op.enclosing_block.try_replace_uses_of_var_after_op( anchor_op=input_op, old_var=op.outputs[0], new_var=input_var, ): op.enclosing_block.remove_ops([op]) return True return False def remove_same_shape(op): input_shape = op.x.sym_type output_shape = op.outputs[0].sym_type if input_shape != output_shape: return False input_var = op.x input_op = input_var.op if op.enclosing_block.try_replace_uses_of_var_after_op( anchor_op=input_op, old_var=op.outputs[0], new_var=input_var, ): op.enclosing_block.remove_ops([op]) return True return False def remove_linear(op): if op.alpha.val != 1 or op.beta.val != 0: return False input_var = op.x input_op = input_var.op if op.enclosing_block.try_replace_uses_of_var_after_op( anchor_op=input_op, old_var=op.outputs[0], new_var=input_var, ): op.enclosing_block.remove_ops([op]) return True return False def remove_transpose(op): perm = np.array([p if p >= 0 else p + len(op.perm.val) for p in op.perm.val]) sorted_perm = np.sort(perm) if (perm != sorted_perm).any(): return False input_var = op.x input_op = input_var.op if op.enclosing_block.try_replace_uses_of_var_after_op( anchor_op=input_op, old_var=op.outputs[0], new_var=input_var, ): op.enclosing_block.remove_ops([op]) return True return False op_to_removal_fn = { "identity": remove_identity, "add": remove_elementwise, "mul": remove_elementwise, "floor_div": remove_elementwise, "pow": remove_elementwise, "real_div": remove_elementwise, "sub": remove_elementwise, "reshape": remove_same_shape, "split": remove_same_shape, "slice_by_index": remove_slice_by_index, "slice_by_size": remove_same_shape, "pad": remove_same_shape, "tile": remove_same_shape, "transpose": remove_transpose, "upsample_nearest_neighbor": remove_same_shape, "upsample_bilinear": remove_same_shape, "resize_bilinear": remove_same_shape, "crop": remove_same_shape, "linear_activation": remove_linear, } # abort if op output is a block output if op.outputs[0] in op.enclosing_block.outputs: return None if op.op_type in noop_elimination._SUPPORTED_OPS: if len(op.outputs) != 1: return None return op_to_removal_fn[op.op_type] return None @block_context_manager def _noop_elimination_block_wrapper(self, block): def _noop_elimination_block(block): for op in list(block.operations): for b in op.blocks: block_changed = True while block_changed: block_changed = _noop_elimination_block(b) if len(op.blocks) > 0: continue remove_fn = noop_elimination._match_pattern(op) if remove_fn is not None: status = remove_fn(op) # has to break as the downstream iterator is affected. if status: return status return False block_changed = True while block_changed: block_changed = _noop_elimination_block(block)