Source code for coremltools.converters.mil.mil.passes.defs.cleanup.remove_symbolic_reshape

#  Copyright (c) 2023, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause


from coremltools import _logger as logger
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import Program
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
from coremltools.converters.mil.mil.types.symbolic import any_variadic, is_symbolic, num_symbolic


[docs]@register_pass(namespace="common") class remove_symbolic_reshape(AbstractGraphPass): """ Convert symbolic shape in ``reshape`` to integers. Note: This does not perform any optimization, but simply replaces symbols with positive integers if solved from volumetric constraint, or -1. Therefore, this pass fails if more than one symbol needs to be resolved to -1. .. code-block:: # Before remove_symbolic_reshape pass. main(%x: (s0, 4, fp32)) { block0() { %reshape_0_shape_0: (3,i32)^ = const(val=(s0, s1, 2)) %reshape_0: (s0, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0) } -> (%reshape_0) } # After remove_symbolic_reshape pass. main(%x: (s0, 4, fp32)) { block0() { %reshape_0_shape_0x: (3,i32)* = const(val=[-1, 2, 2]) %reshape_0: (-1, 2, 2, fp32) = reshape(x=%x, shape=%reshape_0_shape_0x) } -> (%reshape_0) } TODO (rdar://59165842): Use expand_dims, squeeze etc to use 0 instead of dynamic reshape with -1. """ def apply(self, prog: Program): for f in prog.functions.values(): num_changes = self._remove_symbolic_reshape_block(f) msg = "remove_symbolic_reshape: changed {} reshapes." logger.info(msg.format(num_changes)) @block_context_manager def _remove_symbolic_reshape_block(self, block): num_changes = 0 for op in list(block.operations): for b in op.blocks: num_changes += self._remove_symbolic_reshape_block(b) if op.op_type != "reshape": continue if op.shape.val is not None: # shape does not contain symbol. continue if op.shape.sym_val is None: # shape is runtime determined. continue if len(op.shape.child_ops) > 1: continue # Use output shape as `shape` shape = op.outputs[0].shape if any_variadic(shape): msg = ( "Cannot reshape to variadic from a compile time " + "shape argument. Variadic shape can only be achieved " + "via runtime shape argument. op: {}" ) raise ValueError(msg.format(op)) num_symbols = num_symbolic(shape) if num_symbols > 1: continue # Convert the one symbol to -1 integer_shape = [-1 if is_symbolic(i) else i for i in shape] shape_const = mb.const( val=integer_shape, name=op.shape.name + "x", before_op=op, ) reshaped = mb.reshape(x=op.x, shape=shape_const, name=op.name, before_op=op) op.enclosing_block.replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=reshaped ) # Remove all the ops at once block.remove_ops([op, op.shape.op]) num_changes += 1 return num_changes