# Copyright (c) 2023, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
import collections
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import _are_ops_identical, block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
[docs]@register_pass(namespace="common")
class remove_redundant_ops(AbstractGraphPass):
"""
If there are multiple ops with "identical" inputs, then they are redundant and all but one of them can be removed.
This pass checks and removes such ops.
Since all inputs to ops in MIL are named, two ops with same ``op_types`` can be compared by comparing their
correspondingly named inputs. Inputs are treated as identical if one of the following is true:
- The input is a constant var, in which case its value should have the same dtype and numerical value.
- The input is a non constant var, in which case it should be the same var object.
This pass iterates over the ops, takes its first output var, and then builds a candidate op list from the child
ops of this var.
This candidate ops list contains ops of the same ``op_type``, arranged in topological order.
From each of these candidate ops in the list, the second, third, and subsequent ops are pairwise compared with the first op,
and if identical to it, they are removed. For example:
.. code-block::
Input:
%0 = op0(...)
%1 = op1(...)
%2 = const(val=4.5)
%3 = const(val=4.5)
%4 = op2(%1, %0, %2)
%5 = op3(%1, %0, %3)
Output:
%0 = op0(...)
%1 = op1(...)
%2 = const(val=4.5)
%3 = const(val=4.5) # this will get removed later by dead code elimination pass
%4 = op2(%1, %0, %2)
In the example above, ``op3`` is removed and all uses of ``%5`` is replaced by ``%4``.
For more examples, see "TestRemoveRedundantOpsPass".
"""
_NON_REDUNDANT_OPS = tuple()
def apply(self, prog):
for f in prog.functions.values():
self._remove_redundant_ops_in_block_wrapper(f)
@staticmethod
def _is_op_eligible_to_be_removed(op):
if (
len(op.blocks) != 0
or op.op_type.startswith("random")
or op.op_type in remove_redundant_ops._NON_REDUNDANT_OPS
):
return False
else:
return True
@staticmethod
def _get_candidate_ops_list(prospective_ops_list):
od = collections.OrderedDict()
enclosing_block = [op.enclosing_block for op in prospective_ops_list]
if len(set(enclosing_block)) > 1: # all candidate ops must belong to the same block
return []
for op in prospective_ops_list:
if remove_redundant_ops._is_op_eligible_to_be_removed(op):
od[op] = enclosing_block[0].operations.index(op)
# Sort the ops according to their index of appearing in block.operations, which is
# topologically sorted
return [x[0] for x in sorted(od.items(), key=lambda t: t[1])]
@staticmethod
def _get_candidate_ops_lists_from_var(var):
"""
Return a list of lists.
Each element is a list of a subset of the child ops of var, which satisifies the following conditions:
- they are of the same op_type
- ops are not repeated in it. The .child_ops property of a var may sometimes contain an op repeated more than once
- the ops are ordered based on the order in which they appear in the block.operations list (which is topologically sorted),
with ops appearing earlier in that list appearing first here.
"""
candidate_ops_lists = []
op_types_to_ops = collections.OrderedDict()
for op in var.child_ops:
if op.op_type in op_types_to_ops:
op_types_to_ops[op.op_type].append(op)
else:
op_types_to_ops[op.op_type] = [op]
for v in op_types_to_ops.values():
if len(v) > 1:
candidate_ops_list = remove_redundant_ops._get_candidate_ops_list(v)
if len(candidate_ops_list) > 1:
candidate_ops_lists.append(candidate_ops_list)
return candidate_ops_lists
@staticmethod
def _try_to_remove_ops(candidate_ops_list):
# candidate_ops_list contains ops in topological order.
# All the ops in candidate_ops_list will be compared to the first op, and removed if identical to it.
# Removing ops later in the topological order is much easier, as their output vars
# can simply be replaced by the output var of the first_op, this doesn't require
# changing any op order in the block.
if len(candidate_ops_list) < 2:
return False
first_op = candidate_ops_list[0]
block = first_op.enclosing_block
# currently, we only consider the cases when the op has 1 output.
# The replace var logic below only handles the single output case.
if len(first_op.outputs) > 1:
return False
ops_to_remove = []
for op in candidate_ops_list[1:]:
if op.outputs[0] not in block.outputs: # to make sure we don't remove an output op
if _are_ops_identical(first_op, op):
ops_to_remove.append(op)
if len(ops_to_remove) == 0:
return False
# remove uses of output vars of the ops to be removed.
# This can be safely done, since all the ops in ops_to_remove
# appear after first_op, hence first_op.outputs[0] variable is in
# scope before the op's output var
for op in ops_to_remove:
op.enclosing_block.replace_uses_of_var_after_op(
anchor_op=op, old_var=op.outputs[0], new_var=first_op.outputs[0]
)
block.remove_ops(ops_to_remove)
return True
@staticmethod
def _try_to_transform(parent_var):
"""
scan the children ops to parent_var, to find and remove indentical ops, if any.
Returns True, if succesful in finding such redundant ops.
"""
candidate_ops_lists = remove_redundant_ops._get_candidate_ops_lists_from_var(parent_var)
block_changed = False
for ops_list in candidate_ops_lists:
if remove_redundant_ops._try_to_remove_ops(ops_list):
block_changed = True
return block_changed
@block_context_manager
def _remove_redundant_ops_in_block_wrapper(self, block):
def _remove_redundant_ops_in_block(block):
if isinstance(block.inputs, dict):
block_input_var_list = list(block.inputs.values())
elif isinstance(block.inputs, (list, tuple)):
block_input_var_list = block.inputs
else:
raise ValueError("Unrecognized type of block.inputs, its neither a list nor dict.")
# iterate over the block inputs
for input_var in block_input_var_list:
if len(input_var.child_ops) > 1:
self._try_to_transform(input_var)
# iterate over the ops in the block
graph_updated = False
for op in block.operations:
if op.op_type == "const":
continue
for b in op.blocks:
block_changed = True
while block_changed:
block_changed = _remove_redundant_ops_in_block(b)
if len(op.outputs) > 0 and len(op.outputs[0].child_ops) > 1:
# currently, we only check the first output of the op
# this can be extended, if required, to check for other outputs.
graph_updated = self._try_to_transform(op.outputs[0])
# has to break as the downstream iterator is affected.
if graph_updated:
return graph_updated
return graph_updated
block_changed = True
while block_changed:
block_changed = _remove_redundant_ops_in_block(block)