MIL Graph Passes
Graph Passes supported by the Model Intermediate Language (MIL):
cleanup
- class coremltools.converters.mil.mil.passes.defs.cleanup.const_elimination[source]
Replace non-
const
ops that haveconst
Var. Outputs are replaced with theconst
op. Example:Given: %2, %3 = non_const_op(...) # %2 is const, %3 isn't const %4 = other_op(%2, %3) Result: _, %3 = non_const_op(...) # _ is the ignored output %2_const = const() # %2_const name is for illustration only %4 = other_op(%2_const, %3)
Support options:
skip_const_by_size
: Skip foldingconst
ops that have larger number of elements than a threshold.
- class coremltools.converters.mil.mil.passes.defs.cleanup.dead_code_elimination[source]
Eliminate unused ops in program. Ops whose outputs do not contribute to final outputs will be deleted.
# Before dead_code_elimination pass. main(%x: (2, 4, fp32)) { block0() { %const_2: (4, 2, fp32)* = const(val=[...]) %const_3: (4, fp32)* = const(val=[...]) %tx_0: (bool)* = const(val=False) %ty_0: (bool)* = const(val=False) %matmul_0: (2, 2, fp32) = matmul(x=%x, y=%const_2, transpose_x=%tx_0, transpose_y=%ty_0) %linear_0: (2, 4, fp32) = linear(x=%x, weight=%const_2, bias=%const_3) } -> (%linear_0) } # After dead_code_elimination pass. main(%x: (2, 4, fp32)) { block0() { %const_2: (4, 2, fp32)* = const(val=[...]) %const_3: (4, fp32)* = const(val=[...]) %linear_0: (2, 4, fp32) = linear(x=%x, weight=%const_2, bias=%const_3) } -> (%linear_0) }
In the example above,
%matmul_0
is an op that is not used in the computation. This op and its input ops (%tx_0
and%ty_0
) are eliminated in this pass.
- class coremltools.converters.mil.mil.passes.defs.cleanup.dedup_op_and_var_names[source]
For each function, this pass renames ops and variables with the same name as any preceding ops/variables across all scopes in the given function, where the precedence is implementation-specific. Note that an op name and variable names are tracked separately, so an op may have the same name as a variable.
The pass preserves input and output name. Raises ValueError if we cannot dedup without changing the input/output var names.
def prog(x): x = mb.cast(x=x, dtype="fp16", name="castop") x = mb.cast(x=x, dtype="fp32", name="castop") x = mb.square(x=x, name="square_last") return x # Before dedup pass, the op names are ["castop", "castop", "square_last"]. # After dedup pass, the op names are ["castop", "castop_1", "square_last"].
- class coremltools.converters.mil.mil.passes.defs.cleanup.fuse_reduce_mean[source]
Detect the
reduce_sum
—>mul/real_div
pattern than can be mapped toreduce_mean
. That is, the operationreduce_sum/count == reduce_mean
.Input graph
const (scalar) | input ----> reduce_sum ----> mul/real_div -----------> output
Output graph
input --------> reduce_mean ---------> output
- class coremltools.converters.mil.mil.passes.defs.cleanup.loop_invariant_elimination[source]
When a block does not modify a block input var, eliminate that block input var and use the corresponding var in the outer scope. Example:
# Before loop_invariant_elimination pass. # Notice that ``%b.x`` is constant through while loop iterates. main(%a: (1, 2, fp32), %b: (1, 2, fp32)) { block0() { %loop:0: (1, 2, fp32), %loop:1: (1, 2, fp32) = while_loop(loop_vars=(%a, %b)) loop_cond(%a.x, %b.x) { %cond_var: (bool) = some_op(x=%a.x, y=%b.x) } -> (%cond_var) loop_body(%a.x, %b.x) { %add_0: (1, 2, fp32) = add(x=%a.x, y=%b.x) } -> (%add_0, %b.x) } -> (%loop:0, %loop:1) } # After loop_invariant_elimination pass. main(%a: (1, 2, fp32), %b: (1, 2, fp32)) { block0() { %loop:1: (1, 2, fp32) = identity(x=%b) %loop:0: (1, 2, fp32) = while_loop(loop_vars=(%a)) loop_cond(%a.x) { %cond_var: (bool) = some_op(x=%a.x, y=%b) } -> (%cond_var) loop_body(%a.x) { %add_0: (1, 2, fp32) = add(x=%a.x, y=%b) } -> (%add_0) } -> (%loop:0, %loop:1) }
where we eliminate loop invariant
%b.x
fromwhile_loop
, which returns 1 instead of 2 outputs. We also preserve the return var names with identity.
- class coremltools.converters.mil.mil.passes.defs.cleanup.noop_elimination[source]
Remove ops that have no effect.
Given: %1 (1, 96, 128, 64, fp32) = ... %2 (1, 96, 128, 64, fp32) = reshape(%1) ... %3 (1, 96, 128, 64, fp32) = add(%2, constant) ... Result: %1 (1, 96, 128, 64, fp32) = ... %3 (1, 96, 128, 64, fp32) = add(%1, constant) ...
- class coremltools.converters.mil.mil.passes.defs.cleanup.remove_redundant_ops[source]
If there are multiple ops with “identical” inputs, then they are redundant and all but one of them can be removed. This pass checks and removes such ops.
Since all inputs to ops in MIL are named, two ops with same
op_types
can be compared by comparing their correspondingly named inputs. Inputs are treated as identical if one of the following is true:The input is a constant var, in which case its value should have the same dtype and numerical value.
The input is a non constant var, in which case it should be the same var object.
This pass iterates over the ops, takes its first output var, and then builds a candidate op list from the child ops of this var. This candidate ops list contains ops of the same
op_type
, arranged in topological order. From each of these candidate ops in the list, the second, third, and subsequent ops are pairwise compared with the first op, and if identical to it, they are removed. For example:Input: %0 = op0(...) %1 = op1(...) %2 = const(val=4.5) %3 = const(val=4.5) %4 = op2(%1, %0, %2) %5 = op3(%1, %0, %3) Output: %0 = op0(...) %1 = op1(...) %2 = const(val=4.5) %3 = const(val=4.5) # this will get removed later by dead code elimination pass %4 = op2(%1, %0, %2)
In the example above,
op3
is removed and all uses of%5
is replaced by%4
. For more examples, see “TestRemoveRedundantOpsPass”.
- class coremltools.converters.mil.mil.passes.defs.cleanup.remove_symbolic_reshape[source]
Convert symbolic shape in
reshape
to integers.Note: This does not perform any optimization, but simply replaces symbols with positive integers if solved from volumetric constraint, or -1. Therefore, this pass fails if more than one symbol needs to be resolved to -1.
# Before remove_symbolic_reshape pass. main(%x: (s0, 4, fp32)) { block0() { %reshape_0_shape_0: (3,i32)^ = const(val=(s0, s1, 2)) %reshape_0: (s0, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0) } -> (%reshape_0) } # After remove_symbolic_reshape pass. main(%x: (s0, 4, fp32)) { block0() { %reshape_0_shape_0x: (3,i32)* = const(val=[-1, 2, 2]) %reshape_0: (-1, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0x) } -> (%reshape_0) }
TODO (rdar://59165842): Use expand_dims, squeeze etc to use 0 instead of dynamic reshape with -1.
- class coremltools.converters.mil.mil.passes.defs.cleanup.topological_reorder[source]
Topologically re-orders the list of operations in a program by places each operation closer to its first use, or at the end if it’s not consumed by any other operation.
Currently, This pass re-orders only Transpose and Cast operations.
# Example: input program main(x: (2, 4, fp32)) { x = mb.cast(x=x, dtype="fp16") x1 = mb.square(x=x) x1_t = mb.transpose(x=x1, perm=[1, 0]) x2 = mb.cast(x=x1_t, dtype="fp32") x3 = mb.log(x=x) x3_t = mb.transpose(x=x3, perm=[1, 0]) x4 = mb.cast(x=x3_t, dtype="fp32") x5 = mb.relu(x=x) x6 = mb.cast(x=x5, dtype="fp32") x7 = mb.relu(x=x6) x8 = mb.relu(x=x) } -> x2, x4, x7, x8 # After moving `cast` ops becomes main(x: (2, 4, fp32)) { x = mb.cast(x=x, dtype="fp16") x1 = mb.square(x=x) x1_t = mb.transpose(x=x1, perm=[1, 0]) x3 = mb.log(x=x) x3_t = mb.transpose(x=x3, perm=[1, 0]) x5 = mb.relu(x=x) x6 = mb.cast(x=x5, dtype="fp32") x7 = mb.relu(x=x6) x8 = mb.relu(x=x) x4 = mb.cast(x=x3_t, dtype="fp32") x2 = mb.cast(x=x1_t, dtype="fp32") } -> x2, x4, x7, x8 # After moving `transpose` ops becomes main(x: (2, 4, fp32)) { x = mb.cast(x=x, dtype="fp16") x1 = mb.square(x=x) x3 = mb.log(x=x) x5 = mb.relu(x=x) x6 = mb.cast(x=x5, dtype="fp32") x7 = mb.relu(x=x6) x8 = mb.relu(x=x) x3_t = mb.transpose(x=x3, perm=[1, 0]) x4 = mb.cast(x=x3_t, dtype="fp32") x1_t = mb.transpose(x=x1, perm=[1, 0]) x2 = mb.cast(x=x1_t, dtype="fp32") } -> x2, x4, x7, x8
optimize_activation
- class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_gelu_exact[source]
Identify the pattern that corresponds to the exact version of
gelu
, and replace it with a singlegelu
layer withmode=EXACT
. The pattern isy = 0.5 * x * (1 + erf (x / srqt (2))
, which can be represented by one of the following:(1) [...] ----> div (1.414) ---> erf ---> add (1) -----> mul (0.5) ---> mul ---> [...] | ^ | | |------------------------------------------------------------------- (2) [...] ----> div (1.414) ---> erf ---> add (1) -----> mul ---> mul (0.5) ---> [...] | ^ | | |---------------------------------------------------- (3) [...] ----> div (1.414) ---> erf ---> add (1) -----> mul ------> [...] | ^ | | |---------------> mul(0.5) -------------------------- All of them are converted to: [...] ----> gelu (mode=EXACT) ---> [...]
- class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_gelu_tanh_approximation[source]
Identify the pattern that corresponds to the
tanh
approximate version ofgelu
, and replace it with a singlegelu
layer withmode=TANH_APPROXIMATION
.The implementation of this pass uses the generic graph pattern matching and transform algorithm implemented in
coremltools.converters.mil.experimental.passes.generic_pass_infrastructure
and documented incoremltools/converters/mil/experimental/passes/readme.md
.Graph for
get_gelu_pattern1()
y = x * (0.5 * (tanh(((.0447)x^3 + x ) * sqrt(2/pi)) + 1))
[...] -----> pow (3) ----> mul (.044715) ---> add -----> mul (sqrt(2/pi)) ---> tanh ----> add (1) ----> mul (0.5) -----> mul ---> [...] | ^ ^ | | | |------------------------------------------------------------------------------------------------------------------------
Graph for
get_gelu_pattern2()
y = (0.5 * x) * (tanh(((.0447)x^3 + x ) * sqrt(2/pi)) + 1)
-------------------------------------------------------------------------------------------------------- ^ | | V [...] -----> mul(0.5) pow (3) ----> mul (.044715) ---> add -----> mul (sqrt(2/pi)) ---> tanh ----> add (1) -----> mul ---> [...] | ^ ^ | | | |---------------------------------------------------------
- class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_leaky_relu[source]
Detect the
mul
—>max
pattern than can be mapped toleaky_relu
.In code form - Input
%2 = const(value = alpha) # where 0 <= alpha <= 1 %3 = mul(%1, %2) # alpha * x %4 = max(%3, %1) # max(alpha * x, x)
In code form - Output
%4 = leaky_relu(x=%1, alpha=%2)
In graphical form - Input graph
const (val = alpha) | input ----> mul ---------------> maximum -----------> output | | |----------------------------------
In graphical form - Output graph
input --------> leaky_relu ---------> output
- class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_prelu[source]
Detect the following patterns that can be mapped to a
prelu
op. Essentially, theprelu
op can be broken down into the following ops:y = a * relu(-1 * x) + relu(x)
Pattern 1
| ------------> relu --------------------| | V x (BCHW) ------| add -----> y (BCHW) | ^ --------> mul -------> relu -----> mul---| ^ ^ | | Const(val=-1) Const(name=a, shape=(C,1,1) or (1,C,1,1))
This will be mapped to:
x (BCHW) ------> prelu(alpha=a, shape=(C,)) ---------> y (BCHW)
Pattern 2
| ------------> relu --------------------| | V x (BCHW) -->transpose(BHWC)---->| add -----> y (BHWC) | ^ --------> mul -------> relu -----> mul---| ^ ^ | | Const(val=-1) Const(shape=(C,) or (1,C) or (1,1,C) or (1,1,1,C))
This will be mapped to:
x (BCHW) ------> prelu ---------> transpose ------> y (BHWC)
optimize_conv
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.add_conv_transpose_output_shape[source]
The
conv_transpose
inputoutput_shape
is an optional input. Since we can infer the output shape fromtype_inference
, we addoutput_shape
input whenever it is known to be constant at compile time. For example:Given: %1: (1, 5, 39, fp32) = conv_transpose(...) # no output_shape input. Result: %2: (3, i32) = const(val=[1,5,39]) %3: (1, 5, 39, fp32) = conv_transpose(..., output_shape=%2)
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.compose_conv1d[source]
In TensorFlow,
tf.keras.layers.Conv1D
is a composite op:expand a dummy dim -> Conv2D -> squeeze the dummy dim
In PyTorch, this is also true for some backends (
mkldnn
andxpu
).This decomposition wrecks the coremltools
conv1d
graph passes, so we should recompose the fragments back to MILconv
, which natively supportsconv1d
:Pattern 1: Given: %2 = expand_dims(%1, axes=-2) or expand_dims(%1, axes=2), %1.rank = 3 %3 = conv(%2) %4 = squeeze(%3, axes=-2) or squeeze(%3, axes=2) ... Result: %4 = conv(%1) ... Pattern 2 (TensorFlow channel_last): Given: %2 = expand_dims(%1, axes=-3) or expand_dims(%1, axes=1), %1.rank = 3 %3 = transpose(%2, perm=(0, 3, 1, 2)) %4 = conv(%3) %5 = transpose(%4, perm=(0, 2, 3, 1)) %6 = squeeze(%5, axes=-3) or squeeze(%5, axes=1) ... Result: %3 = transpose(%1, perm=(0, 2, 1)) %4 = conv(%3) %6 = transpose(%4, perm=(0, 2, 1)) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_conv_batchnorm[source]
Fuse the following
batch_norm
layer intoconv
andconv_transpose
. That is, convertconv + batch_norm
toconv
, by modifying the weight and bias in theconv
layer.Given: %2 = conv(%1) ... %3 = batch_norm(%2) ... Result: %3 = conv(%1) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_conv_bias[source]
Fold
add
/sub
intobias
ofconv
andconv_transpose
. That is, convertconv + add/sub
toconv
, whenadd
/sub
is adding a constant.Two patterns are supported:
Pattern 1: Given: %2 = conv(%1) ... %3 = add(%2, constant) # where constant has shape (1,C,1)/(C,1) for 1d conv, (1,C,1,1)/(C,1,1) for 2d conv etc ... Result: %3 = conv(%1) ... Pattern 2: Given: %2 = conv(%1) %3 = transpose(%2) ... %4 = add(%3, constant) # where constant has a broacasable shape ... Result: %2 = conv(%1) %4 = transpose(%2) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_conv_scale[source]
Fold
mul
/div
intoconv
/conv_transpose
by updating the weight/bias of the convolution layers.The scale
const
can be a single number (scalar) or a vector with a broadcastable shape. For example, if the output of theconv
/deconv
layer is(B, Cout, H, W)
,const
of shape(Cout, 1, 1)
and(1, Cout, 1, 1)
are allowed.Given: %2 = conv(%1) ... %3 = mul(%2, constant) # where constant is the scale constant ... Result: %3 = conv(%1) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_pad_conv[source]
When we observe
pad -> transpose -> conv
, we move thepad
to be next toconv
. This allows us to meldpad + conv
if possible.Given: %1 = pad(%0, ...) %2 = transpose(%1, ...) %3 = conv(%2, ...) ... Result: %1.a = transpose(%0, ...) $2.a = pad(%1.a, ...) %3 = conv(%2.a) ...
optimize_elementwise_binary
- class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.divide_to_multiply[source]
Convert divide into multiply if the divisor is
const
.
- class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.fuse_elementwise_to_batchnorm[source]
Fold
mul
+add
into abatchnorm
if theconst
feeding into themul
/add
is of shape(1,C,1,1)
or(C,1,1)
and input tomul
is of rank 4.Given: [Const] [Const] | | V V [...] --> [Mul] --> [Add] --> [...] That is, %2 = op1(%1) %3 = mul(%2, constant) %4 = add(%3, constant) %5 = op2(%4) ... Result: [...] --> [BatchNorm] --> [...] That is, %2 = op1(%1) %4 = batchnorm(%2) %5 = op2(%4) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.rank0_expand_dims_swap[source]
Identify the pattern of a
rank-0
binary elementwise operation followed by anexpand_dims
op. In the MIL backend, the output of theelementwise
op becomes rank 1. Hence, anexpand_dims
op should be added after both of therank-0
tensors, and the finalexpand_dims
should be removed. If the output var of the binary elementwise op is consumed by more than one op, asqueeze
op is inserted.Input
[...](rank-0) --> sub --> expand_dims (axes=[0]) --> [...] ^ | | |--> op2 | | | |--> op3 | [scalar const]
Output
[...](rank-0) --> expand_dims (axes=[0]) --> sub --> [...] ^ | | |--> squeeze ---> op2 | | | |--> op3 | expand_dims (axes=[0]) ^ | | [scalar const]
optimize_linear
- class coremltools.converters.mil.mil.passes.defs.optimize_linear.fuse_linear_bias[source]
Convert
linear + add/sub
to a singlelinear
by updating the weight and bias of thelinear
layer.Example 1: Original: %4 = linear(x=%1, weight=%2, bias=%3) # %2 is a rank-2 const tensor (weight) # %3 is a rank-1 const tensor (bias) ... %6 = add(x=%4, y=%5) # %5 is a const tensor with same shape as %3 Result: %8 = linear(x=%1, weight=%2, bias=%7) # where %7 is a new const tensor with value # %7 = %3 + %6 Example 2: Original: %4 = linear(x=%1, weight=%2, bias=%3) # %2 is a rank-2 const tensor (weight) # %3 is a rank-1 const tensor (bias) ... %6 = sub(x=%5, y=%4) # %5 is a const tensor with a broacasable shape with %3. i.e. if %3 has shape (Dout), %5 could be (1, Dout). Result: %9 = linear(x=%1, weight=%7, bias=%8) # where %7 is a new const tensor with value %7 = -%2 # %8 = %5 - %3
- class coremltools.converters.mil.mil.passes.defs.optimize_linear.fuse_matmul_weight_bias[source]
Convert
matmul + add/sub
tolinear
whenever possible.Given: %3 = matmul(x=%1, y=%2) # %1 or %2 is const and rank 2 (weight) ... %5 = add(x=%3, y=%4) # %4 is const. add(x=%4, y=%3) is equivalent # sub is similar. Result: # assuming %2 above is const and rank 2 %5 = linear(x=%1, weight=%2, bias=%4)
optimize_normalization
- class coremltools.converters.mil.mil.passes.defs.optimize_normalization.fuse_layernorm_or_instancenorm[source]
A graph optimization pass on PyMIL to detect and fuse several variants of
layer_norm
orinstance_norm
. Pattern 1 corresponds to eitherlayer_norm
orinstance_norm
. Patterns 2-4 areinstance_norm
.Pattern 1
Identify the pattern:
y = gamma * (x - mean) / sqrt(variance + epsilon) + beta
y = x * [gamma * rsqrt(variance + eps)] + (beta - mean * [gamma * rsqrt(variance + eps)])
x --> reduce_mean --> sub --> square --> reduce_mean --> add(epsilon) --> rsqrt | | ^ | | | | V |----------------------- mul (gamma) | | | | | --------|--------- | | | | | | | V | |----------------------------------------------------------------> mul | | | | V | |--------------------------------------------------------------> mul | | V | sub (beta) --> add --> [...] | ^ |-------------------------------
This pattern corresponds to either
layer_norm
orinstance_norm
.- It is
instance_norm
if all of the following are true: input
is rank 4.axes
ofreduce_mean
is[-2, -1]
or[-3, -2]
(when[-3, -2]
, a channel first to channel last transpose would be inserted).gamma
andbeta
are rank 1, aftersqueeze
.
- It is
layer_norm
if all of the following are true: axes
is either[-1]
,[-1, -2]
, or[-1, -2, -3]
, and so on.rank
ofgamma
andbeta
is equal to the length of theaxes
.
Pattern 2
Identify the pattern:
y = (x - mean) / pow(variance + epsilon) * gamma + beta
This pattern corresponds to, and should be fused as,
instance_norm
.- All of the following conditions must be satisfied:
input
is rank 4 tensor.reduce
operates on spatial dimensionsaxes=[-2, -1]
, oraxes=[-3, -2]
(a channel first to channel last transpose would be inserted in such cases).gamma
andbeta
are both shape(C,)
aftersqueeze
, whereC
is number of channels.
|----> sub -----| const (0.5) | ^ | | | | V V x ---> mean square --> mean1 --> add_eps ---> pow const_gamma const_beta | | | | | | V V V V |----> sub1 --------------------------------> real_div --> mul_gamma --> add_beta --> ...
Pattern 3
Detect
InstanceNorm
pattern in TensorFlow-Addons.This pattern corresponds to, and should be fused as,
instance_norm
.- All of the following conditions must be satisfied:
input
is rank 4 tensor.reduce
operates on spatial dimensionsaxes=[-2, -1]
, oraxes=[-3, -2]
(a channel first to channel last transpose would be inserted in such cases).gamma
andbeta
are absent. Default values forgamma
andbeta
would be used.
|-------------------------------------------------| | | | V x --> mean square --> mean1 --> add_eps --> rsqrt --> mul2 --> mul_sub | | ^ | | | V | | | | --> sub -----| | | | V V |--------------------------------------------> mul1 -------------> add --> ...
Pattern 4
Identify the pattern:
y = x * [gamma * rsqrt(variance + eps)] + (beta - mean * [gamma * rsqrt(variance + eps)])
This pattern corresponds to, and should be fused as,
instance_norm
.- All of the following conditions must be satisfied:
input
is rank 4 tensor.reduce
operates on spatial dimensionsaxes=[-2, -1]
oraxes=[-3, -2]
(a channel first to channel last transpose would be inserted in such cases).gamma
andbeta
are both shape(C,)
aftersqueeze
, whereC
is number of channels.
|-----------| | V |------> mul_square1 -----> sum1 -----> mul_mean1 | | | V x --> sum --> mul_mean ==> mul_square --> sub_variance --> add_eps --> rsqrt | | | | | V | | mul_gamma | | | | | |----------------| | | | V | |--------------------------------------------+-------------> mul2 | V | |----------------------------------------------------------> mul1 | | V | sub_beta --> add --> [...] | ^ |---------------------------|
- It is
optimize_repeat_ops
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.cast_optimization[source]
This optimization pass performs the following:
Removes redundant
cast
op; that is,cast
where source and destination tensors have same dtypes.Either cancels or fuses any two consecutive cast ops, repeatedly.
After this pass, there can’t be any consecutive cast ops present in the program. For examples, see
TestCastOptimization
. This is a non-algebraic translation which assumes that the upcasting doesn’t change the user’s intent.For example:
Input graph: input -----> cast(dtype="fp16") -----> cast(dtype="fp32") ----> square ---> out Output graph: input -----> square -----> out
The input graph has a maximum precision of fp16 while the output graph has fp32 precision.
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_paddings[source]
Identify two consecutive
pad
layers which could be merged into a singlepad
layer.This is possible only if one of the following conditions is satisfied:
The paddings are “constant” and have the same
constant_val
.The paddings act along different axes.
Input graph: input(1, 2, 6, 8) ------> pad([1, 1], mode='reflect) -----> pad([1, 1, 0, 0], mode='reflect') ---> out(1, 2, 8, 10) Output graph: input(1, 2, 6, 8) ------> pad([1, 1, 1, 1], mode='reflect) ---> out(1, 2, 8, 10)
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_relus[source]
Identify consecutive
relu
layers which could be merged into a singlerelu
layer.Input graph: input ------> relu -----> 1 or more relu layers ---> out Output graph: input ------> relu ---> out
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_reshapes[source]
Identify consecutive
reshape
ops which could be merged into a singlereshape
.Input graph: input -> reshape -> 1 or more reshapes -> output Output graph: input -> reshape -> output
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_transposes[source]
Identify consecutive ‘transpose’ layers which could be merged into a single ‘transpose’ layer.
Input graph: input ------> transpose -----> 1 or more transpose layers ---> out Output graph: input ------> transpose ---> out
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.reduce_transposes[source]
Reduce transposes when it is applicable. For example:
# Example 1 Input graph: input -----> transpose(axis=[1,0]) -----> transpose(axis=[1,0]) ---> out Output graph: input -----> identity -----> out # Example 2 Input graph: input---->transpose(axis=[0,3,1,2])---->relu---->transpose(axis=[0,2,3,1])--->out Output graph: input----->relu----->out # Example 3 Input graph: input(shape=10,2,3,5)--->transpose(axis=[0,2,3,1])----->relu---->pool----->out1 | | --->relu----->log---->transpose(axis=[0,3,1,2])---->out2 Output graph: input(shape=10,2,3,5)----->relu---->transpose(axis=[0,2,3,1])---->pool----->out1 | | --->relu----->log---->out2
Please see
TransposeOptimizationPass
for more details.Notes
This pass is divided into 3 phases:
1st phase: Information gathering.
Plug in Identity ops for all output nodes. This allows us to treat all ops uniformly during traversal.
Block is traversed in the topological order, starting from the ops connected to the inputs.
During the traversal, a value is associated with every var in the block. This value can be either of type
_HypotheticalValue
or_LazyTransposeHypotheticalValue
. The main purpose of type_HypotheticalValue
is to indicate that it is not of type_LazyTransposeHypotheticalValue
._LazyTransposeHypotheticalValue
represents either one or multiple transpose ops with the same perm value. This information is stored in this class. It also wraps a_HypotheticalValue
that was the last hypothetical value which was generated prior to the origin of_LazyTransposeHypotheticalValue
.Each op decides which type of hypothetical value to associate with its output vars, based on its op type, attributes, and the types of the hypothetical values of its input vars.
Ops are classified into 4 categories: unary like, axis update, transpose, and materialize (for all the rest).
- Transpose ops are the ops from which a
_LazyTransposeHypotheticalValue
originate. If the input to it is a
_HypotheticalValue
, its output will be a_LazyTransposeHypotheticalValue
, indicating that thistranspose
op is available to get cancelled downstream.- If the input to it is a
_LazyTransposeHypotheticalValue
, then it is checked whether this op cancels it or not. If the op cancels it, a
_HypotheticalValue
value is generated at the output and the information about thistranspose
cancellation is recorded in the dictionarytranspose_op_to_cancel_ops
.If the op does not cancel, the current
transpose
op is categrorized as a materialize op. Therefore, the information in dictionarytranspose_op_to_materialize_ops
is updated accordingly. The output of the op is now mapped to a_HypotheticalValue
.
- If the input to it is a
- Transpose ops are the ops from which a
Unary like ops: These simply transfer their input hypothetical value type to the output.
Axis update ops: If a
transpose
can pass through them, they are treated like a unary op and the dictionarytranspose_op_to_axis_update_ops
is updated. If the op cannot be updated in any manner to allow atranspose
to pass through, this op is then categorized as a materialize op and handled accordingly.Materialize ops: All
_LazyTransposeHypotheticalValue
input vars, if present, materialize here. Output of this op is always of type_HypotheticalValue
. If the input is a_LazyTransposeHypotheticalValue
, update the dictionarytranspose_op_to_materialize_ops
.To treat an op like a unary op, add its type to
_UNARY_LIKE_OP_TYPES
. In future changes we want to make this process automatic by detecting an op as a unary like by its “traits”.To treat an op like axis update op, add a class specific to the op implementing the class
TransformAxisUpdateOps
. For examples, see classes_TransformConcat
,_TransformPad
, and so on. The dictionaryAXIS_UPDATE_OPS
is automatically filled in by the decorator_TransposeOptimization.register_axis_update_op
.
2nd phase: Determining which
transpose
ops to remove from the graph.All
transpose
ops that have a corresponding compliment op in dicttranspose_op_to_cancel_ops
is a candidate. However, you need to ensure the following:If a
transpose
op is removed, then all of itscancel
ops intranspose_op_to_cancel_ops
must also be removed, to ensure correctness of the graph. The same is true in the reverse direction as well; that is, for everycancel
op that is removed, all its parenttranspose
ops upstream must also be removed.transpose
ops should be removed only if the number ofcancel
ops is greater than the number oftranspose
ops that would get freshly introduced to the block as a result of materialization ops. Currently in the algorithm, each materialization op/output var (dictstranspose_op_to_materialize_ops
/old_output_vars
) results in one moretranspose
op, although this can be further optimized in the future.
To resolve this, we recognize that nodes consisting of sets
(a)
and(b)
form a bipartitle graph, where,(a) ==
startingtranspose
ops (originators of_LazyTransposeHypotheticalValue
) and(b) ==
set oftranspose
cancel
ops andmaterialize
ops.In this bipartite graph, we find all the connected components for each connected component. Either the entire set of
transpose
ops in it are removed/materialized, or none of them are touched.Thus for each set, a determination is made based on counting the number of
cancel
ops andmaterialize
ops.Based on this determination, the final set of
transpose
ops to be removed is updated.
3rd phase: Transforming the graph.
transpose
starting ops and thecancel
ops are removed.Axis update ops, affected by these
transpose
ops, are updated.Transposes are materialized; that is, added just before the
materialize
ops, which are linked to the startingtranspose
ops. The startingtranspose
op can be materialized (inserted) multiple times, before each of thematerialize
ops downstream.Block outputs are handled in a similar fashion as the materialize ops.
Type inference on all ops is invoked after all the transformations.
All Identity ops that are plugged into the graph to treat outputs as materialized are removed.
Debugging
If the
debug
flag is set toTrue
, the block before and after the transformation is plotted, with transpose nodes highlighted.
optimize_tensor_operation
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.concat_to_pixel_shuffle[source]
Identify nested, interleaved
concat
ops which can be replaced by a singleconcat
and a pixel shuffle layer.This pattern occurs with the faster up-convolution from the FCRN model (Laina et al., 2016).
# Before the concat_to_pixel_shuffle pass. input(N, C, H, W) ------------------- | v input(N, C, H, W) -----> concat(axis=2, interleave=True) -----> concat(axis=3, interleave=True) ----> output ^ | input(N, C, H, W) -----> concat(axis=2, interleave=True) -------------------- | ^ | | input(N, C, H, W) ------------------- # After the concat_to_pixel_shuffle pass. input(N, C, H, W) --------------- | v input(N, C, H, W) -----> concat(axis=1, interleave=True) -----> pixel_shuffle(upscale_factor=2) ----> output ^ | input(N, C, H, W) --------------| | | input(N, C, H, W) ---------------
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.detect_concat_interleave[source]
Detect the pattern
concat-->reshape--->transpose--->reshape
, whereconcat
is along the channel axis(axis=-3)
, and map this pattern to theconcat
withinterleave
op.This pattern occurs, for example, in the
shufflenet
model intorchvision
.Given: %3 = concat(%1.a, %1.b, ..., axis=-3, interleave=False) #shape = (B, n*C, H, W) %4 = reshape(%3) #shape = (B, n, C, H, W) %5 = transpose(%4, perm=[0, 2, 1, 3, 4]) # shape = (B, C, n, H, W) %6 = reshape(%5) # shape = (B, C*n, H, W) Result: %6 = concat(%1.a, %1.b, ..., axis=-3, interleave=True)
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.fuse_onehot_matmul_to_gather[source]
Detect if
onehot (axis=-1, on_value=1, off_value=0)
is followed by amatmul
op (no bias). If so, they can be replaced by agather
op.Input: %2 = one_hot(%1, on_value=1, off_value=0, axis=-1) %3 = const() # rank 2 %4 = matmul(%2, %3) Output: %4 = gather(%3, %2, axis=0)
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.replace_stack_reshape[source]
A stack followed by a reshape layer can be replaced by a
concat
if the reshape simply removes the new axis and doubles the size of one of the axes next to it.If the new axis is reshaped to the “right” (that is, the axis just after it is doubled), then we can use a
concat
. If it is reshaped to the “left” (the axis just before it is doubled), then theconcat
needs to set theinterleaved
flag.Examples:
Given: %1 = tensor(1, 5, 3, 4) %2 = tensor(1, 5, 3, 4) %3 = stack((%1,%2), axis=2) # shape = (1, 5, 2, 3, 4) %4 = reshape(%3, shape=[1, 10, 3, 4]) Result: %1 = tensor(1, 5, 3, 4) %2 = tensor(1, 5, 3, 4) %4 = concat((%1,%2), axis=1, interleave=True) # shape = (1, 10, 3, 4) Given: %1 = tensor(1, 5, 3, 4) %2 = tensor(1, 5, 3, 4) %3 = stack((%1, %2), axis=1) # shape = (1, 2, 5, 3, 4) %4 = reshape(%3, shape=[1, 10, 3, 4]) Result: %1 = tensor(1, 5, 3, 4) %2 = tensor(1, 5, 3, 4) %4 = concat((%1, %2), axis = 1) # shape = (1, 10, 3, 4)
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.use_reflection_padding[source]
Identify a reflection padding layer composed out of slices and concats.
Input graph: ------------------------------------------------------------------------------------- | | v input(1, 2, 6, 8) ------> slice_by_index(begin=[0, 0, 0, 1], end=[0, 0, 0, 2]) -----> concat(axis=3) ---> out(1, 2, 6, 10) | ^ ----------------> slice_by_index(begin=[0, 0, 0, -2], end=[0, 0, 0, -1]) -------------| Output graph: input(1, 2, 6, 8) -----0> pad(mode=reflect, size=[0, 0, 1, 1]) -----> out(1, 2, 6, 10)
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.expand_high_rank_reshape_and_transpose[source]
Detect the pattern
reshape_1-->transpose-->reshape_2
, wherereshape_1
has a output tensor with rank >= 6, and the reshape_2 produces a tensor with rank <= 5.In general, we can expand this pattern into a sequence of rank 4
reshape
andtranspose
ops, which is supported by Core ML runtime.Given: %1 = reshape(%x, shape=(d1, d2, d3, d4, ..., dn)) %2 = transpose(%1, perm=(p1, p2, ..., pn)) %3 = reshape(%2, shape=(o1, o2, o3, o4, o5)) Result: %t1 = reshape(%x, shape=(y11, y12, y13, y14)) %h1 = transpose(%t1, perm=[0, 2, 1, 3]) %t2 = reshape(%h1, shape=(y21, y22, y23, 214)) %h2 = transpose(%t2, perm=[0, 2, 1, 3]) .... %hn = transpose(%tn, perm=[0, 2, 1, 3]) %3 = reshape(%hn, shape=(o1, o2, o3, o4, o5))
preprocess
- class coremltools.converters.mil.mil.passes.defs.preprocess.image_input_preprocess[source]
Plug in to
transpose
image input in NHWC format to NCHW format.Follow these steps:
Check whether there are any inputs that the users specify as ImageType.
Check the channel’s dimension for all inputs that are ImageType.
channel_first == True
We do not modify this input, sincechannel_first
is the intended behaviour for feeding images for optimal performance.channel_first == False
We convert the input into a “channel_first” input, and plug in atranspose
for the input to maintain the remaining graph’s dimensionality.
- class coremltools.converters.mil.mil.passes.defs.preprocess.sanitize_input_output_names[source]
Sanitize the names of model input and output vars to make sure that they are of the format as described in the NameSanitizer class; that is, of the format
[a-zA-Z_][a-zA-Z0-9_]*
.
- class coremltools.converters.mil.mil.passes.defs.preprocess.update_output_dtypes[source]
Update the dtypes of output vars of the main block to match the dtypes provided in
prog.main_output_types
, which in turn is populated by theoutputs
argument provided by the user in thecoremltools.convert()
API. This graph pass assumes that the list of outputs inprog.main_output_types
(if notNone
), are in the same order as the output vars.
quantization
- class coremltools.converters.mil.mil.passes.defs.quantization.add_fp16_cast(op_selector=None)[source]
For each input of dtype float32, inject a
cast
op to change it to float16 dtype.For each output of dtype float16, inject a
cast
op to change it back to float32.This pass is the registered interface for FP16ComputePrecision, which makes it consistent with other passes’ interfaces.
Support options:
skip_ops_by_type
: Skip op types specified by comma-separated string; for example,"mul,const"
.