Source code for

#  Copyright (c) 2020, Apple Inc. All rights reserved.
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at

from typing import List, Optional

import numpy as np

from coremltools import _logger as logger
from import Block
from import Builder as mb
from import Operation, Program, Var
from import AbstractGraphPass
from import (
from import register_pass

[docs] @register_pass(namespace="common") class fuse_layernorm_or_instancenorm(AbstractGraphPass): """ A graph optimization pass on PyMIL to detect and fuse several variants of ``layer_norm`` or ``instance_norm``. Pattern 1 corresponds to either ``layer_norm`` or ``instance_norm``. Patterns 2-4 are ``instance_norm``. Pattern 5 is ``layer_norm``. You can find these patterns in the methods for this class in the source code. To quickly view the source code, click the **[source]** button at the end of the class definition. """ _DEBUG = False # set to true to plot the block before and after the transformation def apply(self, prog: Program): for f in prog.functions.values(): block_changed = True while block_changed: if self._DEBUG: import graphviz graphviz.Source( f.get_dot_string( highlight_debug_op_types=["instance_norm"], ) ).view(filename="/tmp/block_before_fuse_layernorm_or_instancenorm") logger.debug("Block before fuse_layernorm_or_instancenorm transform:\n{}".format(f)) block_changed = self._fuse_layernorm_or_instancenorm_block(f) if self._DEBUG: graphviz.Source( f.get_dot_string( highlight_debug_op_types=["instance_norm"], ) ).view(filename="/tmp/block_after_fuse_layernorm_or_instancenorm") logger.debug("Block after fuse_layernorm_or_instancenorm transform:\n{}".format(f)) @staticmethod def _check_reduce_op(reduce_op: Operation, mode: str = "reduce_mean") -> bool: """ Check whether or not the ``reduction`` op satisfies following conditions: - Mode is expected. - Does not change rank (``keep_dims`` is ``True``). - The ``axes`` is known at compile time. Parameters ---------- param reduce_op : ``reduce_op`` to check on. param mode : ``reduce`` mode """ if reduce_op is None: return False if reduce_op.op_type != mode: return False if reduce_op.keep_dims is None or reduce_op.keep_dims.val is None: return False if reduce_op.keep_dims.val is False: return False if reduce_op.axes is None or reduce_op.axes.val is None: return False return True @staticmethod def _check_child_op_types( op: Operation, child_op_types: List[str], check_order: bool = True ) -> bool: """ Returns ``True`` for child op types matching ``child_op_types``, otherwise returns ``False``. Parameters ---------- param op : Current op. param child_op_type : Expected child op type. param check_order : Ensure child in given order, defaults to ``True``. """ if op is None or len(op.outputs) != 1: return False child_ops = list(op.outputs[0].child_ops) if len(child_ops) != len(child_op_types): return False ops_types = [c.op_type for c in child_ops] if check_order is False: ops_types = sorted(ops_types) child_op_types = sorted(child_op_types) return ops_types == child_op_types @staticmethod def _try_get_child_op_type( op: Operation, child_op_type: str, index: int = 0 ) -> Optional[Operation]: """ Returns child op if type matches, otherwise returns ``None``. Parameters ---------- param op : Current op. param child_op_type : Expected child op type. param index : Child op index. """ if op is None: return None if len(op.outputs) != 1: return None child_ops = list(op.outputs[0].child_ops) if index >= len(child_ops): return None if child_ops[index].op_type != child_op_type: return None return child_ops[index] @staticmethod def _try_apply_transform( reduce_op: Operation, block: Block, gamma_var: Var, beta_var: Var, epsilon_var: Var, end_op: Operation, ops_to_remove: List[Operation], ) -> bool: """ Insert instance_norm / layer_norm and delete all ops. :param reduce_op: Start operation of the pattern. :param block: Block :param gamma_var: Gamma variable. :param beta_var: Beta variable. :param epsilon_var: Epsilon variable. :param end_op: End operation of the pattern. :param ops_to_remove: Operations to remove. """ if not _check_no_output_connection(block, ops_to_remove): return False axes = reduce_op.axes.val rank = len(reduce_op.x.shape) # check whether the pattern is instance_norm or layer_norm is_layernorm = False is_instancenorm = False is_require_rank4_transpose = False negative_axes = [a - rank if a >= 0 else a for a in axes] negative_axes.sort() gamma_rank = gamma_var.rank if gamma_var is not None else -1 beta_rank = beta_var.rank if beta_var is not None else -1 if gamma_rank == len(axes) and beta_rank == len(axes): # axes for layer_norm must be [-1] or [-1, -2] or [-1, -2, -3] and so on if negative_axes == list(range(-len(negative_axes), 0)): is_layernorm = True if rank == 4 and negative_axes == [-3]: is_layernorm = (gamma_var is None and beta_var is None) or (gamma_rank == 1 and beta_rank == 1) if gamma_var: ops_to_remove.append(gamma_var.op) gamma_var = gamma_var.val else: gamma_var = None if beta_var: ops_to_remove.append(beta_var.op) beta_var = beta_var.val else: beta_var = None if rank == 4 and (negative_axes == [-2, -1] or negative_axes == [-3, -2]): if ( len(np.squeeze(gamma_var.val).shape) == 1 and len(np.squeeze(beta_var.val).shape) == 1 ): is_instancenorm = True if negative_axes == [-3, -2]: is_require_rank4_transpose = True if not (is_instancenorm or is_layernorm): return False # remove all the ops, and replace with a layer_norm or instance_norm op out_name = end_op.outputs[0].name if is_require_rank4_transpose: x = mb.transpose( x=reduce_op.x, perm=[0, 3, 1, 2], name=out_name + "_transpose_nhwc_nchw", before_op=end_op, ) if is_instancenorm: x = mb.instance_norm( x=x if is_require_rank4_transpose else reduce_op.x, gamma=np.squeeze(gamma_var.val), beta=np.squeeze(beta_var.val), epsilon=epsilon_var, name=out_name + "_instancenorm" if is_require_rank4_transpose else out_name, before_op=end_op, ) ops_to_remove.extend([gamma_var.op, beta_var.op]) else: # is_layernorm x = mb.layer_norm( x=x if is_require_rank4_transpose else reduce_op.x, axes=axes, gamma=gamma_var, beta=beta_var, epsilon=epsilon_var, name=out_name + "_layernorm" if is_require_rank4_transpose else out_name, before_op=end_op, ) if is_require_rank4_transpose: x = mb.transpose( x=x, perm=[0, 2, 3, 1], name=out_name + "_transpose_nchw_nhwc", before_op=end_op, ) end_op.enclosing_block.replace_uses_of_var_after_op( anchor_op=end_op, old_var=end_op.outputs[0], new_var=x ) # Remove all the ops at once block.remove_ops(ops_to_remove) return True def _try_match_and_transform_pattern_1(self, reduce_op, block) -> bool: """ Identify the pattern: ``y = gamma * (x - mean) / sqrt(variance + epsilon) + beta`` ``y = x * [gamma * rsqrt(variance + eps)] + (beta - mean * [gamma * rsqrt(variance + eps)])`` .. code-block:: x --> reduce_mean --> sub --> square --> reduce_mean --> add(epsilon) --> rsqrt | | ^ | | | | V |----------------------- mul (gamma) | | | | | --------|--------- | | | | | | | V | |----------------------------------------------------------------> mul | | | | V | |--------------------------------------------------------------> mul | | V | sub (beta) --> add --> [...] | ^ |------------------------------- This pattern corresponds to either ``layer_norm`` or ``instance_norm``. It is ``instance_norm`` if all of the following are true: - ``input`` is rank 4. - ``axes`` of ``reduce_mean`` is ``[-2, -1]`` or ``[-3, -2]`` (when ``[-3, -2]``, a channel first to channel last transpose would be inserted). - ``gamma`` and ``beta`` are rank 1, after ``squeeze``. It is ``layer_norm`` if all of the following are true: - ``axes`` is either ``[-1]``, ``[-1, -2]``, or ``[-1, -2, -3]``, and so on. - ``rank`` of ``gamma`` and ``beta`` is equal to the length of the ``axes``. """ ops_to_remove = [] root_var = reduce_op.x if root_var.shape is None: return False # check that root_var feeds into exactly 3 ops if len(list(root_var.child_ops)) != 3: return False if root_var.op is not None and not self._check_child_op_types( root_var.op, child_op_types=["reduce_mean", "sub", "mul"] ): return False # check 1st reduce_mean op if not self._check_reduce_op(reduce_op): return False ops_to_remove.append(reduce_op) # check 1st sub op if not self._check_child_op_types(reduce_op, ["sub", "mul"], check_order=False): return False child_ops_reduce_mean = list(reduce_op.outputs[0].child_ops) op_a = child_ops_reduce_mean[0] op_b = child_ops_reduce_mean[1] sub_op1 = op_a if op_a.op_type == "sub" else op_b if not (sub_op1.x == root_var and sub_op1.y == reduce_op.outputs[0]): return False ops_to_remove.append(sub_op1) # check square op square_op = self._try_get_child_op_type(sub_op1, "square") if square_op is None: return False ops_to_remove.append(square_op) # check second reduce mean reduce_op2 = self._try_get_child_op_type(square_op, "reduce_mean") if not self._check_reduce_op(reduce_op2): return False ops_to_remove.append(reduce_op2) # check add op (with epsilon) add_op1 = self._try_get_child_op_type(reduce_op2, "add") if add_op1 is None: return False epsilon_var = add_op1.y if add_op1.x == reduce_op2.outputs[0] else add_op1.x if epsilon_var.val is None or len(epsilon_var.val.shape) != 0: return False # must be scalar ops_to_remove.append(add_op1) # check rsqrt rsqrt_op = self._try_get_child_op_type(add_op1, "rsqrt") if rsqrt_op is None: return False ops_to_remove.append(rsqrt_op) # check mul (gamma) mul_op1 = self._try_get_child_op_type(rsqrt_op, "mul") if mul_op1 is None: return False gamma_var = mul_op1.y if mul_op1.x == rsqrt_op.outputs[0] else mul_op1.x if gamma_var.val is None: return False ops_to_remove.append(mul_op1) # check 2 muls after the gamma mul if not self._check_child_op_types(mul_op1, ["mul", "mul"]): return False child_ops = list(mul_op1.outputs[0].child_ops) mul_op2 = child_ops[0] mul_op3 = child_ops[1] mul_op2_other_var = mul_op2.x if mul_op2.y == mul_op1.outputs[0] else mul_op2.y mul_op3_other_var = mul_op3.x if mul_op3.y == mul_op1.outputs[0] else mul_op3.y if not ( (mul_op2_other_var == root_var and mul_op3_other_var == reduce_op.outputs[0]) or (mul_op2_other_var == reduce_op.outputs[0] and mul_op3_other_var == root_var) ): return False if mul_op2_other_var == root_var: mul_root_op = mul_op2 mul_mean_op = mul_op3 else: mul_root_op = mul_op3 mul_mean_op = mul_op2 ops_to_remove.append(mul_mean_op) ops_to_remove.append(mul_root_op) # check sub with beta sub_op2 = self._try_get_child_op_type(mul_mean_op, "sub") if sub_op2 is None: return False if sub_op2.y != mul_mean_op.outputs[0]: return False beta_var = sub_op2.x if beta_var.val is None: return False ops_to_remove.append(sub_op2) # check last add op add_op2 = self._try_get_child_op_type(sub_op2, "add") if add_op2 is None: return False if not (add_op2.x == mul_root_op.outputs[0] or add_op2.y == mul_root_op.outputs[0]): return False ops_to_remove.append(add_op2) return self._try_apply_transform( reduce_op, block, gamma_var, beta_var, epsilon_var, add_op2, ops_to_remove ) def _try_match_and_transform_pattern_2(self, reduce_op, block) -> bool: """ Identify the pattern: ``y = (x - mean) / pow(variance + epsilon) * gamma + beta`` This pattern corresponds to, and should be fused as, ``instance_norm``. All of the following conditions must be satisfied: 1. ``input`` is rank 4 tensor. 2. ``reduce`` operates on spatial dimensions ``axes=[-2, -1]``, or ``axes=[-3, -2]`` (a channel first to channel last transpose would be inserted in such cases). 3. ``gamma`` and ``beta`` are both shape ``(C,)`` after ``squeeze``, where ``C`` is number of channels. .. code-block:: |----> sub -----| const (0.5) | ^ | | | | V V x ---> mean square --> mean1 --> add_eps ---> pow const_gamma const_beta | | | | | | V V V V |----> sub1 --------------------------------> real_div --> mul_gamma --> add_beta --> ... """ ops_to_remove = [] root_var = reduce_op.x if root_var.shape is None: return False # check that root_var feeds into exactly 3 ops if len(root_var.child_ops) != 3: return False if root_var.op is not None and not self._check_child_op_types( root_var.op, child_op_types=["reduce_mean", "sub", "sub"] ): return False # check 1st reduce_mean op if not self._check_reduce_op(reduce_op): return False ops_to_remove.append(reduce_op) # check 1st sub op if not self._check_child_op_types(reduce_op, ["sub", "sub"]): return False child_ops_reduce_mean = list(reduce_op.outputs[0].child_ops) reduce_mean_child_op_a = child_ops_reduce_mean[0] reduce_mean_child_op_b = child_ops_reduce_mean[1] # One of sub op directly goes square, the other one goes real_div if list(reduce_mean_child_op_a.outputs[0].child_ops)[0].op_type == "square": sub_op0 = reduce_mean_child_op_a sub_op1 = reduce_mean_child_op_b else: sub_op0 = reduce_mean_child_op_b sub_op1 = reduce_mean_child_op_a if not (sub_op0.x == root_var and sub_op0.y == reduce_op.outputs[0]): return False if not (sub_op1.x == root_var and sub_op1.y == reduce_op.outputs[0]): return False ops_to_remove.append(sub_op0) ops_to_remove.append(sub_op1) # check square op square_op = self._try_get_child_op_type(sub_op0, "square") if square_op is None: return False ops_to_remove.append(square_op) # check second reduce mean reduce_op2 = self._try_get_child_op_type(square_op, "reduce_mean") if not self._check_reduce_op(reduce_op2): return False ops_to_remove.append(reduce_op2) # check add op (with epsilon) add_eps_op = self._try_get_child_op_type(reduce_op2, "add") if add_eps_op is None: return False epsilon_var = add_eps_op.y if add_eps_op.x == reduce_op2.outputs[0] else add_eps_op.x if epsilon_var.val is None or len(epsilon_var.val.shape) != 0: return False # must be scalar ops_to_remove.append(add_eps_op) # check pow pow_op = self._try_get_child_op_type(add_eps_op, "pow") if pow_op is None: return False if pow_op.y.val is None or not np.isclose(pow_op.y.val, 0.5): return False ops_to_remove.append(pow_op) # check real_div real_div_op = self._try_get_child_op_type(pow_op, "real_div") if real_div_op is None: return False if not (real_div_op.x == sub_op1.outputs[0] and real_div_op.y == pow_op.outputs[0]): return False ops_to_remove.append(real_div_op) # check mul with gamma mul_gamma_op = self._try_get_child_op_type(real_div_op, "mul") if mul_gamma_op is None: return False gamma_var = mul_gamma_op.y if mul_gamma_op.x == real_div_op.outputs[0] else mul_gamma_op.x if gamma_var.val is None: return False ops_to_remove.append(mul_gamma_op) # check add with beta add_beta_op = self._try_get_child_op_type(mul_gamma_op, "add") if add_beta_op is None: return False beta_var = add_beta_op.y if add_beta_op.x == mul_gamma_op.outputs[0] else add_beta_op.x if beta_var.val is None: return False ops_to_remove.append(add_beta_op) return self._try_apply_transform( reduce_op, block, gamma_var, beta_var, epsilon_var, add_beta_op, ops_to_remove ) def _try_match_and_transform_pattern_3(self, reduce_op, block) -> bool: """ Detect ``InstanceNorm`` pattern in TensorFlow-Addons. This pattern corresponds to, and should be fused as, ``instance_norm``. All of the following conditions must be satisfied: 1. ``input`` is rank 4 tensor. 2. ``reduce`` operates on spatial dimensions ``axes=[-2, -1]``, or ``axes=[-3, -2]`` (a channel first to channel last transpose would be inserted in such cases). 3. ``gamma`` and ``beta`` are absent. Default values for ``gamma`` and ``beta`` would be used. .. code-block:: |-------------------------------------------------| | | | V x --> mean square --> mean1 --> add_eps --> rsqrt --> mul2 --> mul_sub | | ^ | | | V | | | | --> sub -----| | | | V V |--------------------------------------------> mul1 -------------> add --> ... """ ops_to_remove = [] root_var = reduce_op.x if root_var.shape is None: return False # check that root_var feeds into exactly 3 ops if len(root_var.child_ops) != 3: return False if root_var.op is not None and not self._check_child_op_types( root_var.op, ["sub", "mul", "reduce_mean"] ): return False # check 1st reduce_mean op if not self._check_reduce_op(reduce_op): return False ops_to_remove.append(reduce_op) # check 1st sub op if not self._check_child_op_types(reduce_op, ["sub", "mul"], check_order=False): return False child_ops_reduce_mean = list(reduce_op.outputs[0].child_ops) reduce_mean_child_op_a = child_ops_reduce_mean[0] reduce_mean_child_op_b = child_ops_reduce_mean[1] sub_op1 = ( reduce_mean_child_op_a if reduce_mean_child_op_a.op_type == "sub" else reduce_mean_child_op_b ) if not (sub_op1.x == root_var and sub_op1.y == reduce_op.outputs[0]): return False ops_to_remove.append(sub_op1) # check square op square_op = self._try_get_child_op_type(sub_op1, "square") if square_op is None: return False ops_to_remove.append(square_op) # check second reduce mean reduce_op2 = self._try_get_child_op_type(square_op, "reduce_mean") if reduce_op2 is None or not self._check_reduce_op(reduce_op2): return False ops_to_remove.append(reduce_op2) # check add op (with epsilon) add_eps_op = self._try_get_child_op_type(reduce_op2, "add") if add_eps_op is None: return False epsilon_var = add_eps_op.y if add_eps_op.x == reduce_op2.outputs[0] else add_eps_op.x if epsilon_var.val is None or len(epsilon_var.val.shape) != 0: return False # must be scalar ops_to_remove.append(add_eps_op) # check rsqrt rsqrt_op = self._try_get_child_op_type(add_eps_op, "rsqrt") if rsqrt_op is None: return False ops_to_remove.append(rsqrt_op) # check mul 1 mul_op1 = self._try_get_child_op_type(rsqrt_op, "mul") if mul_op1 is None: return False if not ( (mul_op1.x == root_var and mul_op1.y == rsqrt_op.outputs[0]) or (mul_op1.x == rsqrt_op.outputs[0] and mul_op1.y == root_var) ): return False ops_to_remove.append(mul_op1) # check mul 2 mul_op2 = self._try_get_child_op_type(rsqrt_op, "mul", index=1) if mul_op2 is None: return False if not ( (mul_op2.x == reduce_op.outputs[0] and mul_op2.y == rsqrt_op.outputs[0]) or (mul_op2.x == rsqrt_op.outputs[0] and mul_op2.y == reduce_op.outputs[0]) ): return False ops_to_remove.append(mul_op2) # check mul (sub) mul_sub_op = self._try_get_child_op_type(mul_op2, "mul") if mul_sub_op is None: return False if mul_sub_op.y.val is None or mul_sub_op.y.val != -1: return False ops_to_remove.append(mul_sub_op) # check last add op add_op = self._try_get_child_op_type(mul_sub_op, "add") if add_op is None: return False if not ( (add_op.x == mul_op1.outputs[0] and add_op.y == mul_sub_op.outputs[0]) or (add_op.x == mul_sub_op.outputs[0] and add_op.y == mul_op1.outputs[0]) ): return False ops_to_remove.append(add_op) gamma_var = mb.const( val=np.ones(shape=(1, root_var.shape[1], 1, 1)), name="_fuse_layernorm_or_instancenorm_gamma", ) beta_var = mb.const( val=np.zeros(shape=(1, root_var.shape[1], 1, 1)), name="_fuse_layernorm_or_instancenorm_beta", ) return self._try_apply_transform( reduce_op, block, gamma_var, beta_var, epsilon_var, add_op, ops_to_remove ) def _try_match_and_transform_pattern_4(self, reduce_op: Operation, block: Block) -> bool: """ Identify the pattern: ``y = x * [gamma * rsqrt(variance + eps)] + (beta - mean * [gamma * rsqrt(variance + eps)])`` This pattern corresponds to, and should be fused as, ``instance_norm``. All of the following conditions must be satisfied: 1. ``input`` is rank 4 tensor. 2. ``reduce`` operates on spatial dimensions ``axes=[-2, -1]`` or ``axes=[-3, -2]`` (a channel first to channel last transpose would be inserted in such cases). 3. ``gamma`` and ``beta`` are both shape ``(C,)`` after ``squeeze``, where ``C`` is number of channels. .. code-block:: |-----------| | V |------> mul_square1 -----> sum1 -----> mul_mean1 | | | V x --> sum --> mul_mean ==> mul_square --> sub_variance --> add_eps --> rsqrt | | | | | V | | mul_gamma | | | | | |----------------| | | | V | |--------------------------------------------+-------------> mul2 | V | |----------------------------------------------------------> mul1 | | V | sub_beta --> add --> [...] | ^ |---------------------------| """ ops_to_remove = [] root_var = reduce_op.x if root_var.shape is None: return False # check that root_var feeds into exactly 4 ops if len(root_var.child_ops) != 4: return False if ( root_var.op is not None and not self._check_child_op_types( root_var.op, child_op_types=["mul", "mul", "reduce_sum", "mul"] ) and not self._check_child_op_types( # The _check_child_op_types checks for the exact order of the child_ops. root_var.op, child_op_types=["mul", "mul", "mul", "reduce_sum"], ) ): return False # check 1st reduce_sum op if not self._check_reduce_op(reduce_op, mode="reduce_sum"): return False ops_to_remove.append(reduce_op) # check mul (mean) op mul_mean_op = self._try_get_child_op_type(reduce_op, "mul") if mul_mean_op is None: return False if mul_mean_op.y.shape != (): return False ops_to_remove.append(mul_mean_op) # check 1st mul (square) op if not self._check_child_op_types(mul_mean_op, child_op_types=["mul", "mul", "mul"]): return False # both 0 and 1 should be mul square op mul_square_op = self._try_get_child_op_type(mul_mean_op, "mul") if mul_square_op is None: return False if self._try_get_child_op_type(mul_mean_op, "mul", index=1) is None: return False ops_to_remove.append(mul_square_op) # Check another branch # check 2nd mul (square) op # both 0 and 1 should be mul square op 1 mul_square_op2 = list(root_var.child_ops)[0] ops_to_remove.append(mul_square_op2) # check 2nd reduce sum reduce_op2 = self._try_get_child_op_type(mul_square_op2, child_op_type="reduce_sum") if not self._check_reduce_op(reduce_op2, "reduce_sum"): return False ops_to_remove.append(reduce_op2) # check mul after 2nd reduce op mul_mean_op2 = self._try_get_child_op_type(reduce_op2, "mul") if mul_mean_op2 is None: return False if mul_mean_op2.y.shape != (): return False ops_to_remove.append(mul_mean_op2) # check sub (variance) sub_variance_op = self._try_get_child_op_type(mul_mean_op2, "sub") if sub_variance_op is None: return False if sub_variance_op.y != mul_square_op.outputs[0]: return False ops_to_remove.append(sub_variance_op) # check add op (epsilon) add_eps_op = self._try_get_child_op_type(sub_variance_op, "add") if add_eps_op is None: return False epsilon_var = add_eps_op.y if add_eps_op.x == sub_variance_op.outputs[0] else add_eps_op.x if epsilon_var.val is None or len(epsilon_var.val.shape) != 0: return False # must be scalar ops_to_remove.append(add_eps_op) # check rsqrt rsqrt_op = self._try_get_child_op_type(add_eps_op, "rsqrt") if rsqrt_op is None: return False ops_to_remove.append(rsqrt_op) # check mul (gamma) mul_gamma_op = self._try_get_child_op_type(rsqrt_op, "mul") if mul_gamma_op is None: return False gamma_var = mul_gamma_op.y if mul_gamma_op.x == rsqrt_op.outputs[0] else mul_gamma_op.x if gamma_var.val is None: return False ops_to_remove.append(mul_gamma_op) # check 2 muls after the gamma mul if not self._check_child_op_types(mul_gamma_op, ["mul", "mul"]): return False mul_gamma_child_ops = list(mul_gamma_op.outputs[0].child_ops) mul_op1 = mul_gamma_child_ops[0] mul_op2 = mul_gamma_child_ops[1] mul_op1_other_var = mul_op1.x if mul_op1.y == mul_gamma_op.outputs[0] else mul_op1.y mul_op2_other_var = mul_op2.x if mul_op2.y == mul_gamma_op.outputs[0] else mul_op2.y if not ( (mul_op1_other_var == root_var and mul_op2_other_var == mul_square_op.x) or (mul_op1_other_var == mul_square_op.x and mul_op2_other_var == root_var) ): return False if mul_op1_other_var == root_var: mul_op1, mul_op2 = mul_op1, mul_op2 else: mul_op2, mul_op1 = mul_op1, mul_op2 ops_to_remove.append(mul_op1) ops_to_remove.append(mul_op2) # check sub with beta sub_beta_op = self._try_get_child_op_type(mul_op2, "sub") if sub_beta_op is None: return False if sub_beta_op.y != mul_op2.outputs[0]: return False beta_var = sub_beta_op.x if beta_var.val is None: return False ops_to_remove.append(sub_beta_op) # check last add op add_op = self._try_get_child_op_type(sub_beta_op, "add") if add_op is None: return False if not ( (add_op.x == mul_op1.outputs[0] and add_op.y == sub_beta_op.outputs[0]) or (add_op.y == mul_op1.outputs[0] and add_op.x == sub_beta_op.outputs[0]) ): return False ops_to_remove.append(add_op) return self._try_apply_transform( reduce_op, block, gamma_var, beta_var, epsilon_var, add_op, ops_to_remove ) def _try_match_and_transform_pattern_5(self, reduce_op, block) -> bool: """ Detect BC1S ``LayerNorm`` pattern as in ml-ane-transformers. Identify two patterns, the first: ``y = (x - mean(x)) * rsqrt(variance(X) + eps)`` ``y = (x - mean(x)) * rsqrt(mean((x - mean(x))^2) + eps)`` .. code-block:: x --> reduce_mean --| | | |---| | V | V |----------------> sub --> mul --> reduce_mean --> add(epsilon) --> rsqrt | | | V |-------------------------------------------------> mul --> [...] If the optional elementwise weight and bias are set, the second pattern is: ``y = [(x - mean(x)) * rsqrt(mean((x - mean(x))^2) + eps) + beta] * gamma`` Note that this is different from the torch and MIL definitions of beta and gamma so beta is be scaled by gamma and applied before it. .. code-block:: x --> reduce_mean --| | | |---| | V | V |----------------> sub --> mul --> reduce_mean --> add(epsilon) --> rsqrt | | | V |-------------------------------------------------> mul | V add(beta) | V mul(gamma) --> [...] These pattern corresponds to a specific ``layer_norm``: - ``rank`` is 4. - ``axes`` is ``[1]`` - ``gamma`` and ``beta`` are applied as in ml-ane-transformers, in the opposite order of torch. """ ops_to_remove = [] root_var = reduce_op.x # check that root_var feeds into at least 2 ops if len(list(root_var.child_ops)) < 2: return False # Do not enforce that the only child ops are reduce_mean and sub as in other # patterns. There are models where the root op is used after the layer norm. # check 1st reduce_mean op if not self._check_reduce_op(reduce_op): return False if len(reduce_op.axes.val) != 1 or reduce_op.axes.val != [1] or not reduce_op.keep_dims.val: return False ops_to_remove.append(reduce_op) # check 1st sub op if not self._check_child_op_types(reduce_op, ["sub"], check_order=False): return False child_ops_reduce_mean = list(reduce_op.outputs[0].child_ops) sub_op1 = child_ops_reduce_mean[0] if sub_op1 is None or not self._check_child_op_types( sub_op1, child_op_types=["mul", "mul", "mul"] ): return False if not (sub_op1.x == root_var and sub_op1.y == reduce_op.outputs[0]): return False ops_to_remove.append(sub_op1) # check mul op (equivalent to a square op) square_op = self._try_get_child_op_type(sub_op1, "mul") if square_op is None or not self._check_child_op_types( square_op, child_op_types=["reduce_mean"] ): return False if square_op.x != square_op.y: return False ops_to_remove.append(square_op) # check second reduce mean reduce_op2 = self._try_get_child_op_type(square_op, "reduce_mean") if not self._check_reduce_op(reduce_op2) or not self._check_child_op_types( reduce_op2, child_op_types=["add"] ): return False if len(reduce_op2.axes.val) != 1 or reduce_op2.axes.val != [1] or not reduce_op2.keep_dims.val: return False ops_to_remove.append(reduce_op2) # check add op (with epsilon) add_op1 = self._try_get_child_op_type(reduce_op2, "add") if add_op1 is None or not self._check_child_op_types( add_op1, child_op_types=["rsqrt"] ): return False epsilon_var = add_op1.y if add_op1.x == reduce_op2.outputs[0] else add_op1.x if epsilon_var.val is None or len(epsilon_var.val.shape) != 0: return False # must be scalar ops_to_remove.append(add_op1) # check rsqrt rsqrt_op = self._try_get_child_op_type(add_op1, "rsqrt") if rsqrt_op is None or not self._check_child_op_types( rsqrt_op, child_op_types=["mul"] ): return False ops_to_remove.append(rsqrt_op) # Last op in pattern if there is no elementwise affine. mul_op = self._try_get_child_op_type(rsqrt_op, "mul") if mul_op is None: return False if mul_op.y != sub_op1.outputs[0] and mul_op.x != sub_op1.outputs[0]: return False ops_to_remove.append(mul_op) # Default values if no gamma or beta ops. end_op = mul_op gamma_var = None beta_var = None add_beta_op = self._try_get_child_op_type(mul_op, "add") mul_gamma_op = self._try_get_child_op_type(add_beta_op, "mul") has_beta_and_gamma = add_beta_op is not None and mul_gamma_op is not None # mul_op cannot be used except as an input to add_beta_op. if has_beta_and_gamma and not self._check_child_op_types( mul_op, child_op_types=["add"] ): # It would be possible to fuse this pattern as: # layer_norm(x, gamma=None, beta=None) -> add(beta) -> mul(gamma) -> ... # |-> other mul_op child ops # For simplicity don't handle this edge case. return False # add_beta_op cannot be used except as an input to mul_gamma_op. if has_beta_and_gamma and not self._check_child_op_types( add_beta_op, child_op_types=["mul"] ): # It would be possible to fuse this pattern as: # layer_norm(x, gamma=None, beta=None) -> add(beta) -> mul(gamma) -> ... # |-> other add_beta_op child ops # For simplicity don't handle this edge case. return False if add_beta_op is None and mul_gamma_op is None: # Gamma and beta are optional in layer_norm. pass elif add_beta_op is None or mul_gamma_op is None: # If only one of gamma or beta is present, they could # be folded into the layer_norm op. For simplicity # don't handle this edge case. return False if has_beta_and_gamma: beta_var = add_beta_op.y if add_beta_op.x == mul_op.outputs[0] else add_beta_op.x gamma_var = ( mul_gamma_op.y if mul_gamma_op.x == add_beta_op.outputs[0] else mul_gamma_op.x ) if beta_var.val is None or gamma_var.val is None: return False gamma_var = mb.const( val=np.squeeze(gamma_var.val), name="_fuse_layernorm_gamma", ) # Scale beta by gamma. Note: this un-scaling introduces a small amount # of precision loss. # beta_var = mb.const( val=np.squeeze(beta_var.val) * gamma_var.val, name="_fuse_layernorm_beta" ) ops_to_remove.extend([add_beta_op, mul_gamma_op]) end_op = mul_gamma_op return self._try_apply_transform( reduce_op, block, gamma_var, beta_var, epsilon_var, end_op, ops_to_remove ) @block_context_manager def _fuse_layernorm_or_instancenorm_block(self, block: Block): fusion_occurred = False for op in list(block.operations): if op.enclosing_block is None: continue for b in op.blocks: block_changed = True while block_changed: block_changed = self._fuse_layernorm_or_instancenorm_block(b) if len(op.blocks) > 0: continue # start pattern match if reduce_mean op is encountered if op.op_type == "reduce_mean": if self._try_match_and_transform_pattern_1(op, block): fusion_occurred = True elif self._try_match_and_transform_pattern_2(op, block): fusion_occurred = True elif self._try_match_and_transform_pattern_3(op, block): fusion_occurred = True elif self._try_match_and_transform_pattern_5(op, block): fusion_occurred = True elif op.op_type == "reduce_sum": if self._try_match_and_transform_pattern_4(op, block): fusion_occurred = True return fusion_occurred