MIL Graph Passes

Graph Passes supported by the Model Intermediate Language (MIL):

cleanup

class coremltools.converters.mil.mil.passes.defs.cleanup.const_elimination[source]

Replace non-const ops that have const Var. Outputs are replaced with the const op. Example:

Given:
    %2, %3 = non_const_op(...)  # %2 is const, %3 isn't const
    %4 = other_op(%2, %3)

Result:
    _, %3 = non_const_op(...)  # _ is the ignored output
    %2_const = const()         # %2_const name is for illustration only
    %4 = other_op(%2_const, %3)

Support options:

  • skip_const_by_size: Skip folding const ops that have larger number of elements than a threshold.

class coremltools.converters.mil.mil.passes.defs.cleanup.dead_code_elimination[source]

Eliminate unused ops in program. Ops whose outputs do not contribute to final outputs will be deleted.

# Before dead_code_elimination pass.
main(%x: (2, 4, fp32)) {
  block0() {
    %const_2: (4, 2, fp32)* = const(val=[...])
    %const_3: (4, fp32)* = const(val=[...])
    %tx_0: (bool)* = const(val=False)
    %ty_0: (bool)* = const(val=False)
    %matmul_0: (2, 2, fp32) = matmul(x=%x, y=%const_2, transpose_x=%tx_0, transpose_y=%ty_0)
    %linear_0: (2, 4, fp32) = linear(x=%x, weight=%const_2, bias=%const_3)
  } -> (%linear_0)
}

# After dead_code_elimination pass.
main(%x: (2, 4, fp32)) {
  block0() {
    %const_2: (4, 2, fp32)* = const(val=[...])
    %const_3: (4, fp32)* = const(val=[...])
    %linear_0: (2, 4, fp32) = linear(x=%x, weight=%const_2, bias=%const_3)
  } -> (%linear_0)
}

In the example above, %matmul_0 is an op that is not used in the computation. This op and its input ops (%tx_0 and %ty_0) are eliminated in this pass.

class coremltools.converters.mil.mil.passes.defs.cleanup.dedup_op_and_var_names[source]

For each function, this pass renames ops and variables with the same name as any preceding ops/variables across all scopes in the given function, where the precedence is implementation-specific. Note that an op name and variable names are tracked separately, so an op may have the same name as a variable.

The pass preserves input and output name. Raises ValueError if we cannot dedup without changing the input/output var names.

def prog(x):
    x = mb.cast(x=x, dtype="fp16", name="castop")
    x = mb.cast(x=x, dtype="fp32", name="castop")
    x = mb.square(x=x, name="square_last")
    return x

# Before dedup pass, the op names are ["castop", "castop", "square_last"].
# After dedup pass, the op names are ["castop", "castop_1", "square_last"].
class coremltools.converters.mil.mil.passes.defs.cleanup.fuse_reduce_mean[source]

Detect the reduce_sum —> mul/real_div pattern than can be mapped to reduce_mean. That is, the operation reduce_sum/count == reduce_mean.

Input graph

                            const (scalar)
                                |
input ----> reduce_sum ----> mul/real_div -----------> output

Output graph

input --------> reduce_mean ---------> output
class coremltools.converters.mil.mil.passes.defs.cleanup.loop_invariant_elimination[source]

When a block does not modify a block input var, eliminate that block input var and use the corresponding var in the outer scope. Example:

# Before loop_invariant_elimination pass.
# Notice that ``%b.x`` is constant through while loop iterates.
main(%a: (1, 2, fp32),
     %b: (1, 2, fp32)) {
  block0() {
    %loop:0: (1, 2, fp32), %loop:1: (1, 2, fp32) =            while_loop(loop_vars=(%a, %b))
      loop_cond(%a.x, %b.x) {
        %cond_var: (bool) = some_op(x=%a.x, y=%b.x)
      } -> (%cond_var)
      loop_body(%a.x, %b.x) {
        %add_0: (1, 2, fp32) = add(x=%a.x, y=%b.x)
      } -> (%add_0, %b.x)
  } -> (%loop:0, %loop:1)
}

# After loop_invariant_elimination pass.
main(%a: (1, 2, fp32),
     %b: (1, 2, fp32)) {
  block0() {
    %loop:1: (1, 2, fp32) = identity(x=%b)
    %loop:0: (1, 2, fp32) =            while_loop(loop_vars=(%a))
      loop_cond(%a.x) {
        %cond_var: (bool) = some_op(x=%a.x, y=%b)
      } -> (%cond_var)
      loop_body(%a.x) {
        %add_0: (1, 2, fp32) = add(x=%a.x, y=%b)
      } -> (%add_0)
  } -> (%loop:0, %loop:1)
}

where we eliminate loop invariant %b.x from while_loop, which returns 1 instead of 2 outputs. We also preserve the return var names with identity.

class coremltools.converters.mil.mil.passes.defs.cleanup.noop_elimination[source]

Remove ops that have no effect.

Given:
    %1 (1, 96, 128, 64, fp32) = ...
    %2 (1, 96, 128, 64, fp32) = reshape(%1)
    ...
    %3 (1, 96, 128, 64, fp32) = add(%2, constant)
    ...

Result:
    %1 (1, 96, 128, 64, fp32) = ...
    %3 (1, 96, 128, 64, fp32) = add(%1, constant)
...
class coremltools.converters.mil.mil.passes.defs.cleanup.remove_redundant_ops[source]

If there are multiple ops with “identical” inputs, then they are redundant and all but one of them can be removed. This pass checks and removes such ops.

Since all inputs to ops in MIL are named, two ops with same op_types can be compared by comparing their correspondingly named inputs. Inputs are treated as identical if one of the following is true:

  • The input is a constant var, in which case its value should have the same dtype and numerical value.

  • The input is a non constant var, in which case it should be the same var object.

This pass iterates over the ops, takes its first output var, and then builds a candidate op list from the child ops of this var. This candidate ops list contains ops of the same op_type, arranged in topological order. From each of these candidate ops in the list, the second, third, and subsequent ops are pairwise compared with the first op, and if identical to it, they are removed. For example:

Input:
    %0 = op0(...)
    %1 = op1(...)
    %2 = const(val=4.5)
    %3 = const(val=4.5)
    %4 = op2(%1, %0, %2)
    %5 = op3(%1, %0, %3)

Output:
    %0 = op0(...)
    %1 = op1(...)
    %2 = const(val=4.5)
    %3 = const(val=4.5) # this will get removed later by dead code elimination pass
    %4 = op2(%1, %0, %2)

In the example above, op3 is removed and all uses of %5 is replaced by %4. For more examples, see “TestRemoveRedundantOpsPass”.

class coremltools.converters.mil.mil.passes.defs.cleanup.remove_symbolic_reshape[source]

Convert symbolic shape in reshape to integers.

Note: This does not perform any optimization, but simply replaces symbols with positive integers if solved from volumetric constraint, or -1. Therefore, this pass fails if more than one symbol needs to be resolved to -1.

# Before remove_symbolic_reshape pass.
main(%x: (s0, 4, fp32)) {
  block0() {
    %reshape_0_shape_0: (3,i32)^ = const(val=(s0, s1, 2))
    %reshape_0: (s0, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0)
  } -> (%reshape_0)
}

# After remove_symbolic_reshape pass.
main(%x: (s0, 4, fp32)) {
  block0() {
    %reshape_0_shape_0x: (3,i32)* = const(val=[-1, 2, 2])
    %reshape_0: (-1, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0x)
  } -> (%reshape_0)
}

TODO (rdar://59165842): Use expand_dims, squeeze etc to use 0 instead of dynamic reshape with -1.

class coremltools.converters.mil.mil.passes.defs.cleanup.topological_reorder[source]

Topologically re-orders the list of operations in a program by places each operation closer to its first use, or at the end if it’s not consumed by any other operation.

Currently, This pass re-orders only Transpose and Cast operations.

# Example: input program
main(x: (2, 4, fp32)) {
    x = mb.cast(x=x, dtype="fp16")
    x1 = mb.square(x=x)
    x1_t = mb.transpose(x=x1, perm=[1, 0])
    x2 = mb.cast(x=x1_t, dtype="fp32")
    x3 = mb.log(x=x)
    x3_t = mb.transpose(x=x3, perm=[1, 0])
    x4 = mb.cast(x=x3_t, dtype="fp32")
    x5 = mb.relu(x=x)
    x6 = mb.cast(x=x5, dtype="fp32")
    x7 = mb.relu(x=x6)
    x8 = mb.relu(x=x)
} -> x2, x4, x7, x8

# After moving `cast` ops becomes
main(x: (2, 4, fp32)) {
    x = mb.cast(x=x, dtype="fp16")
    x1 = mb.square(x=x)
    x1_t = mb.transpose(x=x1, perm=[1, 0])
    x3 = mb.log(x=x)
    x3_t = mb.transpose(x=x3, perm=[1, 0])
    x5 = mb.relu(x=x)
    x6 = mb.cast(x=x5, dtype="fp32")
    x7 = mb.relu(x=x6)
    x8 = mb.relu(x=x)
    x4 = mb.cast(x=x3_t, dtype="fp32")
    x2 = mb.cast(x=x1_t, dtype="fp32")
} -> x2, x4, x7, x8

# After moving `transpose` ops becomes
main(x: (2, 4, fp32)) {
    x = mb.cast(x=x, dtype="fp16")
    x1 = mb.square(x=x)
    x3 = mb.log(x=x)
    x5 = mb.relu(x=x)
    x6 = mb.cast(x=x5, dtype="fp32")
    x7 = mb.relu(x=x6)
    x8 = mb.relu(x=x)
    x3_t = mb.transpose(x=x3, perm=[1, 0])
    x4 = mb.cast(x=x3_t, dtype="fp32")
    x1_t = mb.transpose(x=x1, perm=[1, 0])
    x2 = mb.cast(x=x1_t, dtype="fp32")
} -> x2, x4, x7, x8

optimize_activation

class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_gelu_exact[source]

Identify the pattern that corresponds to the exact version of gelu, and replace it with a single gelu layer with mode=EXACT. The pattern is y = 0.5 * x * (1 + erf (x / srqt (2)), which can be represented by one of the following:

(1)
    [...] ----> div (1.414) ---> erf ---> add (1) -----> mul (0.5) ---> mul ---> [...]
      |                                                                  ^
      |                                                                  |
      |-------------------------------------------------------------------

(2)
    [...] ----> div (1.414) ---> erf ---> add (1) -----> mul ---> mul (0.5) ---> [...]
      |                                                   ^
      |                                                   |
      |----------------------------------------------------

(3)
    [...] ----> div (1.414) ---> erf ---> add (1) -----> mul ------> [...]
      |                                                   ^
      |                                                   |
      |---------------> mul(0.5) --------------------------

All of them are converted to:
    [...] ----> gelu (mode=EXACT) ---> [...]
class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_gelu_tanh_approximation[source]

Identify the pattern that corresponds to the tanh approximate version of gelu, and replace it with a single gelu layer with mode=TANH_APPROXIMATION.

The implementation of this pass uses the generic graph pattern matching and transform algorithm implemented in coremltools.converters.mil.experimental.passes.generic_pass_infrastructure and documented in coremltools/converters/mil/experimental/passes/readme.md.

Graph for get_gelu_pattern1()

y = x * (0.5 * (tanh(((.0447)x^3 + x ) * sqrt(2/pi)) + 1))

[...] -----> pow (3) ----> mul (.044715) ---> add -----> mul (sqrt(2/pi)) ---> tanh ----> add (1) ----> mul (0.5) -----> mul ---> [...]
  |                                            ^                                                                          ^
  |                                            |                                                                          |
  |------------------------------------------------------------------------------------------------------------------------

Graph for get_gelu_pattern2()

y = (0.5 * x) * (tanh(((.0447)x^3 + x ) * sqrt(2/pi)) + 1)

               --------------------------------------------------------------------------------------------------------
               ^                                                                                                      |
               |                                                                                                      V
[...] -----> mul(0.5)    pow (3) ----> mul (.044715) ---> add -----> mul (sqrt(2/pi)) ---> tanh ----> add (1) -----> mul ---> [...]
  |                        ^                               ^
  |                        |                               |
  |---------------------------------------------------------
class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_leaky_relu[source]

Detect the mul —> max pattern than can be mapped to leaky_relu.

In code form - Input

%2 = const(value = alpha) # where 0 <= alpha <= 1
%3 = mul(%1, %2) # alpha * x
%4 = max(%3, %1) # max(alpha * x, x)

In code form - Output

%4 = leaky_relu(x=%1, alpha=%2)

In graphical form - Input graph

         const (val = alpha)
             |
input ----> mul ---------------> maximum -----------> output
  |                                 |
  |----------------------------------

In graphical form - Output graph

input --------> leaky_relu ---------> output
class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_prelu[source]

Detect the following patterns that can be mapped to a prelu op. Essentially, the prelu op can be broken down into the following ops:

y = a * relu(-1 * x) + relu(x)

Pattern 1

               | ------------> relu --------------------|
               |                                        V
x (BCHW) ------|                                       add -----> y (BCHW)
               |                                        ^
               --------> mul -------> relu -----> mul---|
                          ^                        ^
                          |                        |
                     Const(val=-1)            Const(name=a, shape=(C,1,1) or (1,C,1,1))

This will be mapped to:

x (BCHW) ------> prelu(alpha=a, shape=(C,)) ---------> y (BCHW)

Pattern 2

                                | ------------> relu --------------------|
                                |                                        V
x (BCHW) -->transpose(BHWC)---->|                                       add -----> y (BHWC)
                                |                                        ^
                                --------> mul -------> relu -----> mul---|
                                           ^                        ^
                                           |                        |
                                  Const(val=-1)    Const(shape=(C,) or (1,C) or (1,1,C) or (1,1,1,C))

This will be mapped to:

x (BCHW) ------> prelu ---------> transpose ------> y (BHWC)
class coremltools.converters.mil.mil.passes.defs.optimize_activation.prelu_to_lrelu[source]

If prelu has the same leakage factor across all channels, it will be converted to leaky_relu.

optimize_conv

class coremltools.converters.mil.mil.passes.defs.optimize_conv.add_conv_transpose_output_shape[source]

The conv_transpose input output_shape is an optional input. Since we can infer the output shape from type_inference, we add output_shape input whenever it is known to be constant at compile time. For example:

Given:
  %1: (1, 5, 39, fp32) = conv_transpose(...) # no output_shape input.

Result:
  %2: (3, i32) = const(val=[1,5,39])
  %3: (1, 5, 39, fp32) = conv_transpose(..., output_shape=%2)
class coremltools.converters.mil.mil.passes.defs.optimize_conv.compose_conv1d[source]

In TensorFlow, tf.keras.layers.Conv1D is a composite op:

expand a dummy dim -> Conv2D -> squeeze the dummy dim

In PyTorch, this is also true for some backends (mkldnn and xpu).

This decomposition wrecks the coremltools conv1d graph passes, so we should recompose the fragments back to MIL conv, which natively supports conv1d:

Pattern 1:
    Given:
        %2 = expand_dims(%1, axes=-2) or expand_dims(%1, axes=2), %1.rank = 3
        %3 = conv(%2)
        %4 = squeeze(%3, axes=-2) or squeeze(%3, axes=2)
        ...

    Result:
        %4 = conv(%1)
        ...

Pattern 2 (TensorFlow channel_last):
    Given:
        %2 = expand_dims(%1, axes=-3) or expand_dims(%1, axes=1), %1.rank = 3
        %3 = transpose(%2, perm=(0, 3, 1, 2))
        %4 = conv(%3)
        %5 = transpose(%4, perm=(0, 2, 3, 1))
        %6 = squeeze(%5, axes=-3) or squeeze(%5, axes=1)
        ...

    Result:
        %3 = transpose(%1, perm=(0, 2, 1))
        %4 = conv(%3)
        %6 = transpose(%4, perm=(0, 2, 1))
        ...
class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_conv_batchnorm[source]

Fuse the following batch_norm layer into conv and conv_transpose. That is, convert conv + batch_norm to conv, by modifying the weight and bias in the conv layer.

Given:
    %2 = conv(%1)
    ...
    %3 = batch_norm(%2)
    ...

Result:
    %3 = conv(%1)
    ...
class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_conv_bias[source]

Fold add/sub into bias of conv and conv_transpose. That is, convert conv + add/sub to conv, when add/sub is adding a constant.

Two patterns are supported:

Pattern 1:
Given:
    %2 = conv(%1)
    ...
    %3 = add(%2, constant) # where constant has shape (1,C,1)/(C,1) for 1d conv, (1,C,1,1)/(C,1,1) for 2d conv etc
    ...

Result:
    %3 = conv(%1)
    ...

Pattern 2:
Given:
    %2 = conv(%1)
    %3 = transpose(%2)
    ...
    %4 = add(%3, constant) # where constant has a broacasable shape
    ...

Result:
    %2 = conv(%1)
    %4 = transpose(%2)
    ...
class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_conv_scale[source]

Fold mul/div into conv/conv_transpose by updating the weight/bias of the convolution layers.

The scale const can be a single number (scalar) or a vector with a broadcastable shape. For example, if the output of the conv/deconv layer is (B, Cout, H, W), const of shape (Cout, 1, 1) and (1, Cout, 1, 1) are allowed.

Given:
    %2 = conv(%1)
    ...
    %3 = mul(%2, constant) # where constant is the scale constant
    ...

Result:
    %3 = conv(%1)
    ...
class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_pad_conv[source]

When we observe pad -> transpose -> conv, we move the pad to be next to conv. This allows us to meld pad + conv if possible.

Given:
    %1 = pad(%0, ...)
    %2 = transpose(%1, ...)
    %3 = conv(%2, ...)
    ...

Result:
    %1.a = transpose(%0, ...)
    $2.a = pad(%1.a, ...)
    %3 = conv(%2.a)
    ...

optimize_elementwise_binary

class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.divide_to_multiply[source]

Convert divide into multiply if the divisor is const.

class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.fuse_elementwise_to_batchnorm[source]

Fold mul + add into a batchnorm if the const feeding into the mul/add is of shape (1,C,1,1) or (C,1,1) and input to mul is of rank 4.

Given:
         [Const]   [Const]
            |         |
            V         V
[...] --> [Mul] --> [Add] --> [...]

That is,

    %2 = op1(%1)
    %3 = mul(%2, constant)
    %4 = add(%3, constant)
    %5 = op2(%4)
    ...

Result:

[...] --> [BatchNorm] --> [...]

That is,
    %2 = op1(%1)
    %4 = batchnorm(%2)
    %5 = op2(%4)
    ...
class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.rank0_expand_dims_swap[source]

Identify the pattern of a rank-0 binary elementwise operation followed by an expand_dims op. In the MIL backend, the output of the elementwise op becomes rank 1. Hence, an expand_dims op should be added after both of the rank-0 tensors, and the final expand_dims should be removed. If the output var of the binary elementwise op is consumed by more than one op, a squeeze op is inserted.

Input

[...](rank-0) --> sub --> expand_dims (axes=[0]) --> [...]
                   ^   |
                   |   |--> op2
                   |   |
                   |   |--> op3
                   |
             [scalar const]

Output

[...](rank-0) --> expand_dims (axes=[0]) --> sub --> [...]
                                              ^   |
                                              |   |--> squeeze ---> op2
                                              |                |
                                              |                |--> op3
                                              |
                                        expand_dims (axes=[0])
                                              ^
                                              |
                                              |
                                        [scalar const]

optimize_linear

class coremltools.converters.mil.mil.passes.defs.optimize_linear.fuse_linear_bias[source]

Convert linear + add/sub to a single linear by updating the weight and bias of the linear layer.

Example 1:
    Original:
        %4 = linear(x=%1, weight=%2, bias=%3) # %2 is a rank-2 const tensor (weight)
                                              # %3 is a rank-1 const tensor (bias)
        ...
        %6 = add(x=%4, y=%5) # %5 is a const tensor with same shape as %3

    Result:
        %8 = linear(x=%1, weight=%2, bias=%7) # where %7 is a new const tensor with value
                                              # %7 = %3 + %6

Example 2:
    Original:
        %4 = linear(x=%1, weight=%2, bias=%3) # %2 is a rank-2 const tensor (weight)
                                              # %3 is a rank-1 const tensor (bias)
        ...
        %6 = sub(x=%5, y=%4) # %5 is a const tensor with a broacasable shape with %3.
                               i.e. if %3 has shape (Dout), %5 could be (1, Dout).

    Result:
        %9 = linear(x=%1, weight=%7, bias=%8) # where %7 is a new const tensor with value %7 = -%2
                                              # %8 = %5 - %3
class coremltools.converters.mil.mil.passes.defs.optimize_linear.fuse_matmul_weight_bias[source]

Convert matmul + add/sub to linear whenever possible.

Given:
    %3 = matmul(x=%1, y=%2)  # %1 or %2 is const and rank 2 (weight)
    ...
    %5 = add(x=%3, y=%4) # %4 is const. add(x=%4, y=%3) is equivalent
                         # sub is similar.

Result:
    # assuming %2 above is const and rank 2
    %5 = linear(x=%1, weight=%2, bias=%4)

optimize_normalization

class coremltools.converters.mil.mil.passes.defs.optimize_normalization.fuse_layernorm_or_instancenorm[source]

A graph optimization pass on PyMIL to detect and fuse several variants of layer_norm or instance_norm. Pattern 1 corresponds to either layer_norm or instance_norm. Patterns 2-4 are instance_norm.

Pattern 1

Identify the pattern:

y = gamma * (x - mean) / sqrt(variance + epsilon) + beta

y = x * [gamma * rsqrt(variance + eps)] + (beta - mean * [gamma * rsqrt(variance + eps)])

x --> reduce_mean --> sub --> square --> reduce_mean --> add(epsilon) --> rsqrt
|             |        ^                                                    |
|             |        |                                                    V
|-----------------------                                              mul (gamma)
|             |                                                           |
|             |                                                   --------|---------
|             |                                                   |                |
|             |                                                   |                V
|             |---------------------------------------------------------------->  mul
|                                                                 |                |
|                                                                 V                |
|--------------------------------------------------------------> mul               |
                                                                  |                V
                                                                  |              sub (beta) --> add --> [...]
                                                                  |                              ^
                                                                  |-------------------------------

This pattern corresponds to either layer_norm or instance_norm.

It is instance_norm if all of the following are true:
  • input is rank 4.

  • axes of reduce_mean is [-2, -1] or [-3, -2] (when [-3, -2], a channel first to channel last transpose would be inserted).

  • gamma and beta are rank 1, after squeeze.

It is layer_norm if all of the following are true:
  • axes is either [-1], [-1, -2], or [-1, -2, -3], and so on.

  • rank of gamma and beta is equal to the length of the axes.

Pattern 2

Identify the pattern:

y = (x - mean) / pow(variance + epsilon) * gamma + beta

This pattern corresponds to, and should be fused as, instance_norm.

All of the following conditions must be satisfied:
  1. input is rank 4 tensor.

  2. reduce operates on spatial dimensions axes=[-2, -1], or axes=[-3, -2] (a channel first to channel last transpose would be inserted in such cases).

  3. gamma and beta are both shape (C,) after squeeze, where C is number of channels.

|----> sub -----|                            const (0.5)
|       ^       |                                |
|       |       V                                V
x ---> mean  square --> mean1 --> add_eps ---> pow       const_gamma   const_beta
|       |                                        |             |            |
|       V                                        V             V            V
|----> sub1 --------------------------------> real_div --> mul_gamma --> add_beta --> ...

Pattern 3

Detect InstanceNorm pattern in TensorFlow-Addons.

This pattern corresponds to, and should be fused as, instance_norm.

All of the following conditions must be satisfied:
  1. input is rank 4 tensor.

  2. reduce operates on spatial dimensions axes=[-2, -1], or axes=[-3, -2] (a channel first to channel last transpose would be inserted in such cases).

  3. gamma and beta are absent. Default values for gamma and beta would be used.

       |-------------------------------------------------|
       |                                                 |
       |                                                 V
x --> mean   square --> mean1 --> add_eps --> rsqrt --> mul2 --> mul_sub
|      |       ^                                |                   |
|      V       |                                |                   |
| --> sub -----|                                |                   |
|                                               V                   V
|--------------------------------------------> mul1 -------------> add --> ...

Pattern 4

Identify the pattern:

y = x * [gamma * rsqrt(variance + eps)] + (beta - mean * [gamma * rsqrt(variance + eps)])

This pattern corresponds to, and should be fused as, instance_norm.

All of the following conditions must be satisfied:
  1. input is rank 4 tensor.

  2. reduce operates on spatial dimensions axes=[-2, -1] or axes=[-3, -2] (a channel first to channel last transpose would be inserted in such cases).

  3. gamma and beta are both shape (C,) after squeeze, where C is number of channels.

|-----------|
|           V
|------> mul_square1 -----> sum1 -----> mul_mean1
|                                           |
|                                           V
x --> sum --> mul_mean ==> mul_square --> sub_variance --> add_eps --> rsqrt
|                |                                                      |
|                |                                                      V
|                |                                                  mul_gamma
|                |                                                      |
|                |                                            |----------------|
|                |                                            |                V
|                |--------------------------------------------+-------------> mul2
|                                                             V                |
|----------------------------------------------------------> mul1              |
                                                              |                V
                                                              |             sub_beta --> add --> [...]
                                                              |                           ^
                                                              |---------------------------|

optimize_repeat_ops

class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.cast_optimization[source]

This optimization pass performs the following:

  • Removes redundant cast op; that is, cast where source and destination tensors have same dtypes.

  • Either cancels or fuses any two consecutive cast ops, repeatedly.

After this pass, there can’t be any consecutive cast ops present in the program. For examples, see TestCastOptimization. This is a non-algebraic translation which assumes that the upcasting doesn’t change the user’s intent.

For example:

Input graph:
input -----> cast(dtype="fp16") -----> cast(dtype="fp32") ----> square ---> out

Output graph:
input -----> square -----> out

The input graph has a maximum precision of fp16 while the output graph has fp32 precision.

class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_paddings[source]

Identify two consecutive pad layers which could be merged into a single pad layer.

This is possible only if one of the following conditions is satisfied:

  • The paddings are “constant” and have the same constant_val.

  • The paddings act along different axes.

Input graph:
input(1, 2, 6, 8) ------> pad([1, 1], mode='reflect) -----> pad([1, 1, 0, 0], mode='reflect') ---> out(1, 2, 8, 10)

Output graph:
input(1, 2, 6, 8) ------> pad([1, 1, 1, 1], mode='reflect) ---> out(1, 2, 8, 10)
class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_relus[source]

Identify consecutive relu layers which could be merged into a single relu layer.

Input graph:
input ------> relu -----> 1 or more relu layers ---> out

Output graph:
input ------> relu ---> out
class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_reshapes[source]

Identify consecutive reshape ops which could be merged into a single reshape.

Input graph:
input -> reshape -> 1 or more reshapes -> output

Output graph:
input -> reshape -> output
class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_transposes[source]

Identify consecutive ‘transpose’ layers which could be merged into a single ‘transpose’ layer.

Input graph:
input ------> transpose -----> 1 or more transpose layers ---> out

Output graph:
input ------> transpose ---> out
class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.reduce_transposes[source]

Reduce transposes when it is applicable. For example:

# Example 1
    Input graph:
    input -----> transpose(axis=[1,0]) -----> transpose(axis=[1,0]) ---> out

    Output graph:
    input -----> identity -----> out

# Example 2
    Input graph:
    input---->transpose(axis=[0,3,1,2])---->relu---->transpose(axis=[0,2,3,1])--->out

    Output graph:
    input----->relu----->out

# Example 3
    Input graph:
    input(shape=10,2,3,5)--->transpose(axis=[0,2,3,1])----->relu---->pool----->out1
                                                       |
                                                       |
                                                       --->relu----->log---->transpose(axis=[0,3,1,2])---->out2

    Output graph:
    input(shape=10,2,3,5)----->relu---->transpose(axis=[0,2,3,1])---->pool----->out1
                           |
                           |
                           --->relu----->log---->out2

Please see TransposeOptimizationPass for more details.

Notes

This pass is divided into 3 phases:

1st phase: Information gathering.

  • Plug in Identity ops for all output nodes. This allows us to treat all ops uniformly during traversal.

  • Block is traversed in the topological order, starting from the ops connected to the inputs.

  • During the traversal, a value is associated with every var in the block. This value can be either of type _HypotheticalValue or _LazyTransposeHypotheticalValue. The main purpose of type _HypotheticalValue is to indicate that it is not of type _LazyTransposeHypotheticalValue.

  • _LazyTransposeHypotheticalValue represents either one or multiple transpose ops with the same perm value. This information is stored in this class. It also wraps a _HypotheticalValue that was the last hypothetical value which was generated prior to the origin of _LazyTransposeHypotheticalValue.

  • Each op decides which type of hypothetical value to associate with its output vars, based on its op type, attributes, and the types of the hypothetical values of its input vars.

  • Ops are classified into 4 categories: unary like, axis update, transpose, and materialize (for all the rest).

  • Transpose ops are the ops from which a _LazyTransposeHypotheticalValue originate.
    • If the input to it is a _HypotheticalValue, its output will be a _LazyTransposeHypotheticalValue, indicating that this transpose op is available to get cancelled downstream.

    • If the input to it is a _LazyTransposeHypotheticalValue, then it is checked whether this op cancels it or not.
      • If the op cancels it, a _HypotheticalValue value is generated at the output and the information about this transpose cancellation is recorded in the dictionary transpose_op_to_cancel_ops.

      • If the op does not cancel, the current transpose op is categrorized as a materialize op. Therefore, the information in dictionary transpose_op_to_materialize_ops is updated accordingly. The output of the op is now mapped to a _HypotheticalValue.

  • Unary like ops: These simply transfer their input hypothetical value type to the output.

  • Axis update ops: If a transpose can pass through them, they are treated like a unary op and the dictionary transpose_op_to_axis_update_ops is updated. If the op cannot be updated in any manner to allow a transpose to pass through, this op is then categorized as a materialize op and handled accordingly.

  • Materialize ops: All _LazyTransposeHypotheticalValue input vars, if present, materialize here. Output of this op is always of type _HypotheticalValue. If the input is a _LazyTransposeHypotheticalValue, update the dictionary transpose_op_to_materialize_ops.

  • To treat an op like a unary op, add its type to _UNARY_LIKE_OP_TYPES. In future changes we want to make this process automatic by detecting an op as a unary like by its “traits”.

  • To treat an op like axis update op, add a class specific to the op implementing the class TransformAxisUpdateOps. For examples, see classes _TransformConcat, _TransformPad, and so on. The dictionary AXIS_UPDATE_OPS is automatically filled in by the decorator _TransposeOptimization.register_axis_update_op.

2nd phase: Determining which transpose ops to remove from the graph.

All transpose ops that have a corresponding compliment op in dict transpose_op_to_cancel_ops is a candidate. However, you need to ensure the following:

  • If a transpose op is removed, then all of its cancel ops in transpose_op_to_cancel_ops must also be removed, to ensure correctness of the graph. The same is true in the reverse direction as well; that is, for every cancel op that is removed, all its parent transpose ops upstream must also be removed.

  • transpose ops should be removed only if the number of cancel ops is greater than the number of transpose ops that would get freshly introduced to the block as a result of materialization ops. Currently in the algorithm, each materialization op/output var (dicts transpose_op_to_materialize_ops/old_output_vars) results in one more transpose op, although this can be further optimized in the future.

To resolve this, we recognize that nodes consisting of sets (a) and (b) form a bipartitle graph, where, (a) == starting transpose ops (originators of _LazyTransposeHypotheticalValue) and (b) == set of transpose cancel ops and materialize ops.

  • In this bipartite graph, we find all the connected components for each connected component. Either the entire set of transpose ops in it are removed/materialized, or none of them are touched.

  • Thus for each set, a determination is made based on counting the number of cancel ops and materialize ops.

  • Based on this determination, the final set of transpose ops to be removed is updated.

3rd phase: Transforming the graph.

  • transpose starting ops and the cancel ops are removed.

  • Axis update ops, affected by these transpose ops, are updated.

  • Transposes are materialized; that is, added just before the materialize ops, which are linked to the starting transpose ops. The starting transpose op can be materialized (inserted) multiple times, before each of the materialize ops downstream.

  • Block outputs are handled in a similar fashion as the materialize ops.

  • Type inference on all ops is invoked after all the transformations.

  • All Identity ops that are plugged into the graph to treat outputs as materialized are removed.

Debugging

If the debug flag is set to True, the block before and after the transformation is plotted, with transpose nodes highlighted.

optimize_tensor_operation

class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.concat_to_pixel_shuffle[source]

Identify nested, interleaved concat ops which can be replaced by a single concat and a pixel shuffle layer.

This pattern occurs with the faster up-convolution from the FCRN model (Laina et al., 2016).

# Before the concat_to_pixel_shuffle pass.
input(N, C, H, W) -------------------
                                    |
                                    v
input(N, C, H, W) -----> concat(axis=2, interleave=True) -----> concat(axis=3, interleave=True) ----> output
                                                                            ^
                                                                            |
input(N, C, H, W) -----> concat(axis=2, interleave=True) --------------------
            |                       ^
            |                       |
input(N, C, H, W) -------------------

# After the concat_to_pixel_shuffle pass.
input(N, C, H, W) ---------------
                                |
                                v
input(N, C, H, W) -----> concat(axis=1, interleave=True) -----> pixel_shuffle(upscale_factor=2) ----> output
                                ^
                                |
input(N, C, H, W) --------------|
                                |
                                |
input(N, C, H, W) ---------------
class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.detect_concat_interleave[source]

Detect the pattern concat-->reshape--->transpose--->reshape, where concat is along the channel axis (axis=-3), and map this pattern to the concat with interleave op.

This pattern occurs, for example, in the shufflenet model in torchvision.

Given:
    %3 = concat(%1.a, %1.b, ..., axis=-3, interleave=False) #shape = (B, n*C, H, W)
    %4 = reshape(%3) #shape = (B, n, C, H, W)
    %5 = transpose(%4, perm=[0, 2, 1, 3, 4]) # shape = (B, C, n, H, W)
    %6 = reshape(%5) # shape = (B, C*n, H, W)

Result:
    %6 = concat(%1.a, %1.b, ..., axis=-3, interleave=True)
class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.fuse_onehot_matmul_to_gather[source]

Detect if onehot (axis=-1, on_value=1, off_value=0) is followed by a matmul op (no bias). If so, they can be replaced by a gather op.

Input:
    %2 = one_hot(%1, on_value=1, off_value=0, axis=-1)
    %3 = const() # rank 2
    %4  = matmul(%2, %3)

Output:
    %4 = gather(%3, %2, axis=0)
class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.replace_stack_reshape[source]

A stack followed by a reshape layer can be replaced by a concat if the reshape simply removes the new axis and doubles the size of one of the axes next to it.

If the new axis is reshaped to the “right” (that is, the axis just after it is doubled), then we can use a concat. If it is reshaped to the “left” (the axis just before it is doubled), then the concat needs to set the interleaved flag.

Examples:

Given:
    %1 = tensor(1, 5, 3, 4)
    %2 = tensor(1, 5, 3, 4)
    %3 = stack((%1,%2), axis=2) # shape = (1, 5, 2, 3, 4)
    %4 = reshape(%3, shape=[1, 10, 3, 4])

Result:
    %1 = tensor(1, 5, 3, 4)
    %2 = tensor(1, 5, 3, 4)
    %4 = concat((%1,%2), axis=1, interleave=True) # shape = (1, 10, 3, 4)

Given:
    %1 = tensor(1, 5, 3, 4)
    %2 = tensor(1, 5, 3, 4)
    %3 = stack((%1, %2), axis=1) # shape = (1, 2, 5, 3, 4)
    %4 = reshape(%3, shape=[1, 10, 3, 4])

Result:
    %1 = tensor(1, 5, 3, 4)
    %2 = tensor(1, 5, 3, 4)
    %4 = concat((%1, %2), axis = 1) # shape = (1, 10, 3, 4)
class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.use_reflection_padding[source]

Identify a reflection padding layer composed out of slices and concats.

Input graph:

        ------------------------------------------------------------------------------------- |
        |                                                                                     v
input(1, 2, 6, 8) ------> slice_by_index(begin=[0, 0, 0, 1], end=[0, 0, 0, 2]) -----> concat(axis=3) ---> out(1, 2, 6, 10)
        |                                                                                     ^
        ----------------> slice_by_index(begin=[0, 0, 0, -2], end=[0, 0, 0, -1]) -------------|

Output graph:

input(1, 2, 6, 8) -----0> pad(mode=reflect, size=[0, 0, 1, 1]) -----> out(1, 2, 6, 10)
class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.expand_high_rank_reshape_and_transpose[source]

Detect the pattern reshape_1-->transpose-->reshape_2, where reshape_1 has a output tensor with rank >= 6, and the reshape_2 produces a tensor with rank <= 5.

In general, we can expand this pattern into a sequence of rank 4 reshape and transpose ops, which is supported by Core ML runtime.

Given:
    %1 = reshape(%x, shape=(d1, d2, d3, d4, ..., dn))
    %2 = transpose(%1, perm=(p1, p2, ..., pn))
    %3 = reshape(%2, shape=(o1, o2, o3, o4, o5))

Result:
    %t1 = reshape(%x, shape=(y11, y12, y13, y14))
    %h1 = transpose(%t1, perm=[0, 2, 1, 3])
    %t2 = reshape(%h1, shape=(y21, y22, y23, 214))
    %h2 = transpose(%t2, perm=[0, 2, 1, 3])
    ....
    %hn = transpose(%tn, perm=[0, 2, 1, 3])
    %3 = reshape(%hn, shape=(o1, o2, o3, o4, o5))

preprocess

class coremltools.converters.mil.mil.passes.defs.preprocess.image_input_preprocess[source]

Plug in to transpose image input in NHWC format to NCHW format.

Follow these steps:

  1. Check whether there are any inputs that the users specify as ImageType.

  2. Check the channel’s dimension for all inputs that are ImageType.

    1. channel_first == True We do not modify this input, since channel_first is the intended behaviour for feeding images for optimal performance.

    2. channel_first == False We convert the input into a “channel_first” input, and plug in a transpose for the input to maintain the remaining graph’s dimensionality.

class coremltools.converters.mil.mil.passes.defs.preprocess.sanitize_input_output_names[source]

Sanitize the names of model input and output vars to make sure that they are of the format as described in the NameSanitizer class; that is, of the format [a-zA-Z_][a-zA-Z0-9_]*.

class coremltools.converters.mil.mil.passes.defs.preprocess.update_output_dtypes[source]

Update the dtypes of output vars of the main block to match the dtypes provided in prog.main_output_types, which in turn is populated by the outputs argument provided by the user in the coremltools.convert() API. This graph pass assumes that the list of outputs in prog.main_output_types (if not None), are in the same order as the output vars.

quantization

class coremltools.converters.mil.mil.passes.defs.quantization.add_fp16_cast(op_selector=None)[source]

For each input of dtype float32, inject a cast op to change it to float16 dtype.

For each output of dtype float16, inject a cast op to change it back to float32.

This pass is the registered interface for FP16ComputePrecision, which makes it consistent with other passes’ interfaces.

Support options:

  • skip_ops_by_type: Skip op types specified by comma-separated string; for example, "mul,const".