Source code for coremltools.converters.mil.mil.passes.defs.cleanup.expand_dynamic_linear

#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import numpy as np

from coremltools.converters.mil.mil import Block
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import Operation, Program, Var
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass


[docs] @register_pass(namespace="common") class expand_dynamic_linear(AbstractGraphPass): """ Translate to ``linear`` when the operand is a descendant of const, since such an operand may be folded into const or fused into constexpr later by graph passes. In op translation, we prefer ``linear`` whenever possible because it requires const or constexpr ``weight`` and ``bias``. If such const folding or constexpr fusion did not happen, this pass would clean up the too-ambitious ``linear`` ops by replacing them with ``matmul`` ops. """ def apply(self, prog: Program) -> None: for f in prog.functions.values(): self._expand_dynamic_linear_block(f) @block_context_manager def _expand_dynamic_linear_block(self, block: Block) -> None: # use shallow copy to hide changes on block.operations during the loop, # since we do not need to deal with the newly expanded matmul + add ops for op in list(block.operations): for b in op.blocks: self._expand_dynamic_linear_block(b) if op.op_type == "linear": self._try_expand_dynamic_linear(op, block) @staticmethod def _is_operand_static(var: Var) -> bool: if var is None: return True op = var.op if op is None: return False op_type = op.op_type return op_type == "const" or op_type.startswith("constexpr_") def _try_expand_dynamic_linear(self, op: Operation, block: Block) -> None: assert op.op_type == "linear", "Should only apply to linear op" is_weight_static = self._is_operand_static(op.weight) is_bias_static = self._is_operand_static(op.bias) if is_weight_static: if is_bias_static: # static weight and bias, linear is good return else: # static weight with dynamic bias, so linear for weight matmul + add for bias add matmul = mb.linear(x=op.x, weight=op.weight, before_op=op) add = mb.add(x=matmul, y=op.bias, before_op=op, name=op.name) block.replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=add, ) op.remove_from_block() else: # dynamic weight, have to expand to at least matmul result = mb.matmul(x=op.x, y=op.weight, transpose_y=True, before_op=op) # static bias, try skipping add if all zero if is_bias_static: force_replace = False # if no bias provided, default to 0, can skip # if bias provided, need to inspect its value if op.bias is not None: bias_op = op.bias.op bias_op_type = bias_op.op_type if bias_op_type == "const": is_nonzero_bias = np.any(op.bias.val != 0) else: if bias_op_type == "constexpr_affine_dequantize": is_nonzero_bias = not bias_op.is_all_zeros() # cowardly treat other types of compressed bias as if nonzero else: is_nonzero_bias = True # For such a compressed all-zero bias, if we skip add, then # the result (matmul output) would only descend from weight but not bias, # i.e. need to force replacing descendant of bias if not is_nonzero_bias: force_replace = True if is_nonzero_bias: result = mb.add(x=result, y=op.bias, before_op=op, name=op.name) block.replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=result, force_replace=force_replace, ) op.remove_from_block() # dynamic bias, have to further expand to matmul + add else: result = mb.add(x=result, y=op.bias, before_op=op, name=op.name) block.replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=result, ) op.remove_from_block()