MIL Graph Passes
Graph Passes supported by the Model Intermediate Language (MIL):
cleanup
- class coremltools.converters.mil.mil.passes.defs.cleanup.const_deduplication[source]
Remove duplicated large constants (tensor with 100+ elements)
For example
Input graph (where weight and bias are large constants): weight_q = const(weight) weight_k = const(weight) bias_q = const(bias) bias_k = const(bias) q_embedding = linear(x=q, weight=weight_q, bias=bias_q) k_embedding = linear(x=k, weight=weight_k, bias=bias_k) Output graph: weight_q = const(weight) bias_q = const(bias) q_embedding = linear(x=q, weight=weight_q, bias=bias_q) k_embedding = linear(x=k, weight=weight_q, bias=bias_q)
Concretely, this graph pass consists of two stages:
Deduplication of
const
op:We consider a
const
as duplicated if there exists such a previousconst
that has same dtype and valueDeduplication of
constexpr_*
op:We consider a
constexpr_*
as duplicated if there exists such a previousconstexpr_*
that has the sameop_type
and input attributes.
Support options:
const_threshold
: Skip deduplicatingconst
ops that have smaller number of elements than a threshold. Defaults to100
. i.e. the constants withsize < 100
will not be deduplicated.
- class coremltools.converters.mil.mil.passes.defs.cleanup.const_elimination[source]
Replace non-
const
ops that haveconst
Var. Outputs are replaced with theconst
op. Example:Given: %2, %3 = non_const_op(...) # %2 is const, %3 isn't const %4 = other_op(%2, %3) Result: _, %3 = non_const_op(...) # _ is the ignored output %2_const = const() # %2_const name is for illustration only %4 = other_op(%2_const, %3)
Support options:
skip_const_by_size
: Skip foldingconst
ops that have larger number of elements than a threshold.
- class coremltools.converters.mil.mil.passes.defs.cleanup.dead_code_elimination[source]
Eliminate unused ops in program. Ops whose outputs do not contribute to final outputs will be deleted.
# Before dead_code_elimination pass. main(%x: (2, 4, fp32)) { block0() { %const_2: (4, 2, fp32)* = const(val=[...]) %const_3: (4, fp32)* = const(val=[...]) %tx_0: (bool)* = const(val=False) %ty_0: (bool)* = const(val=False) %matmul_0: (2, 2, fp32) = matmul(x=%x, y=%const_2, transpose_x=%tx_0, transpose_y=%ty_0) %linear_0: (2, 4, fp32) = linear(x=%x, weight=%const_2, bias=%const_3) } -> (%linear_0) } # After dead_code_elimination pass. main(%x: (2, 4, fp32)) { block0() { %const_2: (4, 2, fp32)* = const(val=[...]) %const_3: (4, fp32)* = const(val=[...]) %linear_0: (2, 4, fp32) = linear(x=%x, weight=%const_2, bias=%const_3) } -> (%linear_0) }
In the example above,
%matmul_0
is an op that is not used in the computation. This op and its input ops (%tx_0
and%ty_0
) are eliminated in this pass.
- class coremltools.converters.mil.mil.passes.defs.cleanup.dedup_op_and_var_names[source]
For each function, this pass renames ops and variables with the same name as any preceding ops/variables across all scopes in the given function, where the precedence is implementation-specific. Note that an op name and variable names are tracked separately, so an op may have the same name as a variable.
The pass preserves input and output name. Raises ValueError if we cannot dedup without changing the input/output var names.
def prog(x): x = mb.cast(x=x, dtype="fp16", name="castop") x = mb.cast(x=x, dtype="fp32", name="castop") x = mb.square(x=x, name="square_last") return x # Before dedup pass, the op names are ["castop", "castop", "square_last"]. # After dedup pass, the op names are ["castop", "castop_1", "square_last"].
- class coremltools.converters.mil.mil.passes.defs.cleanup.expand_dynamic_linear[source]
Translate to
linear
when the operand is a descendant of const, since such an operand may be folded into const or fused into constexpr later by graph passes. In op translation, we preferlinear
whenever possible because it requires const or constexprweight
andbias
.If such const folding or constexpr fusion did not happen, this pass would clean up the too-ambitious
linear
ops by replacing them withmatmul
ops.
- class coremltools.converters.mil.mil.passes.defs.cleanup.fuse_reduce_mean[source]
Detect the
reduce_sum
—>mul/real_div
pattern than can be mapped toreduce_mean
. That is, the operationreduce_sum/count == reduce_mean
.Input graph
const (scalar) | input ----> reduce_sum ----> mul/real_div -----------> output
Output graph
input --------> reduce_mean ---------> output
- class coremltools.converters.mil.mil.passes.defs.cleanup.loop_invariant_elimination[source]
When a block does not modify a block input var, eliminate that block input var and use the corresponding var in the outer scope. Example:
# Before loop_invariant_elimination pass. # Notice that ``%b.x`` is constant through while loop iterates. main(%a: (1, 2, fp32), %b: (1, 2, fp32)) { block0() { %loop:0: (1, 2, fp32), %loop:1: (1, 2, fp32) = while_loop(loop_vars=(%a, %b)) loop_cond(%a.x, %b.x) { %cond_var: (bool) = some_op(x=%a.x, y=%b.x) } -> (%cond_var) loop_body(%a.x, %b.x) { %add_0: (1, 2, fp32) = add(x=%a.x, y=%b.x) } -> (%add_0, %b.x) } -> (%loop:0, %loop:1) } # After loop_invariant_elimination pass. main(%a: (1, 2, fp32), %b: (1, 2, fp32)) { block0() { %loop:1: (1, 2, fp32) = identity(x=%b) %loop:0: (1, 2, fp32) = while_loop(loop_vars=(%a)) loop_cond(%a.x) { %cond_var: (bool) = some_op(x=%a.x, y=%b) } -> (%cond_var) loop_body(%a.x) { %add_0: (1, 2, fp32) = add(x=%a.x, y=%b) } -> (%add_0) } -> (%loop:0, %loop:1) }
where we eliminate loop invariant
%b.x
fromwhile_loop
, which returns 1 instead of 2 outputs. We also preserve the return var names with identity.
- class coremltools.converters.mil.mil.passes.defs.cleanup.noop_elimination[source]
Remove ops that have no effect.
Given: %1 (1, 96, 128, 64, fp32) = ... %2 (1, 96, 128, 64, fp32) = reshape(%1) ... %3 (1, 96, 128, 64, fp32) = add(%2, constant) ... Result: %1 (1, 96, 128, 64, fp32) = ... %3 (1, 96, 128, 64, fp32) = add(%1, constant) ...
- class coremltools.converters.mil.mil.passes.defs.cleanup.remove_redundant_ops[source]
If there are multiple ops with “identical” inputs, then they are redundant and all but one of them can be removed. This pass checks and removes such ops.
Since all inputs to ops in MIL are named, two ops with same
op_types
can be compared by comparing their correspondingly named inputs. Inputs are treated as identical if one of the following is true:The input is a constant var, in which case its value should have the same dtype and numerical value.
The input is a non constant var, in which case it should be the same var object.
This pass iterates over the ops, takes its first output var, and then builds a candidate op list from the child ops of this var. This candidate ops list contains ops of the same
op_type
, arranged in topological order. From each of these candidate ops in the list, the second, third, and subsequent ops are pairwise compared with the first op, and if identical to it, they are removed. For example:Input: %0 = op0(...) %1 = op1(...) %2 = const(val=4.5) %3 = const(val=4.5) %4 = op2(%1, %0, %2) %5 = op3(%1, %0, %3) Output: %0 = op0(...) %1 = op1(...) %2 = const(val=4.5) %3 = const(val=4.5) # this will get removed later by dead code elimination pass %4 = op2(%1, %0, %2)
In the example above,
op3
is removed and all uses of%5
is replaced by%4
. For more examples, see “TestRemoveRedundantOpsPass”.
- class coremltools.converters.mil.mil.passes.defs.cleanup.remove_symbolic_reshape[source]
Convert symbolic shape in
reshape
to integers.Note: This does not perform any optimization, but simply replaces symbols with positive integers if solved from volumetric constraint, or -1. Therefore, this pass fails if more than one symbol needs to be resolved to -1.
# Before remove_symbolic_reshape pass. main(%x: (s0, 4, fp32)) { block0() { %reshape_0_shape_0: (3,i32)^ = const(val=(s0, s1, 2)) %reshape_0: (s0, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0) } -> (%reshape_0) } # After remove_symbolic_reshape pass. main(%x: (s0, 4, fp32)) { block0() { %reshape_0_shape_0x: (3,i32)* = const(val=[-1, 2, 2]) %reshape_0: (-1, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0x) } -> (%reshape_0) }
TODO (rdar://59165842): Use expand_dims, squeeze etc to use 0 instead of dynamic reshape with -1.
- class coremltools.converters.mil.mil.passes.defs.cleanup.topological_reorder[source]
Topologically re-orders the list of operations in a program by places each operation closer to its first use, or at the end if it’s not consumed by any other operation.
Currently, This pass re-orders only Transpose and Cast operations.
# Example: input program main(x: (2, 4, fp32)) { x = mb.cast(x=x, dtype="fp16") x1 = mb.square(x=x) x1_t = mb.transpose(x=x1, perm=[1, 0]) x2 = mb.cast(x=x1_t, dtype="fp32") x3 = mb.log(x=x) x3_t = mb.transpose(x=x3, perm=[1, 0]) x4 = mb.cast(x=x3_t, dtype="fp32") x5 = mb.relu(x=x) x6 = mb.cast(x=x5, dtype="fp32") x7 = mb.relu(x=x6) x8 = mb.relu(x=x) } -> x2, x4, x7, x8 # After moving `cast` ops becomes main(x: (2, 4, fp32)) { x = mb.cast(x=x, dtype="fp16") x1 = mb.square(x=x) x1_t = mb.transpose(x=x1, perm=[1, 0]) x3 = mb.log(x=x) x3_t = mb.transpose(x=x3, perm=[1, 0]) x5 = mb.relu(x=x) x6 = mb.cast(x=x5, dtype="fp32") x7 = mb.relu(x=x6) x8 = mb.relu(x=x) x4 = mb.cast(x=x3_t, dtype="fp32") x2 = mb.cast(x=x1_t, dtype="fp32") } -> x2, x4, x7, x8 # After moving `transpose` ops becomes main(x: (2, 4, fp32)) { x = mb.cast(x=x, dtype="fp16") x1 = mb.square(x=x) x3 = mb.log(x=x) x5 = mb.relu(x=x) x6 = mb.cast(x=x5, dtype="fp32") x7 = mb.relu(x=x6) x8 = mb.relu(x=x) x3_t = mb.transpose(x=x3, perm=[1, 0]) x4 = mb.cast(x=x3_t, dtype="fp32") x1_t = mb.transpose(x=x1, perm=[1, 0]) x2 = mb.cast(x=x1_t, dtype="fp32") } -> x2, x4, x7, x8
optimize_activation
- class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_gelu_exact[source]
Identify the pattern that corresponds to the exact version of
gelu
, and replace it with a singlegelu
layer withmode=EXACT
. The pattern isy = 0.5 * x * (1 + erf (x / srqt (2))
, which can be represented by one of the following:(1) [...] ----> div (1.414) ---> erf ---> add (1) -----> mul (0.5) ---> mul ---> [...] | ^ | | |------------------------------------------------------------------- (2) [...] ----> div (1.414) ---> erf ---> add (1) -----> mul ---> mul (0.5) ---> [...] | ^ | | |---------------------------------------------------- (3) [...] ----> div (1.414) ---> erf ---> add (1) -----> mul ------> [...] | ^ | | |---------------> mul(0.5) -------------------------- All of them are converted to: [...] ----> gelu (mode=EXACT) ---> [...]
- class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_gelu_tanh_approximation[source]
Identify the pattern that corresponds to the
tanh
approximate version ofgelu
, and replace it with a singlegelu
layer withmode=TANH_APPROXIMATION
.The implementation of this pass uses the generic graph pattern matching and transform algorithm implemented in
coremltools.converters.mil.experimental.passes.generic_pass_infrastructure
and documented incoremltools/converters/mil/experimental/passes/readme.md
.Graph for
get_gelu_pattern1()
y = x * (0.5 * (tanh(((.0447)x^3 + x ) * sqrt(2/pi)) + 1))
[...] -----> pow (3) ----> mul (.044715) ---> add -----> mul (sqrt(2/pi)) ---> tanh ----> add (1) ----> mul (0.5) -----> mul ---> [...] | ^ ^ | | | |------------------------------------------------------------------------------------------------------------------------
Graph for
get_gelu_pattern2()
y = (0.5 * x) * (tanh(((.0447)x^3 + x ) * sqrt(2/pi)) + 1)
-------------------------------------------------------------------------------------------------------- ^ | | V [...] -----> mul(0.5) pow (3) ----> mul (.044715) ---> add -----> mul (sqrt(2/pi)) ---> tanh ----> add (1) -----> mul ---> [...] | ^ ^ | | | |---------------------------------------------------------
- class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_leaky_relu[source]
Detect the
mul
—>max
pattern than can be mapped toleaky_relu
.In code form - Input
%2 = const(value = alpha) # where 0 <= alpha <= 1 %3 = mul(%1, %2) # alpha * x %4 = max(%3, %1) # max(alpha * x, x)
In code form - Output
%4 = leaky_relu(x=%1, alpha=%2)
In graphical form - Input graph
const (val = alpha) | input ----> mul ---------------> maximum -----------> output | | |----------------------------------
In graphical form - Output graph
input --------> leaky_relu ---------> output
- class coremltools.converters.mil.mil.passes.defs.optimize_activation.fuse_prelu[source]
Detect the following patterns that can be mapped to a
prelu
op. Essentially, theprelu
op can be broken down into the following ops:y = a * relu(-1 * x) + relu(x)
Pattern 1
| ------------> relu --------------------| | V x (BCHW) ------| add -----> y (BCHW) | ^ --------> mul -------> relu -----> mul---| ^ ^ | | Const(val=-1) Const(name=a, shape=(C,1,1) or (1,C,1,1))
This will be mapped to:
x (BCHW) ------> prelu(alpha=a, shape=(C,)) ---------> y (BCHW)
Pattern 2
| ------------> relu --------------------| | V x (BCHW) -->transpose(BHWC)---->| add -----> y (BHWC) | ^ --------> mul -------> relu -----> mul---| ^ ^ | | Const(val=-1) Const(shape=(C,) or (1,C) or (1,1,C) or (1,1,1,C))
This will be mapped to:
x (BCHW) ------> prelu ---------> transpose ------> y (BHWC)
optimize_conv
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.add_conv_transpose_output_shape[source]
The
conv_transpose
inputoutput_shape
is an optional input. Since we can infer the output shape fromtype_inference
, we addoutput_shape
input whenever it is known to be constant at compile time. For example:Given: %1: (1, 5, 39, fp32) = conv_transpose(...) # no output_shape input. Result: %2: (3, i32) = const(val=[1,5,39]) %3: (1, 5, 39, fp32) = conv_transpose(..., output_shape=%2)
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.compose_conv1d[source]
In TensorFlow,
tf.keras.layers.Conv1D
is a composite op:expand a dummy dim -> Conv2D -> squeeze the dummy dim
In PyTorch, this is also true for some backends (
mkldnn
andxpu
).This decomposition wrecks the coremltools
conv1d
graph passes, so we should recompose the fragments back to MILconv
, which natively supportsconv1d
:Pattern 1: Given: %2 = expand_dims(%1, axes=-2) or expand_dims(%1, axes=2), %1.rank = 3 %3 = conv(%2) %4 = squeeze(%3, axes=-2) or squeeze(%3, axes=2) ... Result: %4 = conv(%1) ... Pattern 2 (TensorFlow channel_last): Given: %2 = expand_dims(%1, axes=-3) or expand_dims(%1, axes=1), %1.rank = 3 %3 = transpose(%2, perm=(0, 3, 1, 2)) %4 = conv(%3) %5 = transpose(%4, perm=(0, 2, 3, 1)) %6 = squeeze(%5, axes=-3) or squeeze(%5, axes=1) ... Result: %3 = transpose(%1, perm=(0, 2, 1)) %4 = conv(%3) %6 = transpose(%4, perm=(0, 2, 1)) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_conv_batchnorm[source]
Fuse the following
batch_norm
layer intoconv
andconv_transpose
. That is, convertconv + batch_norm
toconv
, by modifying the weight and bias in theconv
layer.Given: %2 = conv(%1) ... %3 = batch_norm(%2) ... Result: %3 = conv(%1) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_conv_bias[source]
Fold
add
/sub
intobias
ofconv
andconv_transpose
. That is, convertconv + add/sub
toconv
, whenadd
/sub
is adding a constant.Two patterns are supported:
Pattern 1: Given: %2 = conv(%1) ... %3 = add(%2, constant) # where constant has shape (1,C,1)/(C,1) for 1d conv, (1,C,1,1)/(C,1,1) for 2d conv etc ... Result: %3 = conv(%1) ... Pattern 2: Given: %2 = conv(%1) %3 = transpose(%2) ... %4 = add(%3, constant) # where constant has a broacasable shape ... Result: %2 = conv(%1) %4 = transpose(%2) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_conv_scale[source]
Fold
mul
/div
intoconv
/conv_transpose
by updating the weight/bias of the convolution layers.The scale
const
can be a single number (scalar) or a vector with a broadcastable shape. For example, if the output of theconv
/deconv
layer is(B, Cout, H, W)
,const
of shape(Cout, 1, 1)
and(1, Cout, 1, 1)
are allowed.Given: %2 = conv(%1) ... %3 = mul(%2, constant) # where constant is the scale constant ... Result: %3 = conv(%1) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_conv.fuse_pad_conv[source]
When we observe
pad -> transpose -> conv
, we move thepad
to be next toconv
. This allows us to meldpad + conv
if possible.Given: %1 = pad(%0, ...) %2 = transpose(%1, ...) %3 = conv(%2, ...) ... Result: %1.a = transpose(%0, ...) $2.a = pad(%1.a, ...) %3 = conv(%2.a) ...
optimize_elementwise_binary
- class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.divide_to_multiply[source]
Convert divide into multiply if the divisor is
const
.
- class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.select_optimization[source]
For
select(cond, a, b)
, there are 2 cases where we can replace it with a single simpler opIf
cond
is a const scalar (or a const tensor but all elements are the same, which is equivalent to a scalar), then we replaceselect(cond, a, b)
with simplya
orb
Input graph: const(scalar cond) -| | a ------------------|-> select -> output | b ------------------| Output graph: if cond: a -> output else: b -> output
If
cond
is a more complicated const, anda
is an inf const, then we replacea
withselect(cond, a, 0)
, then returna + b
Input graph: const(cond) -| | const(±inf) -|-> select -> output | b -----------| Output graph: select(cond, ±inf, 0) -| |-> add -> output b ---------------------|
Note that
select(cond, ±inf, 0))
will further get eliminated byconst_elimination
, so in the end the op in graph is simplyadd
This replacement is based on floating-point arithmetic
inf + b = inf -inf + b = -inf 0 + b = b
PS: if
a
is not inf const butb
is, then we would swapa
andb
- class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.fuse_elementwise_to_batchnorm[source]
Fold
mul
+add
into abatchnorm
if theconst
feeding into themul
/add
is of shape(1,C,1,1)
or(C,1,1)
and input tomul
is of rank 4.Given: [Const] [Const] | | V V [...] --> [Mul] --> [Add] --> [...] That is, %2 = op1(%1) %3 = mul(%2, constant) %4 = add(%3, constant) %5 = op2(%4) ... Result: [...] --> [BatchNorm] --> [...] That is, %2 = op1(%1) %4 = batchnorm(%2) %5 = op2(%4) ...
- class coremltools.converters.mil.mil.passes.defs.optimize_elementwise_binary.rank0_expand_dims_swap[source]
Identify the pattern of a
rank-0
binary elementwise operation followed by anexpand_dims
op. In the MIL backend, the output of theelementwise
op becomes rank 1. Hence, anexpand_dims
op should be added after both of therank-0
tensors, and the finalexpand_dims
should be removed. If the output var of the binary elementwise op is consumed by more than one op, asqueeze
op is inserted.Input
[...](rank-0) --> sub --> expand_dims (axes=[0]) --> [...] ^ | | |--> op2 | | | |--> op3 | [scalar const]
Output
[...](rank-0) --> expand_dims (axes=[0]) --> sub --> [...] ^ | | |--> squeeze ---> op2 | | | |--> op3 | expand_dims (axes=[0]) ^ | | [scalar const]
optimize_linear
- class coremltools.converters.mil.mil.passes.defs.optimize_linear.fuse_linear_bias[source]
Convert
linear + add/sub
to a singlelinear
by updating the weight and bias of thelinear
layer.Example 1: Original: %4 = linear(x=%1, weight=%2, bias=%3) # %2 is a rank-2 const tensor (weight) # %3 is a rank-1 const tensor (bias) ... %6 = add(x=%4, y=%5) # %5 is a const tensor with same shape as %3 Result: %8 = linear(x=%1, weight=%2, bias=%7) # where %7 is a new const tensor with value # %7 = %3 + %6 Example 2: Original: %4 = linear(x=%1, weight=%2, bias=%3) # %2 is a rank-2 const tensor (weight) # %3 is a rank-1 const tensor (bias) ... %6 = sub(x=%5, y=%4) # %5 is a const tensor with a broacasable shape with %3. i.e. if %3 has shape (Dout), %5 could be (1, Dout). Result: %9 = linear(x=%1, weight=%7, bias=%8) # where %7 is a new const tensor with value %7 = -%2 # %8 = %5 - %3
- class coremltools.converters.mil.mil.passes.defs.optimize_linear.fuse_matmul_weight_bias[source]
Convert
matmul + add/sub
tolinear
whenever possible.Given: %3 = matmul(x=%1, y=%2) # %1 or %2 is const and rank 2 (weight) ... %5 = add(x=%3, y=%4) # %4 is const. add(x=%4, y=%3) is equivalent # sub is similar. Result: # assuming %2 above is const and rank 2 %5 = linear(x=%1, weight=%2, bias=%4)
- class coremltools.converters.mil.mil.passes.defs.optimize_linear.fuse_transpose_matmul[source]
Fuse
transpose + matmul
tomatmul
if possible, sincematmul
has argstranspose_x
andtranspose_y
to transpose last 2 dimsPositive example: Input graph: transpose(x=x, perm=(1, 0)) -| |-> matmul(x=transposed_x, y=transposed_y) transpose(x=y, perm=(1, 0)) -| Output graph: matmul(x=x, y=y, transpose_x=True, transpose_y=True) Negative example: Input graph: transpose(x=x, perm=(1, 0, 2)) -| |-> matmul(x=transposed_x, y=transposed_y) transpose(x=y, perm=(1, 0, 2)) -| Output graph: Same to input graph, nothing changes
optimize_normalization
- class coremltools.converters.mil.mil.passes.defs.optimize_normalization.fuse_layernorm_or_instancenorm[source]
A graph optimization pass on PyMIL to detect and fuse several variants of
layer_norm
orinstance_norm
. Pattern 1 corresponds to eitherlayer_norm
orinstance_norm
. Patterns 2-4 areinstance_norm
. Pattern 5 islayer_norm
. You can find these patterns in the methods for this class in the source code. To quickly view the source code, click the [source] button at the end of the class definition.
optimize_quantization
- class coremltools.converters.mil.mil.passes.defs.optimize_quantization.merge_affine_dequantize_with_consecutive_ops[source]
This graph pass does const folding to a chain of supported ops starts with a
constexpr_affine_dequantize
op. More types of op are supported when quantization is tensor-wise, and only a subset is supported for channel-wise. For exampleInput graph: data -> constexpr_affine_dequantize -> transpose -> expand_dims -> out Output graph: new_data -> constexpr_affine_dequantize -> out
where
new_data
is computed bydata -> transpose -> expand_dims
.Note that, the graph pass only supports const folding of a single linked list pattern. For example, the following pattern will not be changed
|-> constexpr_affine_dequantize -> transpose -> out data -| |-> constexpr_affine_dequantize -> reshape -> out_2
since the quantized data is used by multiple
constexpr
- class coremltools.converters.mil.mil.passes.defs.optimize_quantization.int_op_canonicalization[source]
For general quantized operators, in Core ML, we represent them as
dequantize -> the floating-point version of this operator -> quantize
, because mathematically it is the floating-point tensor rather than its quantized integer representation that gets operated upon.For some quantized operators that do not involve floating-point arithmetic, however, it is unnecessary to prepend
dequantize
and appendquantize
. Examples are:reshape
- class coremltools.converters.mil.mil.passes.defs.optimize_quantization.nullify_redundant_quantization_zero_point[source]
In Core ML quantization, the performance is better when
zero point = 0
, so we try to makezero point = 0
if possible:zero point = -128
this must be an int8 quantization
equivalent to uint8 quantization with 0 zero point
zero point = 128
this must be an uint8 quantization
equivalent to int8 quantization with 0 zero point
Since
zero point = 0
is equivalent tozero point = None
in Core ML semantics, we further canonicalize tozero point = None
to:make further graph passes easier
avoid serializing trivial 0
The
zero point = 0
case can be canonicalized triviallyInput op: quantize/dequantize(zero_point=0) Output op: quantize/dequantize(zero_point=None)
To guarantee the conservation of output regardless the zero-point shift in
zero point = ±128
cases, we would only transform:const dequantize, where we fuse the zero-point shift into the const
Input op: dequantize(input=const, zero_point=±128) Output op: dequantize(input=const∓128, zero_point=None)
quantize -> dequantize
, where we nullify both simultaneously
Input graph: input -> quantize(zero_point=±128) -> dequantize(zero_point=±128) -> output Output graph: input -> quantize(zero_point=None) -> dequantize(zero_point=None) -> output
- class coremltools.converters.mil.mil.passes.defs.optimize_quantization.dequantize_quantize_pair_elimination[source]
When a
dequantize
is followed by an identicalquantize
(same scale, zero point, axis), they cancel out and can be eliminatedInput graph: input -> dequantize -> quantize -> output Output graph: input -> output
When the pattern has branches (dequantize has multiple children), we cannot eliminate the whole pair, but can still shorten the path. More specifically:
Input graph: op1 -> dequantize -> quantize -> op2 | |-> some_other_op Output graph: op1 -> dequantize -> some_other_op | |-> op2
PS: On the other hand, the reversed pattern, i.e.,
quantize -> dequantize
, is not redundant, since that is the pattern which naturally occurs when a quantized op is converted. In current activation quantization conversion, a quantized op becomesdequantize -> regular op -> quantize
so if we have a sequence of quantized ops, we will get
dequantize -> regular op1 -> quantize -> dequantize -> regular op2 -> quantize
The
quantize -> dequantize
pair in the middle is not redundant, even if they have identical scales and zero points and axes, since removing them will lead to loss of information about the quantization parameters of the output var of op1
- class coremltools.converters.mil.mil.passes.defs.optimize_quantization.distributive_quantized_binary_op_scale_normalization[source]
In the backend, for better performance, quantized op can have 1 input scale fused within the quantized op kernel. For binary ops, there are 2 inputs, but only 1 can get fused. For example, for quantized
add
MIL graph (consists of MIL ops): dequantize(x, s_x, zp_x) -| x_fp = (x - zp_x) * s_x | |-> add(x_fp, y_fp) -> quantize(z_fp, s_z, zp_z) dequantize(y, s_y, zp_y) -| z_fp = x_fp + y_fp z = z_fp / s_z + zp_z y_fp = (y - zp_y) * s_y Backend graph (consists of backend instructions, usually including + - * / and fused *+): x_shift = x - zp_x -------------------------| |-> z_fp = s_x * x_shift + y_fp -> z = z_fp / s_z + zp_z y_shift = y - zp_y -> y_fp = s_y * y_shift -|
Where
x
andy
are the inputs,z
is the output,s
andzp
are the corresponding scale and zero point.The reason why fusing one scale leads to better performance is, instead of 2 instructions
x_fp = s_x * x_shift
andz_fp = x_fp + y_fp
, a singlez_fp = x_shift * s_x + y_fp
instruction achieves the same result.In this pass, we normalize
s_y
to 1, so they_fp = s_y * y_shift
instruction can get skipped as well, leading to even better performance. This pass only applies to distributive binary ops such asadd
andsub
Appendix: Mathematical and Computer-Scientific Details
Mathematically, for a binary operator
.op.
z_fp = (x - zp_x) * s_x .op. (y - zp_y) * s_y = s_y * [(x - zp_x) * s_x/s_y .op. (y - zp_y) * 1]
The corresponding pseudo code is
# before z_fp = (x - zp_x) * s_x .op. (y - zp_y) * s_y z = z_fp / s - zp_z # after z_fp_modified = (x - zp_x) * s_x/s_y .op. (y - zp_y) * 1.0 z = z_fp_modified / (s_z/s_y) - zp_z
Concretely, as a MIL graph pass
Input graph: dequantize(scale=s_x) -| |-> op -> quantize(scale=s_z) dequantize(scale=s_y) -| Output graph: dequantize(scale=s_x/s_y) -| |-> op -> quantize(scale=s_z/s_y) dequantize(scale=1.0) -|
PS: we only support scalar
s_y
for now. Ifs_y
is not scalar buts_x
is, we would swapx
andy
. Support for both-vector case is to be explored, due to the broadcasting complication.
- class coremltools.converters.mil.mil.passes.defs.optimize_quantization.dequantize_to_constexpr[source]
dequantize
op with constant input is equivalent toconstexpr_affine_dequantize
. This is one of the canonicalization pass that transforms all suchdequantize
ops to respectiveconstexpr_affine_dequantize
ops.Input graph: dequantize(input=const) -> downstream op Output graph: constexpr_affine_dequantize -> downstream op
This pass is being performed because constant tensors being propagated through
dequantize
op would be serialized in bloated/decompressed fashion, whereas withconstexpr_affine_dequantize
, constant weights/tensors remain compressed at serialization.
optimize_repeat_ops
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.cast_optimization[source]
This optimization pass performs the following:
Removes redundant
cast
op; that is,cast
where source and destination tensors have same dtypes.Fuses two consecutive cast ops if applicable, repeatedly.
This is a non-algebraic translation which assumes that the upcasting doesn’t change the user’s intent.
Example for redundant
cast
op removal: .. code-block:Input graph: input(fp16) -> cast(dtype="fp16") -> relu -> out Output graph: input -> relu -> out The input and output tensors for the ``cast`` op are both with type of ``fp16``. Hence, it can be removed.
Example for two
cast
ops fusion: .. code-block:Input graph: input(int8) -> cast(dtype="fp16") -> cast(dtype="fp32") -> out Output graph: input(int8) -> cast(dtype="fp32") -> out The data range and resolution of the above graph are limited by the int8 input, so the fusion is allowed.
Negative example for two
cast
ops fusion: .. code-block:Input graph: input(fp32) -> cast(dtype="bool") -> cast(dtype="fp16") -> out Output graph: Same as input graph. The above two ``cast`` ops cannot be merged, since after the first cast, the resolution of the numerical output is downcasted to binary (``0, 1``). If we fuse them, the output would be in the range and resolution of ``fp16`` instead.
Another Negative example for two
cast
ops fusion: .. code-block:Input graph: input(int32) -> cast(dtype="int8") -> cast(dtype="uint8") -> out Output graph: Same as input graph. The above two ``cast`` ops cannot be merged, since in the original graph, by going through two casts, the output numerical range is capped to ``[0, 127]``. However, if two ``cast`` ops are reduced to 1 ``cast(dtype="uint8")``, the output numerical would in the range of ``[0, 255]``. The fusion would cause numerical issue for the numbers between ``[128, 255]``, which is prohibited.
In general, two
cast
ops can be merged if the output data range and resolution is not affected.For more examples, please see the unittests that start with prefix
TestCastOptimization
intest_passes.py
.
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_paddings[source]
Identify two consecutive
pad
layers which could be merged into a singlepad
layer.This is possible only if one of the following conditions is satisfied:
The paddings are “constant” and have the same
constant_val
.The paddings act along different axes.
Input graph: input(1, 2, 6, 8) ------> pad([1, 1], mode='reflect) -----> pad([1, 1, 0, 0], mode='reflect') ---> out(1, 2, 8, 10) Output graph: input(1, 2, 6, 8) ------> pad([1, 1, 1, 1], mode='reflect) ---> out(1, 2, 8, 10)
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_relus[source]
Identify consecutive
relu
layers which could be merged into a singlerelu
layer.Input graph: input ------> relu -----> 1 or more relu layers ---> out Output graph: input ------> relu ---> out
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_reshapes[source]
Identify consecutive
reshape
ops which could be merged into a singlereshape
.Input graph: input -> reshape -> 1 or more reshapes -> output Output graph: input -> reshape -> output
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.merge_consecutive_transposes[source]
Identify consecutive ‘transpose’ layers which could be merged into a single ‘transpose’ layer.
Input graph: input ------> transpose -----> 1 or more transpose layers ---> out Output graph: input ------> transpose ---> out
- class coremltools.converters.mil.mil.passes.defs.optimize_repeat_ops.reduce_transposes[source]
Reduce transposes when it is applicable. For example:
# Example 1 Input graph: input -----> transpose(axis=[1,0]) -----> transpose(axis=[1,0]) ---> out Output graph: input -----> identity -----> out # Example 2 Input graph: input---->transpose(axis=[0,3,1,2])---->relu---->transpose(axis=[0,2,3,1])--->out Output graph: input----->relu----->out # Example 3 Input graph: input(shape=10,2,3,5)--->transpose(axis=[0,2,3,1])----->relu---->pool----->out1 | | --->relu----->log---->transpose(axis=[0,3,1,2])---->out2 Output graph: input(shape=10,2,3,5)----->relu---->transpose(axis=[0,2,3,1])---->pool----->out1 | | --->relu----->log---->out2
Please see
TransposeOptimizationPass
for more details.Notes
This pass is divided into 3 phases:
1st phase: Information gathering.
Plug in Identity ops for all output nodes. This allows us to treat all ops uniformly during traversal.
Block is traversed in the topological order, starting from the ops connected to the inputs.
During the traversal, a value is associated with every var in the block. This value can be either of type
_HypotheticalValue
or_LazyTransposeHypotheticalValue
. The main purpose of type_HypotheticalValue
is to indicate that it is not of type_LazyTransposeHypotheticalValue
._LazyTransposeHypotheticalValue
represents either one or multiple transpose ops with the same perm value. This information is stored in this class. It also wraps a_HypotheticalValue
that was the last hypothetical value which was generated prior to the origin of_LazyTransposeHypotheticalValue
.Each op decides which type of hypothetical value to associate with its output vars, based on its op type, attributes, and the types of the hypothetical values of its input vars.
Ops are classified into 4 categories: unary like, axis update, transpose, and materialize (for all the rest).
- Transpose ops are the ops from which a
_LazyTransposeHypotheticalValue
originate. If the input to it is a
_HypotheticalValue
, its output will be a_LazyTransposeHypotheticalValue
, indicating that thistranspose
op is available to get cancelled downstream.- If the input to it is a
_LazyTransposeHypotheticalValue
, then it is checked whether this op cancels it or not. If the op cancels it, a
_HypotheticalValue
value is generated at the output and the information about thistranspose
cancellation is recorded in the dictionarytranspose_op_to_cancel_ops
.If the op does not cancel, the current
transpose
op is categrorized as a materialize op. Therefore, the information in dictionarytranspose_op_to_materialize_ops
is updated accordingly. The output of the op is now mapped to a_HypotheticalValue
.
- If the input to it is a
- Transpose ops are the ops from which a
Unary like ops: These simply transfer their input hypothetical value type to the output.
Axis update ops: If a
transpose
can pass through them, they are treated like a unary op and the dictionarytranspose_op_to_axis_update_ops
is updated. If the op cannot be updated in any manner to allow atranspose
to pass through, this op is then categorized as a materialize op and handled accordingly.Materialize ops: All
_LazyTransposeHypotheticalValue
input vars, if present, materialize here. Output of this op is always of type_HypotheticalValue
. If the input is a_LazyTransposeHypotheticalValue
, update the dictionarytranspose_op_to_materialize_ops
.To treat an op like a unary op, add its type to
_UNARY_LIKE_OP_TYPES
. In future changes we want to make this process automatic by detecting an op as a unary like by its “traits”.To treat an op like axis update op, add a class specific to the op implementing the class
TransformAxisUpdateOps
. For examples, see classes_TransformConcat
,_TransformPad
, and so on. The dictionaryAXIS_UPDATE_OPS
is automatically filled in by the decorator_TransposeOptimization.register_axis_update_op
.
2nd phase: Determining which
transpose
ops to remove from the graph.All
transpose
ops that have a corresponding compliment op in dicttranspose_op_to_cancel_ops
is a candidate. However, you need to ensure the following:If a
transpose
op is removed, then all of itscancel
ops intranspose_op_to_cancel_ops
must also be removed, to ensure correctness of the graph. The same is true in the reverse direction as well; that is, for everycancel
op that is removed, all its parenttranspose
ops upstream must also be removed.transpose
ops should be removed only if the number ofcancel
ops is greater than the number oftranspose
ops that would get freshly introduced to the block as a result of materialization ops. Currently in the algorithm, each materialization op/output var (dictstranspose_op_to_materialize_ops
/old_output_vars
) results in one moretranspose
op, although this can be further optimized in the future.
To resolve this, we recognize that nodes consisting of sets
(a)
and(b)
form a bipartitle graph, where,(a) ==
startingtranspose
ops (originators of_LazyTransposeHypotheticalValue
) and(b) ==
set oftranspose
cancel
ops andmaterialize
ops.In this bipartite graph, we find all the connected components for each connected component. Either the entire set of
transpose
ops in it are removed/materialized, or none of them are touched.Thus for each set, a determination is made based on counting the number of
cancel
ops andmaterialize
ops.Based on this determination, the final set of
transpose
ops to be removed is updated.
3rd phase: Transforming the graph.
transpose
starting ops and thecancel
ops are removed.Axis update ops, affected by these
transpose
ops, are updated.Transposes are materialized; that is, added just before the
materialize
ops, which are linked to the startingtranspose
ops. The startingtranspose
op can be materialized (inserted) multiple times, before each of thematerialize
ops downstream.Block outputs are handled in a similar fashion as the materialize ops.
Type inference on all ops is invoked after all the transformations.
All Identity ops that are plugged into the graph to treat outputs as materialized are removed.
Debugging
If the
debug
flag is set toTrue
, the block before and after the transformation is plotted, with transpose nodes highlighted.
optimize_state
- class coremltools.converters.mil.mil.passes.defs.optimize_state.canonicalize_inplace_pattern[source]
As a functional-graph framework, Core ML represents in-place operation as
read_state -> functional operation -> write_state
Due to the non-uniqueness of topological order, in the list representation of ops,
write_state
can be anywhere after the functional op. We prefer the canonical order, i.e. havewrite_state
immediately follow the functional opIn practice
1. In PyMIL, we do not use
write_state
op. Instead, we usecoreml_update_state
, which is the composition ofwrite_state -> read_state
The
read_state
op does not matter in the pattern match and transform
So we will match
functional operation -> coreml_update_state
then reorder the
coreml_update_state
. For exampleGiven: mul = mul(state, x) add = add(mul, y) update = coreml_update_state(state, mul) Return: mul = mul(state, x) update = coreml_update_state(state, mul) add = add(mul, y)
- class coremltools.converters.mil.mil.passes.defs.optimize_state.prefer_state_in_downstream[source]
As a functional-graph framework, Core ML represents in-place operation as
read_state -> functional operation -> write_state
When the output of the in-place operation is used downstream, there are 2 possible patterns, one reuses state memory
read_state -> functional operation -> write_state -> read_state -> ...
the other wastes memory for keeping functional output
|-> write_state read_state -> functional operation -| |-> ...
We prefer the reuse-state one
In practice
1. In PyMIL, we do not use
write_state
op. Instead, we usecoreml_update_state
, which is the composition ofwrite_state -> read_state
2. With canonical inplace pattern (guaranteed by graph pass
canonicalize_inplace_pattern
), simply replace the usage of functional output withcoreml_update_state
output is enoughFor example
Given: mul = mul(state, x) update = coreml_update_state(state, mul) add = add(mul, y) Return: mul = mul(state, x) update = coreml_update_state(state, mul) add = add(update, y)
optimize_tensor_operation
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.concat_to_pixel_shuffle[source]
Identify nested, interleaved
concat
ops which can be replaced by a singleconcat
and a pixel shuffle layer.This pattern occurs with the faster up-convolution from the FCRN model (Laina et al., 2016).
# Before the concat_to_pixel_shuffle pass. input(N, C, H, W) ------------------- | v input(N, C, H, W) -----> concat(axis=2, interleave=True) -----> concat(axis=3, interleave=True) ----> output ^ | input(N, C, H, W) -----> concat(axis=2, interleave=True) -------------------- | ^ | | input(N, C, H, W) ------------------- # After the concat_to_pixel_shuffle pass. input(N, C, H, W) --------------- | v input(N, C, H, W) -----> concat(axis=1, interleave=True) -----> pixel_shuffle(upscale_factor=2) ----> output ^ | input(N, C, H, W) --------------| | | input(N, C, H, W) ---------------
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.detect_concat_interleave[source]
Detect the pattern
concat-->reshape--->transpose--->reshape
, whereconcat
is along the channel axis(axis=-3)
, and map this pattern to theconcat
withinterleave
op.This pattern occurs, for example, in the
shufflenet
model intorchvision
.Given: %3 = concat(%1.a, %1.b, ..., axis=-3, interleave=False) #shape = (B, n*C, H, W) %4 = reshape(%3) #shape = (B, n, C, H, W) %5 = transpose(%4, perm=[0, 2, 1, 3, 4]) # shape = (B, C, n, H, W) %6 = reshape(%5) # shape = (B, C*n, H, W) Result: %6 = concat(%1.a, %1.b, ..., axis=-3, interleave=True)
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.expand_high_rank_reshape_and_transpose[source]
Detect the pattern
reshape_1-->transpose-->reshape_2
, wherereshape_1
has an output tensor withrank >= 6
, andreshape_2
produces a tensor withrank <= 5
.In general, we can expand this pattern into a sequence of rank 4
reshape
andtranspose
ops, which is supported by the Core ML runtime.Given: %1 = reshape(%x, shape=(d1, d2, d3, d4, ..., dn)) %2 = transpose(%1, perm=(p1, p2, ..., pn)) %3 = reshape(%2, shape=(o1, o2, o3, o4, o5)) Result: %t1 = reshape(%x, shape=(y11, y12, y13, y14)) %h1 = transpose(%t1, perm=[0, 2, 1, 3]) %t2 = reshape(%h1, shape=(y21, y22, y23, 214)) %h2 = transpose(%t2, perm=[0, 2, 1, 3]) .... %hn = transpose(%tn, perm=[0, 2, 1, 3]) %3 = reshape(%hn, shape=(o1, o2, o3, o4, o5))
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.fuse_onehot_matmul_to_gather[source]
Detect if
onehot (axis=-1, on_value=1, off_value=0)
is followed by amatmul
op (no bias). If so, they can be replaced by agather
op.Input: %2 = one_hot(%1, on_value=1, off_value=0, axis=-1) %3 = const() # rank 2 %4 = matmul(%2, %3) Output: %4 = gather(%3, %2, axis=0)
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.replace_stack_reshape[source]
A stack followed by a reshape layer can be replaced by a
concat
if the reshape simply removes the new axis and doubles the size of one of the axes next to it.If the new axis is reshaped to the “right” (that is, the axis just after it is doubled), then we can use a
concat
. If it is reshaped to the “left” (the axis just before it is doubled), then theconcat
needs to set theinterleaved
flag.Examples:
Given: %1 = tensor(1, 5, 3, 4) %2 = tensor(1, 5, 3, 4) %3 = stack((%1,%2), axis=2) # shape = (1, 5, 2, 3, 4) %4 = reshape(%3, shape=[1, 10, 3, 4]) Result: %1 = tensor(1, 5, 3, 4) %2 = tensor(1, 5, 3, 4) %4 = concat((%1,%2), axis=1, interleave=True) # shape = (1, 10, 3, 4) Given: %1 = tensor(1, 5, 3, 4) %2 = tensor(1, 5, 3, 4) %3 = stack((%1, %2), axis=1) # shape = (1, 2, 5, 3, 4) %4 = reshape(%3, shape=[1, 10, 3, 4]) Result: %1 = tensor(1, 5, 3, 4) %2 = tensor(1, 5, 3, 4) %4 = concat((%1, %2), axis = 1) # shape = (1, 10, 3, 4)
- class coremltools.converters.mil.mil.passes.defs.optimize_tensor_operation.use_reflection_padding[source]
Identify a reflection padding layer composed out of slices and concats.
Input graph: ------------------------------------------------------------------------------------- | | v input(1, 2, 6, 8) ------> slice_by_index(begin=[0, 0, 0, 1], end=[0, 0, 0, 2]) -----> concat(axis=3) ---> out(1, 2, 6, 10) | ^ ----------------> slice_by_index(begin=[0, 0, 0, -2], end=[0, 0, 0, -1]) -------------| Output graph: input(1, 2, 6, 8) -----0> pad(mode=reflect, size=[0, 0, 1, 1]) -----> out(1, 2, 6, 10)
preprocess
- class coremltools.converters.mil.mil.passes.defs.preprocess.image_input_preprocess[source]
Plug in to
transpose
image input in NHWC format to NCHW format.Follow these steps:
Check whether there are any inputs that the users specify as ImageType.
Check the channel’s dimension for all inputs that are ImageType.
channel_first == True
We do not modify this input, sincechannel_first
is the intended behaviour for feeding images for optimal performance.channel_first == False
We convert the input into a “channel_first” input, and plug in atranspose
for the input to maintain the remaining graph’s dimensionality.
- class coremltools.converters.mil.mil.passes.defs.preprocess.sanitize_input_output_names[source]
Sanitize the names of model input and output vars to make sure that they are of the format as described in the NameSanitizer class; that is, of the format
[a-zA-Z_][a-zA-Z0-9_]*
.
- class coremltools.converters.mil.mil.passes.defs.preprocess.update_output_dtypes[source]
Update the dtypes of output vars of each function block to match the dtypes provided in
function.output_types
. The output types for the main function is populated by theoutputs
argument provided by the user in thecoremltools.convert()
API. This graph pass assumes that the list of outputs infunction.output_types
(if notNone
), are in the same order as the output vars.
quantization
- class coremltools.converters.mil.mil.passes.defs.quantization.add_fp16_cast(op_selector=None)[source]
For each input of dtype float32, inject a
cast
op to change it to float16 dtype.For each output of dtype float16, inject a
cast
op to change it back to float32.This pass is the registered interface for FP16ComputePrecision, which makes it consistent with other passes’ interfaces.
Support options:
skip_ops_by_type
: Skip op types specified by comma-separated string; for example,"mul,const"
.
symbol_transform
- class coremltools.converters.mil.mil.passes.defs.symbol_transform.materialize_symbolic_shape_program[source]
If we realize that only a few fixed shapes are used in a symbolic-shape model, we may prefer materialization into a fixed-shape (multifunction) model, which has the potential to be further optimized
Supported options:
function_name_to_materialization_map
: Dict[str, Dict[str, Tuple[int]]]A dictionary specifying the name of new functions to be created, and for each new function what is the new fixed shapes for inputs. If a new function has the same name as an old function, then the old function will be overridden
source_function_name
: strThe name of the source symbolic-shape function to be materialized, default = main
Example:
Suppose we have a symbolic shape model with 2 symbols
is0
andis1
symbolic_shape_mlmodel: ct.models.MLModel symbolic_shape_prog = symbolic_shape_mlmodel._mil_program
We may invoke this graph pass to materialize some fixed shapes (e.g.
is0 = 2, is1 = 5
andis0 = 4, is1 = 7
), then run every other optimization passespass_pipeline: PassPipeline = ct.PassPipeline.DEFAULT pass_pipeline.insert_pass(0, "common::materialize_symbolic_shape_program") pass_pipeline.set_options( "common::materialize_symbolic_shape_program", { "function_name_to_materialization_map": { # As an example, let us assume the input is x (is0, is1, 1024) "materialization_2_5": {"x": (2, 5, 1024)}, "materialization_4_7": {"x": (4, 7, 1024)}, } }, ) PassPipelineManager.apply_pipeline(symbolic_shape_prog, pass_pipeline)
We will arrive at
main[CoreML8](%x: (is0, is1, 1024, fp16)(Tensor)) { block0() { ... } -> (%y) } materialization_2_5[CoreML8](%x: (2, 5, 1024, fp16)(Tensor)) { block5() { ... } -> (%y) } materialization_4_7[CoreML8](%x: (4, 7, 1024, fp16)(Tensor)) { block6() { ... } -> (%y) }