Source code for coremltools.converters.mil.mil.passes.defs.cleanup.dedup_op_and_var_names

#  Copyright (c) 2023, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import collections
import itertools

from coremltools.converters.mil.mil import Function
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.pass_registry import register_pass


[docs] @register_pass(namespace="common") class dedup_op_and_var_names(AbstractGraphPass): """ For each function, this pass renames ops and variables with the same name as any preceding ops/variables across all scopes in the given function, where the precedence is implementation-specific. Note that an op name and variable names are tracked separately, so an op may have the same name as a variable. The pass preserves input and output name. Raises ValueError if we cannot dedup without changing the input/output var names. .. code-block:: def prog(x): x = mb.cast(x=x, dtype="fp16", name="castop") x = mb.cast(x=x, dtype="fp32", name="castop") x = mb.square(x=x, name="square_last") return x # Before dedup pass, the op names are ["castop", "castop", "square_last"]. # After dedup pass, the op names are ["castop", "castop_1", "square_last"]. """ def apply(self, prog): for func in prog.functions.values(): # Handle function input/outputs as they cannot be changed (to maintain user interface) inputs = list(func.inputs.values()) io_vars = set(inputs + func.outputs) self._ensure_unique_var_names(io_vars) seen_var_names = set([v.name for v in io_vars]) seen_op_names = set() self._deduplicate_block(func, set(func.outputs), seen_var_names, seen_op_names) @staticmethod def _gen_new_name(seen_names, curr_name): if curr_name not in seen_names: return curr_name # make sure the name is unique for i in itertools.count(start=1): # loop from 1 to infinity # rename duplicated name start from 1: 'xxx_1' new_name = curr_name + "_" + str(i) if new_name not in seen_names: return new_name def _deduplicate_block(self, block, func_outputs, seen_var_names, seen_op_names): """ seen_var_names: set[str] seen_op_names: set[str] """ # Add block input (function input is handled separately) if not isinstance(block, Function): for v in block.inputs: v.name = self._gen_new_name(seen_var_names, v.name) seen_var_names.add(v.name) for op in list(block.operations): for b in op.blocks: self._deduplicate_block(b, func_outputs, seen_var_names, seen_op_names) if op.name is not None: op.name = self._gen_new_name(seen_op_names, op.name) seen_op_names.add(op.name) for v in op.outputs: if v in func_outputs: # func output is never renamed continue v.name = self._gen_new_name(seen_var_names, v.name) seen_var_names.add(v.name) @staticmethod def _ensure_unique_var_names(v_set): """ v_set: set[Variable] All variables in v_set should have different names. Raise ValueError otherwise """ names = [v.name for v in v_set] dup_names = [name for name, count in collections.Counter(names).items() if count > 1] if len(dup_names) > 0: raise ValueError(f"Var names {dup_names} is used both as function's input and output")