# Copyright (c) 2023, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
import copy
import numpy as np
from coremltools import _logger as logger
from coremltools.converters.mil.mil import Block
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import Operation, types
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import (
_check_child_op_type,
_check_no_output_connection,
block_context_manager,
)
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
from coremltools.converters.mil.mil.types.symbolic import any_symbolic
[docs]@register_pass(namespace="common")
class add_conv_transpose_output_shape(AbstractGraphPass):
"""
The ``conv_transpose`` input ``output_shape`` is an optional input.
Since we can infer the output shape from ``type_inference``, we add
``output_shape`` input whenever it is known to be constant at
compile time. For example:
.. code-block::
Given:
%1: (1, 5, 39, fp32) = conv_transpose(...) # no output_shape input.
Result:
%2: (3, i32) = const(val=[1,5,39])
%3: (1, 5, 39, fp32) = conv_transpose(..., output_shape=%2)
"""
def apply(self, prog):
for f in prog.functions.values():
self._handle_block(f)
@staticmethod
def _match_pattern(op):
return (
op.op_type == "conv_transpose"
and op.output_shape is None
and not any_symbolic(op.outputs[0].shape)
)
@block_context_manager
def _handle_block(self, block):
for op in list(block.operations):
for b in op.blocks:
self._handle_block(b)
if not self._match_pattern(op):
continue
# matched pattern
x = mb.conv_transpose(
**op.inputs,
output_shape=op.outputs[0].shape,
name=op.name + "_has_output_shape",
before_op=op,
)
op.enclosing_block.replace_uses_of_var_after_op(
anchor_op=op, old_var=op.outputs[0], new_var=x
)
block.remove_ops([op])
[docs]@register_pass(namespace="common")
class compose_conv1d(AbstractGraphPass):
"""
In `TensorFlow <https://github.com/tensorflow/tensorflow/blob/r1.15/tensorflow/python/ops/nn_ops.py#L1657>`_,
``tf.keras.layers.Conv1D`` is a composite op:
.. code-block::
expand a dummy dim -> Conv2D -> squeeze the dummy dim
In `PyTorch <https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/Convolution.cpp#L1087>`_,
this is also true for some backends (``mkldnn`` and ``xpu``).
This decomposition wrecks the coremltools ``conv1d`` graph passes,
so we should recompose the fragments back to MIL ``conv``, which natively supports ``conv1d``:
.. code-block::
Pattern 1:
Given:
%2 = expand_dims(%1, axes=-2) or expand_dims(%1, axes=2), %1.rank = 3
%3 = conv(%2)
%4 = squeeze(%3, axes=-2) or squeeze(%3, axes=2)
...
Result:
%4 = conv(%1)
...
Pattern 2 (TensorFlow channel_last):
Given:
%2 = expand_dims(%1, axes=-3) or expand_dims(%1, axes=1), %1.rank = 3
%3 = transpose(%2, perm=(0, 3, 1, 2))
%4 = conv(%3)
%5 = transpose(%4, perm=(0, 2, 3, 1))
%6 = squeeze(%5, axes=-3) or squeeze(%5, axes=1)
...
Result:
%3 = transpose(%1, perm=(0, 2, 1))
%4 = conv(%3)
%6 = transpose(%4, perm=(0, 2, 1))
...
"""
def apply(self, prog):
for f in prog.functions.values():
self._compose_conv1d_block(f)
@block_context_manager
def _compose_conv1d_block(self, block: Block):
def help_compose_conv1d_block(block: Block) -> bool:
for op in list(block.operations):
for b in op.blocks:
self._compose_conv1d_block(b)
# must start with expanding a 3-D tensor,
# who has batch, channel, length dimensions
if op.op_type != "expand_dims" or op.x.rank != 3:
continue
# try pattern `expand_dim` -> `conv2d` -> `squeeze`
if self._try_match_and_transform_pattern(op, block):
# has to break as the downstream iterator is affected
return True
# try pattern `expand_dim` -> `transpose` -> `conv2d` -> `transpose` -> `squeeze`
if self._try_match_and_transform_pattern_channel_last(op, block):
# has to break as the downstream iterator is affected
return True
return False
block_changed = True
while block_changed:
block_changed = help_compose_conv1d_block(block)
def _try_match_and_transform_pattern(self, expand_op: Operation, block: Block) -> bool:
"""
identify the pattern: `expand_dim` -> `conv2d` -> `squeeze`
"""
# abort composition if dummy dimension is not added as height
if expand_op.axes.rank != 1 or expand_op.axes.val[0] not in (-2, 2):
return False
# `expand_dims` -> `conv`
if not _check_child_op_type(expand_op, "conv"):
return False
conv_op = expand_op.outputs[0].child_ops[0]
# `conv` -> `squeeze`
if not _check_child_op_type(conv_op, "squeeze"):
return False
squeeze_op = conv_op.outputs[0].child_ops[0]
# abort composition if not squeezing the dummy height
if squeeze_op.axes.rank != 1 or squeeze_op.axes.val[0] not in (-2, 2):
return False
# everything looks good
return self._try_apply_transform(expand_op, conv_op, squeeze_op, block)
def _try_match_and_transform_pattern_channel_last(
self, expand_op: Operation, block: Block
) -> bool:
"""
identify the pattern: `expand_dim` -> `transpose` -> `conv2d` -> `transpose` -> `squeeze`
"""
# abort composition if dummy dimension is not added as height
if expand_op.axes.rank != 1 or expand_op.axes.val[0] not in (-3, 1):
return False
# `expand_dims` -> `transpose`
if not _check_child_op_type(expand_op, "transpose"):
return False
transpose1_op = expand_op.outputs[0].child_ops[0]
# abort composition if permutation is not (0, 3, 1, 2)
perm1 = transpose1_op.perm.val.copy()
perm1[np.where(perm1 < 0)] += 4
if np.any(perm1 != (0, 3, 1, 2)):
return False
# `transpose` -> `conv`
if not _check_child_op_type(transpose1_op, "conv"):
return False
conv_op = transpose1_op.outputs[0].child_ops[0]
# `conv` -> `transpose`
if not _check_child_op_type(conv_op, "transpose"):
return False
transpose2_op = conv_op.outputs[0].child_ops[0]
# abort composition if permutation is not (0, 2, 3, 1)
perm2 = transpose2_op.perm.val.copy()
perm2[np.where(perm2 < 0)] += 4
if np.any(perm2 != (0, 2, 3, 1)):
return False
# `transpose` -> `squeeze`
if not _check_child_op_type(transpose2_op, "squeeze"):
return False
squeeze_op = transpose2_op.outputs[0].child_ops[0]
# abort composition if not squeezing the dummy height
if squeeze_op.axes.rank != 1 or squeeze_op.axes.val[0] not in (-3, 1):
return False
# everything looks good
return self._try_apply_transform_channel_last(
expand_op, transpose1_op, conv_op, transpose2_op, squeeze_op, block
)
@staticmethod
def _try_apply_transform(
expand_op: Operation, conv_op: Operation, squeeze_op: Operation, block: Block
) -> bool:
ops_to_remove = [expand_op, conv_op, squeeze_op]
if not _check_no_output_connection(block, ops_to_remove):
return False
# prepare `conv1d`
conv_kwargs = {"name": squeeze_op.outputs[0].name, "before_op": conv_op}
# inherit `x` from `expand_dim`
conv_kwargs["x"] = expand_op.x
# inherit `pad_type`, `groups`, `bias` from `conv2d`
conv_kwargs["pad_type"] = conv_op.inputs["pad_type"].val
conv_kwargs["groups"] = conv_op.inputs["groups"].val
bias = conv_op.inputs.get("bias", None)
if bias is not None:
conv_kwargs["bias"] = bias
# squeeze `weight`, `strides`, `pad`, `dilations` from `conv2d`
conv_kwargs["weight"] = mb.squeeze(
x=conv_op.inputs["weight"], axes=(-2,), before_op=conv_op
)
conv_kwargs["strides"] = (conv_op.inputs["strides"].val[-1],)
conv_kwargs["pad"] = (conv_op.inputs["pad"].val[-2], conv_op.inputs["pad"].val[-1])
conv_kwargs["dilations"] = (conv_op.inputs["dilations"].val[-1],)
# compose `conv1d`
out = mb.conv(**conv_kwargs)
# try replacing `expand_dim` -> `conv2d` -> `squeeze` output
# with the new `conv1d` output
if squeeze_op.enclosing_block.try_replace_uses_of_var_after_op(
anchor_op=squeeze_op, old_var=squeeze_op.outputs[0], new_var=out
):
# remove `expand_dim` -> `conv2d` -> `squeeze`
block.remove_ops(ops_to_remove)
return True
return False
@staticmethod
def _try_apply_transform_channel_last(
expand_op: Operation,
transpose1_op: Operation,
conv_op: Operation,
transpose2_op: Operation,
squeeze_op: Operation,
block: Block,
) -> bool:
ops_to_remove = [expand_op, transpose1_op, conv_op, transpose2_op, squeeze_op]
if not _check_no_output_connection(block, ops_to_remove):
return False
# create `transpose1`
transpose1_out = mb.transpose(
x=expand_op.x, perm=(0, 2, 1), name=transpose1_op.outputs[0].name, before_op=expand_op
)
# prepare `conv1d`
conv_kwargs = {"name": conv_op.outputs[0].name, "x": transpose1_out, "before_op": conv_op}
# inherit `pad_type`, `groups`, `bias` from `conv2d`
conv_kwargs["pad_type"] = conv_op.inputs["pad_type"].val
conv_kwargs["groups"] = conv_op.inputs["groups"].val
bias = conv_op.inputs.get("bias", None)
if bias is not None:
conv_kwargs["bias"] = bias
# squeeze `weight`, `strides`, `pad`, `dilations` from `conv2d`
conv_kwargs["weight"] = mb.squeeze(
x=conv_op.inputs["weight"], axes=(-2,), before_op=conv_op
)
conv_kwargs["strides"] = (conv_op.inputs["strides"].val[-1],)
conv_kwargs["pad"] = (conv_op.inputs["pad"].val[-2], conv_op.inputs["pad"].val[-1])
conv_kwargs["dilations"] = (conv_op.inputs["dilations"].val[-1],)
# compose `conv1d`
conv_out = mb.conv(**conv_kwargs)
# create `transpose2`
transpose2_out = mb.transpose(
x=conv_out, perm=(0, 2, 1), name=squeeze_op.outputs[0].name, before_op=transpose2_op
)
# try replacing `expand_dim` -> `transpose` -> `conv2d` -> `transpose` -> `squeeze` output
# with the new `transpose` -> `conv1d` -> `transpose` output
if squeeze_op.enclosing_block.try_replace_uses_of_var_after_op(
anchor_op=squeeze_op, old_var=squeeze_op.outputs[0], new_var=transpose2_out
):
# remove `expand_dim` -> `transpose` -> `conv2d` -> `transpose` -> `squeeze`
block.remove_ops(ops_to_remove)
return True
return False
[docs]@register_pass(namespace="common")
class fuse_conv_batchnorm(AbstractGraphPass):
"""
Fuse the following ``batch_norm`` layer into ``conv`` and ``conv_transpose``.
That is, convert ``conv + batch_norm`` to ``conv``, by modifying the weight and bias in the ``conv`` layer.
.. code-block::
Given:
%2 = conv(%1)
...
%3 = batch_norm(%2)
...
Result:
%3 = conv(%1)
...
"""
def apply(self, prog):
for f in prog.functions.values():
block_changed = True
while block_changed:
block_changed = self._fuse_conv_batchnorm_block(f)
@staticmethod
def _try_to_transform(conv_op, bn_op):
# get parameters from batch_norm layer
gamma = bn_op.gamma.val
beta = bn_op.beta.val
mean = bn_op.mean.val
variance = bn_op.variance.val
epsilon = bn_op.epsilon.val
# get weight, bias and groups from conv layer
if conv_op.weight.val is None:
return False
conv_weight = conv_op.weight.val
conv_bias = conv_op.bias
groups = conv_op.groups.val
# get type of the conv layer
is_deconv = conv_op.op_type == "conv_transpose"
# The deconv weight transpose axes is determined by the dimension of convolution.
# Conv1d should be [1, 0, 2], Conv2d should be [1, 0, 2, 3], Conv3d should be [1, 0, 2, 3, 4]
if not 3 <= len(conv_weight.shape) <= 5:
raise AssertionError(
f"Only supports Conv1/2/3d, which means weight's dimension should"
f"between 3 and 5, but got weight with {len(conv_weight.shape)} "
f"dimensions. "
)
deconv_weight_transpose_axes = [1, 0] + [axis for axis in range(2, len(conv_weight.shape))]
# D_in denotes the spatial dimensions for conv kernel weight
# for conv_transpose, conv_weight has shape [Cin, Cout / groups, *D_in]
# for conv, conv_weight has shape [Cout, Cin / groups, *D_in]
if is_deconv:
Cout = conv_weight.shape[1] * groups
Cin = conv_weight.shape[0]
else:
Cout = conv_weight.shape[0]
Cin = conv_weight.shape[1] * groups
# get the type of the conv weight
conv_weight_type = conv_weight.dtype
# create bias for conv if not exist
if conv_bias is None:
conv_bias = np.zeros(Cout)
else:
conv_bias = conv_bias.val
if conv_bias is None:
return False
conv_bias = conv_bias.astype(conv_weight_type)
# get the original shape of weight and bias
origin_weight_shape = conv_weight.shape
origin_bias_shape = conv_bias.shape
# update the weight for conv layer
new_conv_weight = []
new_conv_bias = []
if is_deconv:
conv_weight = np.transpose(conv_weight, deconv_weight_transpose_axes)
conv_weight = np.reshape(
conv_weight, [Cout, Cin // groups] + list(conv_weight.shape[2:])
)
for i in range(Cout):
# get batch norm parameters for each channel
_gamma = gamma[i]
_beta = beta[i]
_mean = mean[i]
_variance = variance[i]
_scale = _gamma / np.sqrt(_variance + epsilon)
# get conv weight and bias for each channel
_conv_weight = conv_weight[i]
_conv_bias = conv_bias[i]
# update the conv weight and bias
_conv_weight = _conv_weight * _scale
_conv_bias = _scale * (_conv_bias - _mean) + _beta
new_conv_weight.append(_conv_weight)
new_conv_bias.append(_conv_bias)
new_conv_weight = np.array(new_conv_weight).astype(conv_weight_type)
new_conv_bias = np.array(new_conv_bias).astype(conv_weight_type)
if is_deconv:
new_conv_weight = np.reshape(
new_conv_weight, [Cout // groups, Cin] + list(new_conv_weight.shape[2:])
)
new_conv_weight = np.transpose(new_conv_weight, deconv_weight_transpose_axes)
# make sure the updated weight and bias have the same shape as the original ones
if new_conv_weight.shape != origin_weight_shape:
raise AssertionError(
"conv weight should have the same shape before and after the fuse_"
"conv_batchnorm pass. "
)
if new_conv_bias.shape != origin_bias_shape:
raise AssertionError(
"conv bias should have the same shape before and after the fuse_"
"conv_batchnorm pass. "
)
# create a new conv op with the new bias value, copying rest of the attributes
out_name = bn_op.outputs[0].name
conv_kargs = {
"weight": new_conv_weight,
"bias": new_conv_bias,
"name": out_name,
"before_op": conv_op,
}
for k, v in conv_op.inputs.items():
if k in ["weight", "bias"]:
continue
conv_kargs[k] = v
if is_deconv:
x = mb.conv_transpose(**conv_kargs)
else:
x = mb.conv(**conv_kargs)
if bn_op.enclosing_block.try_replace_uses_of_var_after_op(
anchor_op=bn_op,
old_var=bn_op.outputs[0],
new_var=x,
):
bn_op.enclosing_block.remove_ops([conv_op, bn_op])
return True
return False
@block_context_manager
def _fuse_conv_batchnorm_block(self, block):
def _match_pattern(op):
if op.op_type == "conv" or op.op_type == "conv_transpose":
# abort fusion if op output is also a block output
if op.outputs[0] in op.enclosing_block.outputs:
return None
# find batch_norm op
child_ops = op.outputs[0].child_ops
if len(child_ops) == 1:
bn_op_candidate = list(child_ops)[0]
if bn_op_candidate.op_type == "batch_norm":
return bn_op_candidate
return None
fusion_occurred = False
for op in list(block.operations):
for b in op.blocks:
block_changed = True
while block_changed:
block_changed = self._fuse_conv_batchnorm_block(b)
if len(op.blocks) > 0:
# This op can't be conv or conv_transpose
continue
bn_op = _match_pattern(op)
if bn_op is not None:
fusion_occurred = self._try_to_transform(op, bn_op)
# has to break as the downstream iterator is affected.
if fusion_occurred:
return fusion_occurred
return fusion_occurred
[docs]@register_pass(namespace="common")
class fuse_conv_bias(AbstractGraphPass):
"""
Fold ``add``/``sub`` into ``bias`` of ``conv`` and ``conv_transpose``.
That is, convert ``conv + add/sub`` to ``conv``, when ``add``/``sub`` is adding a constant.
Two patterns are supported:
.. code-block::
Pattern 1:
Given:
%2 = conv(%1)
...
%3 = add(%2, constant) # where constant has shape (1,C,1)/(C,1) for 1d conv, (1,C,1,1)/(C,1,1) for 2d conv etc
...
Result:
%3 = conv(%1)
...
Pattern 2:
Given:
%2 = conv(%1)
%3 = transpose(%2)
...
%4 = add(%3, constant) # where constant has a broacasable shape
...
Result:
%2 = conv(%1)
%4 = transpose(%2)
...
"""
child_op_types = ["add", "sub"]
def apply(self, prog):
for f in prog.functions.values():
block_changed = True
while block_changed:
block_changed = self._fuse_conv_bias_block(f)
def _match_pattern(self, op):
if op.op_type == "conv" or op.op_type == "conv_transpose":
# abort fusion if op output is also a block output
if op.outputs[0] in op.enclosing_block.outputs:
return None
# find add
child_ops = op.outputs[0].child_ops
if len(child_ops) == 1:
add_op_candidate = list(child_ops)[0]
if add_op_candidate.op_type in self.child_op_types:
return add_op_candidate
return None
@staticmethod
def _try_to_transform_transpose_pattern(conv_op, block):
ops_to_remove = []
# conv layer
if conv_op.op_type != "conv" and conv_op.op_type != "conv_transpose":
return False
is_deconv = conv_op.op_type == "conv_transpose"
ops_to_remove.append(conv_op)
# transpose layer
if not _check_child_op_type(conv_op, "transpose"):
return False
transpose_op = list(conv_op.outputs[0].child_ops)[0]
ops_to_remove.append(transpose_op)
# add/sub layer
if not _check_child_op_type(transpose_op, "add") and not _check_child_op_type(
transpose_op, "sub"
):
return False
add_or_sub_op = list(transpose_op.outputs[0].child_ops)[0]
ops_to_remove.append(add_or_sub_op)
# get the bias
if add_or_sub_op.x.val is None and add_or_sub_op.y.val is None:
return False
bias = add_or_sub_op.x.val if add_or_sub_op.x.val is not None else add_or_sub_op.y.val
is_first_input = add_or_sub_op.y.val is not None
is_sub = add_or_sub_op.op_type == "sub"
# get the conv bias/weight
conv_shape = conv_op.outputs[0].shape
Cout = conv_shape[1]
conv_weight = conv_op.weight.val
conv_weight_type = conv_weight.dtype
conv_bias = (
np.zeros(Cout).astype(conv_weight_type) if conv_op.bias is None else conv_op.bias.val
)
# check if the bias is compatible for fusion
is_bias_scalar = True
if isinstance(bias, np.ndarray):
if bias.shape == ():
bias = bias.tolist()
elif np.prod(bias.shape) == 1:
bias = np.squeeze(bias).tolist()
else:
is_bias_scalar = False
if not is_bias_scalar:
if np.prod(bias.shape) != Cout:
return False
rank = transpose_op.outputs[0].rank
cout_dim = transpose_op.perm.val.tolist().index(1) - rank
if bias.shape[cout_dim] != Cout:
return False
bias = np.reshape(bias, (Cout))
# compute the new bias
if is_sub:
if is_first_input:
bias = -bias
else:
conv_bias = -conv_bias
new_bias = conv_bias + bias
# compute the new weight
if is_sub and not is_first_input:
new_weight = -conv_weight
else:
new_weight = conv_weight
if not _check_no_output_connection(block, ops_to_remove):
return False
# create a new conv op with the new weight, bias value, copying rest of the attributes
conv_kargs = {"weight": new_weight, "bias": new_bias, "before_op": conv_op}
for k, v in conv_op.inputs.items():
if k in ["weight", "bias"]:
continue
conv_kargs[k] = v
if is_deconv:
x = mb.conv_transpose(**conv_kargs)
else:
x = mb.conv(**conv_kargs)
# create a new transpose op
out_name = add_or_sub_op.outputs[0].name
tranpose_kargs = {"x": x, "name": out_name, "before_op": transpose_op}
for k, v in transpose_op.inputs.items():
if k == "x":
continue
tranpose_kargs[k] = v
x = mb.transpose(**tranpose_kargs)
if add_or_sub_op.enclosing_block.try_replace_uses_of_var_after_op(
anchor_op=add_or_sub_op,
old_var=add_or_sub_op.outputs[0],
new_var=x,
):
add_or_sub_op.enclosing_block.remove_ops(ops_to_remove)
return True
return False
@staticmethod
def _try_to_transform(conv_op, add_op):
if add_op.op_type == "sub":
bias_var = add_op.y
else:
bias_var = add_op.x if add_op.x.val is not None else add_op.y
bias_value = bias_var.val
is_conv_op = conv_op.op_type == "conv"
# check that the bias value is a constant array or a scalar constant
if not isinstance(bias_value, (np.ndarray, np.generic)):
return False
is_bias_scalar = False
if not isinstance(bias_value, np.ndarray):
is_bias_scalar = True
# find rank of the conv input
rank = conv_op.x.rank
if rank is None:
return False
if not (rank == 3 or rank == 4 or rank == 5):
return False
# check compatibility of bias value with the rank of the conv op
# either bias value should be a scalar or:
# rank=3 ==> (B,C,D), which means bias must be (1,C,1) or (C,1)
# rank=4 ==> (B,C,D1,D2), which means bias must be (1,C,1,1) or (C,1,1)
# rank=5 ==> (B,C,D1,D2,D3), which means bias must be (1,C,1,1,1) or (C,1,1,1)
if is_bias_scalar:
bias_value = np.array([bias_value])
else:
# check that there is at most one dimension in the shape that is not 1
if len(np.squeeze(bias_value).shape) > 1:
return False
# check that addition is not happening on the batch dimension
if len(bias_value.shape) == rank:
if bias_value.shape[0] != 1:
return False
# check that last rank-2 entries in the shape vector are all 1s
if np.prod(bias_value.shape[-(rank - 2) :]) != 1:
return False
bias_value = np.squeeze(bias_value)
if add_op.op_type == "sub":
bias_value *= -1
# everything looks good, now find the new updated bias
old_bias = conv_op.inputs.get("bias", None)
old_bias_value = None
if old_bias is not None and old_bias.val is not None:
old_bias_value = old_bias.val
if old_bias is None:
# need to create a fresh numpy array for bias
if np.prod(bias_value.shape) == 1:
# its a scalar bias
# need to find the value of Cout to form a new bias
if conv_op.weight.val is None:
return False
# conv_transpose has weight format [K, C_out, spatial dims]
# conv has weight format [C_out, K, spatial dims]
Cout = conv_op.weight.val.shape[0 if is_conv_op else 1]
new_bias_value = np.broadcast_to(bias_value, (Cout,))
else:
new_bias_value = bias_value
else:
# just need to update the existing bias array
try:
new_bias_value = old_bias_value + bias_value
except:
return False
# create a new conv op with the new bias value, copying rest of the attributes
out_name = add_op.outputs[0].name
if new_bias_value.dtype != np.float32 and new_bias_value.dtype != np.float16:
# cast the bias to match the weight type
weight_np_type = types.nptype_from_builtin(
conv_op.inputs["weight"].sym_type.get_primitive()
)
logger.warning(
"conv_bias_fusion pass: casting bias "
"from {} to {} to match the dtype of the weight of the conv layer".format(
new_bias_value.dtype, weight_np_type
)
)
new_bias_value = new_bias_value.astype(weight_np_type)
new_bias_var = mb.const(val=new_bias_value, before_op=conv_op)
conv_kargs = {"bias": new_bias_var, "name": out_name, "before_op": conv_op}
for k, v in conv_op.inputs.items():
if k == "bias":
continue
conv_kargs[k] = v
if is_conv_op:
x = mb.conv(**conv_kargs)
else:
x = mb.conv_transpose(**conv_kargs)
if add_op.enclosing_block.try_replace_uses_of_var_after_op(
anchor_op=add_op,
old_var=add_op.outputs[0],
new_var=x,
):
add_op.enclosing_block.remove_ops([conv_op, add_op])
return True
return False
@block_context_manager
def _fuse_conv_bias_block(self, block):
fusion_status = False
for op in list(block.operations):
for b in op.blocks:
block_changed = True
while block_changed:
block_changed = self._fuse_conv_bias_block(b)
if len(op.blocks) > 0:
# This op can't be conv or conv_transpose
continue
# pattern 1 : conv + add/sub
add_op = self._match_pattern(op)
if add_op is not None:
fusion_status = self._try_to_transform(op, add_op)
# has to break as the downstream iterator is affected.
if fusion_status:
return fusion_status
# pattern 2 : conv + transpose + add/sub
fusion_status = self._try_to_transform_transpose_pattern(op, block)
if fusion_status:
return fusion_status
return fusion_status
[docs]@register_pass(namespace="common")
class fuse_conv_scale(AbstractGraphPass):
"""
Fold ``mul``/``div`` into ``conv``/``conv_transpose`` by updating the weight/bias of the convolution layers.
The scale ``const`` can be a single number (scalar) or a vector with a broadcastable shape.
For example, if the output of the ``conv``/``deconv`` layer is ``(B, Cout, H, W)``,
``const`` of shape ``(Cout, 1, 1)`` and ``(1, Cout, 1, 1)`` are allowed.
.. code-block::
Given:
%2 = conv(%1)
...
%3 = mul(%2, constant) # where constant is the scale constant
...
Result:
%3 = conv(%1)
...
"""
def apply(self, prog):
for f in prog.functions.values():
block_changed = True
while block_changed:
block_changed = self._fuse_conv_scale_block(f)
@staticmethod
def _try_to_transform(conv_op, scale_op):
# get the scale
if scale_op.x.val is None and scale_op.y.val is None:
return False
scale_var = scale_op.x if scale_op.x.val is not None else scale_op.y
scale = scale_var.val
# for the scalar case, the scalar can be either
# 1. a python int/float
# 2. a 0d numpy array
# 3. a 1d numpy array with shape (1,)
is_scalar = True
if isinstance(scale, np.ndarray):
if scale.shape == ():
scale = scale.tolist()
elif scale.shape == (1) or scale.shape == (1,):
scale = scale[0]
else:
is_scalar = False
# get weight and bias and groups from conv layer
if conv_op.weight.val is None:
return False
conv_weight = conv_op.weight.val
conv_bias = conv_op.bias
groups = conv_op.groups.val
# get type of the conv layer
is_deconv = conv_op.op_type == "conv_transpose"
is_conv_1d = len(conv_weight.shape) == 3
# D_in denotes the spatial dimensions for conv kernel weight
# for conv_transpose, conv_weight has shape [Cin, Cout / groups, *D_in]
# for conv, conv_weight has shape [Cout, Cin / groups, *D_in]
if is_deconv:
Cout = conv_weight.shape[1] * groups
Cin = conv_weight.shape[0]
else:
Cout = conv_weight.shape[0]
Cin = conv_weight.shape[1] * groups
# for the vector scale case, check if the shape is broacastable
if not is_scalar:
if not np.product(scale.shape) == Cout:
return False
if len(scale.shape) == len(conv_weight.shape):
if not scale.shape[1] == Cout:
return False
elif len(scale.shape) == len(conv_weight.shape) - 1:
if not scale.shape[0] == Cout:
return False
else:
return False
# transform the scale to 1./scale for the real_div case
if scale_op.op_type == "real_div":
scale = 1.0 / scale
# get the type of the conv weight
conv_weight_type = conv_weight.dtype
# create bias for conv if not exist
if conv_bias is None:
conv_bias = np.zeros(Cout)
else:
conv_bias = conv_bias.val
conv_bias = conv_bias.astype(conv_weight_type)
# get the original shape of weight and bias
origin_weight_shape = conv_weight.shape
origin_bias_shape = conv_bias.shape
# update the weight/bias for conv layer
if is_scalar:
new_conv_bias = np.array(conv_bias * scale).astype(conv_weight_type)
new_conv_weight = np.array(conv_weight * scale).astype(conv_weight_type)
else:
scale = np.reshape(scale, (Cout))
new_conv_bias = np.array(conv_bias * scale).astype(conv_weight_type)
new_conv_weight = []
if is_deconv:
conv_weight = np.transpose(conv_weight, [1, 0, 2] if is_conv_1d else [1, 0, 2, 3])
conv_weight = np.reshape(
conv_weight, [Cout, Cin // groups] + list(conv_weight.shape[2:])
)
for i in range(Cout):
_conv_weight = conv_weight[i] * scale[i]
new_conv_weight.append(_conv_weight)
new_conv_weight = np.array(new_conv_weight).astype(conv_weight_type)
if is_deconv:
new_conv_weight = np.reshape(
new_conv_weight, [Cout // groups, Cin] + list(new_conv_weight.shape[2:])
)
new_conv_weight = np.transpose(
new_conv_weight, [1, 0, 2] if is_conv_1d else [1, 0, 2, 3]
)
# make sure the updated weight and bias have the same shape as the original ones
assert (
new_conv_weight.shape == origin_weight_shape
), "conv weight should have the same shape before and after the fuse_conv_scale pass."
assert (
new_conv_bias.shape == origin_bias_shape
), "conv bias should have the same shape before and after the fuse_conv_scale pass."
# create a new conv op with the new weight, bias value, copying rest of the attributes
out_name = scale_op.outputs[0].name
conv_kargs = {
"weight": new_conv_weight,
"bias": new_conv_bias,
"name": out_name,
"before_op": conv_op,
}
for k, v in conv_op.inputs.items():
if k in ["weight", "bias"]:
continue
conv_kargs[k] = v
if is_deconv:
x = mb.conv_transpose(**conv_kargs)
else:
x = mb.conv(**conv_kargs)
if scale_op.enclosing_block.try_replace_uses_of_var_after_op(
anchor_op=scale_op,
old_var=scale_op.outputs[0],
new_var=x,
):
scale_op.enclosing_block.remove_ops([conv_op, scale_op])
return True
return False
@block_context_manager
def _fuse_conv_scale_block(self, block):
def _match_pattern(op):
if op.op_type == "conv" or op.op_type == "conv_transpose":
# abort fusion if op output is also a block output
if op.outputs[0] in op.enclosing_block.outputs:
return None
# find batch_norm op
child_ops = op.outputs[0].child_ops
if len(child_ops) == 1:
scale_op_candidate = list(child_ops)[0]
if scale_op_candidate.op_type in ["mul", "real_div"]:
return scale_op_candidate
return None
fusion_occurred = False
for op in list(block.operations):
for b in op.blocks:
block_changed = True
while block_changed:
block_changed = self._fuse_conv_scale_block(b)
if len(op.blocks) > 0:
# This op can't be conv or conv_transpose
continue
scale_op = _match_pattern(op)
if scale_op is not None:
fusion_occurred = self._try_to_transform(op, scale_op)
# has to break as the downstream iterator is affected.
if fusion_occurred:
return fusion_occurred
return fusion_occurred
[docs]@register_pass(namespace="common")
class fuse_pad_conv(AbstractGraphPass):
"""
When we observe ``pad -> transpose -> conv``, we move the ``pad`` to be next to ``conv``.
This allows us to meld ``pad + conv`` if possible.
.. code-block::
Given:
%1 = pad(%0, ...)
%2 = transpose(%1, ...)
%3 = conv(%2, ...)
...
Result:
%1.a = transpose(%0, ...)
$2.a = pad(%1.a, ...)
%3 = conv(%2.a)
...
"""
def apply(self, prog):
for f in prog.functions.values():
block_changed = True
while block_changed:
block_changed = self._pad_conv_connect_block(f)
@staticmethod
def _match_pattern(op):
ret = set([])
child_ops = op.outputs[0].child_ops
for child_op in child_ops:
if child_op.op_type != "transpose":
continue
skip_ops = child_op.outputs[0].child_ops
for skip_op in skip_ops:
if "conv" not in skip_op.op_type:
continue
ret.update([child_op])
return ret if len(ret) != 0 else None
@staticmethod
def _try_to_transform(pad_op, transpose_ops, block):
def _compute_new_pad_values(transpose_op):
if pad_op.inputs["pad"].val is None:
return None
pad_amounts = np.reshape(pad_op.inputs["pad"].val, [-1, 2])
transpose_axes = transpose_op.inputs["perm"].val
rank_diff = len(transpose_axes) - pad_amounts.shape[0]
pad_amounts_new = copy.deepcopy(pad_amounts)
# append "rank_diff" rows of zeros to the top
pad_amounts_new = np.concatenate(
(np.zeros((2 * rank_diff)).reshape(-1, 2), pad_amounts_new)
)
pad_amounts_new = pad_amounts_new.astype(pad_amounts.dtype)
pad_amounts = np.concatenate((np.zeros((2 * rank_diff)).reshape(-1, 2), pad_amounts))
for i, axis in enumerate(transpose_axes):
pad_amounts_new[i][0] = pad_amounts[axis][0]
pad_amounts_new[i][1] = pad_amounts[axis][1]
# get the top "rank_diff" rows
top_rows = pad_amounts_new[:rank_diff, :]
if not np.all(top_rows == 0):
return False
# cut "rank_diff" from the top
pad_amounts_new = pad_amounts_new[rank_diff:, :]
pad_amounts_new = pad_amounts_new.flatten()
return pad_amounts_new
if pad_op.outputs[0] in pad_op.enclosing_block.outputs:
return False
if len(set(pad_op.outputs[0].child_ops)) != len(transpose_ops):
return False
for transpose_op in transpose_ops:
pad_amounts_new = _compute_new_pad_values(transpose_op)
if pad_amounts_new is None:
continue
with pad_op.enclosing_block:
new_transpose_var = mb.transpose(
x=pad_op.inputs["x"],
perm=transpose_op.inputs["perm"].val,
before_op=transpose_op,
)
new_pad_inputs = {"x": new_transpose_var, "pad": pad_amounts_new}
for k, v in pad_op.inputs.items():
if k not in new_pad_inputs:
new_pad_inputs[k] = v
new_pad_var = mb.pad(before_op=transpose_op, **new_pad_inputs)
pad_op.enclosing_block.replace_uses_of_var_after_op(
anchor_op=transpose_op, old_var=transpose_op.outputs[0], new_var=new_pad_var
)
pad_op.enclosing_block.remove_ops(list(transpose_ops) + [pad_op])
return True
@block_context_manager
def _pad_conv_connect_block(self, block):
fusion_status = False
for op in list(block.operations):
for b in op.blocks:
block_changed = True
while block_changed:
block_changed = self._pad_conv_connect_block(b)
if op.op_type != "pad":
continue
transpose_ops = self._match_pattern(op)
if transpose_ops is not None:
fusion_status = self._try_to_transform(op, transpose_ops, block)
# has to break as the downstream iterator is affected.
if fusion_status:
return fusion_status
return fusion_status