Source code for coremltools.converters.mil.mil.passes.defs.cleanup.const_deduplication

#  Copyright (c) 2023, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import hashlib
from typing import Dict, List, Tuple, Union

import numpy as np

from coremltools.converters.mil.mil import Block, ListVar, Program, Var, types
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass


[docs] @register_pass(namespace="common") class const_deduplication(AbstractGraphPass): """ Remove duplicated large constants (tensor with 100+ elements) For example .. code-block:: Input graph (where weight and bias are large constants): weight_q = const(weight) weight_k = const(weight) bias_q = const(bias) bias_k = const(bias) q_embedding = linear(x=q, weight=weight_q, bias=bias_q) k_embedding = linear(x=k, weight=weight_k, bias=bias_k) Output graph: weight_q = const(weight) bias_q = const(bias) q_embedding = linear(x=q, weight=weight_q, bias=bias_q) k_embedding = linear(x=k, weight=weight_q, bias=bias_q) Concretely, this graph pass consists of two stages: (1) Deduplication of ``const`` op: We consider a ``const`` as duplicated if there exists such a previous ``const`` that has same dtype and value (2) Deduplication of ``constexpr_*`` op: We consider a ``constexpr_*`` as duplicated if there exists such a previous ``constexpr_*`` that has the same ``op_type`` and input attributes. Support options: - ``const_threshold``: Skip deduplicating ``const`` ops that have smaller number of elements than a threshold. Defaults to ``100``. i.e. the constants with ``size < 100`` will not be deduplicated. """ # const with size < _const_threshold will not be deduplicated _const_threshold = 100 # length of the number value hashkey LENGTH_OF_HASHKEY = 100 DTYPE2ATOL = { types.fp16: 6e-8, types.fp32: 1e-12, } @property def const_threshold(self) -> int: return const_deduplication._const_threshold @const_threshold.setter def const_threshold(self, val: int) -> None: if not isinstance(val, int): raise ValueError(f"Expect option 'const_threshold' to be type of int. Got {type(val)}.") const_deduplication._const_threshold = val def apply(self, prog) -> None: for f in prog.functions.values(): self._constant_deduplication_block(f) def _deduplicate_const_across_functions(self, prog: Program) -> None: """ When there are duplicated consts across functions, we cannot create a common const op to be shared. Instead, we set the weight_id to the consts, to allow them share the same blob file value when lowering into milproto. """ # We first make sure that consts are deduplicated within each function, # to make sure we can maximize the weight sharing. self.apply(prog) # check no weight_id is set yet in the program for block in prog.functions.values(): for op in block.operations: if op.op_type != "const": continue if op.weight_id is not None: raise ValueError(f"const op {op.name} already has weight_id {op.weight_id}") # deduplication across functions blocks = list(prog.functions.values()) unique2duplicates_const = self.find_constants(blocks) for i, (k, v) in enumerate(unique2duplicates_const.items()): if len(v) == 0: continue # There could be cases where two functions are pointing to the same block all_vars = [k] + list(v) all_vars = list(set(all_vars)) for duplicate in all_vars: duplicate.op.weight_id = i def remove_duplicate_ops( self, block: Block, unique2duplicates: Dict[Var, List[Var]], force_replace: bool ) -> None: for unique in unique2duplicates: for duplicate in unique2duplicates[unique]: if duplicate in block.outputs: continue block.replace_uses_of_var_after_op( anchor_op=duplicate.op, old_var=duplicate, new_var=unique, force_replace=force_replace, ) block.remove_ops([duplicate.op]) @block_context_manager def _constant_deduplication_block(self, block: Block) -> None: for op in list(block.operations): for b in op.blocks: self._constant_deduplication_block(b) # Deduplication of ``const`` op unique2duplicates_const = self.find_constants([block]) self.remove_duplicate_ops(block, unique2duplicates_const, force_replace=False) # Deduplication of ``constexpr_*`` op # Note that, the ``find_constexpr`` must go after ``find_constants`` + ``remove_duplicate_ops`` for ``const`` ops. # Since after the above two functions, ``const`` ops with identical values are # deduplicated into a single ``Var`` object, which allows ``find_constexpr`` to # directly compare the ``const`` input attr pointers instead of the actual values. unique2duplicates_constexpr = self.find_constexprs([block]) self.remove_duplicate_ops(block, unique2duplicates_constexpr, force_replace=True) @staticmethod def find_constexprs(blocks: List[Block]) -> Dict[Var, List[Var]]: """ Given a list of blocks, return all constexpr in the blocks in such a format: {unique_var_0: [duplicated_var_0_0, duplicated_var_0_1, ...], unique_var_1: [duplicated_var_1_0, duplicated_var_1_1, ...], ... } """ hashkey_2_duplicates: Dict[Tuple, List[Var]] = {} for block in blocks: for op in list(block.operations): if "constexpr" not in op.op_type: continue if hasattr(op, "weight_key"): hash_key = [op.op_type, op.weight_key] else: hash_key = [op.op_type] for v in op.inputs.values(): hash_key.append(v.dtype) if v.val is None or const_deduplication.should_be_deduplicated(v.val): hash_key.append(v) else: hash_key.append(str(v.val)) hash_key = tuple(hash_key) if hash_key not in hashkey_2_duplicates: hashkey_2_duplicates[hash_key] = [op.outputs[0]] else: hashkey_2_duplicates[hash_key].append(op.outputs[0]) return {v[0]: v[1:] for v in hashkey_2_duplicates.values()} @staticmethod def should_be_deduplicated(val: Union[str, bool, np.ndarray]) -> bool: assert val is not None, "val should only be type of (str, bool, np.ndarray)" if isinstance(val, (str, bool)): return False if np.prod(val.shape) < const_deduplication._const_threshold: return False return True @staticmethod def find_constants(blocks: List[Block]) -> Dict[Var, List[Var]]: """ Given a list of blocks, return all constants in the blocks in such a format: {unique_var_0: [duplicated_var_0_0, duplicated_var_0_1, ...], unique_var_1: [duplicated_var_1_0, duplicated_var_1_1, ...], ... } """ unique2duplicates: Dict[Var, List[Var]] = {} # instead of brute-force C_N^2 comparison, use a hash map to be O(N) constant_dict: Dict[Tuple[str, types.type, Tuple[int], str], List[Var]] = {} for block in blocks: for op in list(block.operations): if op.op_type != "const": continue constant_var = op.outputs[0] if isinstance(constant_var, ListVar): continue if not const_deduplication.should_be_deduplicated(constant_var.val): continue shape = constant_var.shape dtype = constant_var.dtype value = constant_var.val hash = hashlib.sha1( np.ascontiguousarray(value.reshape(-1)[: const_deduplication.LENGTH_OF_HASHKEY]) ).hexdigest() if hasattr(op, "weight_key"): key = (op.weight_key, dtype, shape, hash) else: key = (dtype, shape, hash) if key not in constant_dict: constant_dict[key] = [constant_var] unique2duplicates[constant_var] = [] else: hash_collisions = constant_dict[key] existing_constant_var = None for var in hash_collisions: if np.allclose( value, var.val, rtol=0.0, atol=const_deduplication.DTYPE2ATOL.get(dtype, 1e-12), ): existing_constant_var = var break if existing_constant_var is None: hash_collisions.append(constant_var) unique2duplicates[constant_var] = [] else: unique2duplicates[existing_constant_var].append(constant_var) return unique2duplicates