Source code for coremltools.converters.mil.mil.ops.defs.iOS18.transformers

#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import numpy as np

from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.input_type import InputSpec, TensorInputType
from coremltools.converters.mil.mil.operation import Operation
from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op
from coremltools.converters.mil.mil.ops.defs._utils import broadcast_shapes
from coremltools.converters.mil.mil.ops.defs.iOS18 import _IOS18_TARGET
from coremltools.converters.mil.mil.types.symbolic import any_symbolic, is_symbolic


[docs] @register_op(opset_version=_IOS18_TARGET) class scaled_dot_product_attention(Operation): """ Source: `PyTorch scaled dot product attention <https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html>`_. Computes the scaled dot product attention on query, key, and value tensors, using an optional attention mask if passed. In PyTorch, this is equivalent to:: attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask attn_weight = torch.softmax((Q @ K.transpose(-2, -1) / math.sqrt(Q.size(-1))) + attn_mask, dim=-1) return attn_weight @ V Shape key: - ``B`` = Batch size - ``S`` = Source sequence length - ``L`` = Target sequence length - ``E`` = Query/Key embedding dimension - ``EV`` = Value embedding dimension Numerical values can differ due to floating point fusion/accumulation between backends. Note: We currently do not support the ``dropout_p`` and ``is_causal``. Mask can either be bool or float matching query, key, or value. For bool, it indicates whether the element should take part in the attention. Floats are added to the attention score. Mask shape must be broadcastable to ``[B, \*?, L, S]``. Parameters ---------- query: tensor<[B, \*?, L, E], T> (Required) key: tensor<[B, \*?, S, E], T> (Required) value: tensor<[B, \*?, S, EV], T> (Required) attn_mask: tensor<[\*?, S], M> (Optional) Returns ------- tensor<[B, \*?, L, EV], T> Attributes ---------- T: fp16, fp32 M: bool, fp16, fp32 """ input_spec = InputSpec( query=TensorInputType(type_domain="T"), key=TensorInputType(type_domain="T"), value=TensorInputType(type_domain="T"), attn_mask=TensorInputType(optional=True, type_domain="M"), ) type_domains = { "T": (types.fp16, types.fp32), "M": (types.bool, types.fp16, types.fp32), } def _validate_inputs(self): query_rank = self.query.rank key_rank = self.key.rank value_rank = self.value.rank if query_rank != key_rank or query_rank != value_rank: raise ValueError( f"query, key, value must have a same rank, got\n" f"* query rank = {query_rank}\n" f"* key rank = {key_rank}\n" f"* value rank = {value_rank}" ) if query_rank < 3: raise ValueError( f"query, key, value must have at lease rank 3 " f"for batch, sequence length, embedding, got rank {query_rank}" ) query_shape = self.query.shape key_shape = self.key.shape value_shape = self.value.shape B_query = query_shape[:-2] E_query = query_shape[-1] B_key = key_shape[:-2] S_key = key_shape[-2] E_key = key_shape[-1] B_value = value_shape[:-2] S_value = value_shape[-2] batch_dims = [B_query, B_key, B_value] batch_dims = [batch_dim for batch_dim in batch_dims if not any_symbolic(batch_dims)] if len(set(batch_dims)) > 1: raise ValueError( "query, key, value must have a same batch dimension, got\n" f"* query batch = {B_query}\n" f"* key batch = {B_key}\n" f"* value batch = {B_value}" ) if not is_symbolic(E_query) and not is_symbolic(E_key) and E_query != E_key: raise ValueError( "query and key must have a same embedding dimension, got\n" f"* query embedding = {E_query}\n" f"* key embedding = {E_key}" ) if not is_symbolic(S_key) and not is_symbolic(S_value) and S_key != S_value: raise ValueError( "key and value must have a same sequence length, got\n" f"* key sequence = {S_key}\n" f"* value sequence = {S_value}" ) if self.attn_mask is not None: mask_shape = self.attn_mask.shape S_mask = mask_shape[-1] if not is_symbolic(S_mask) and not is_symbolic(S_key) and S_mask != S_key: raise ValueError( "key and mask must have a same sequence length, got\n" f"* key sequence = {S_key}\n" f"* mask sequence = {S_mask}" ) # If shapes are inconsistent, then `broadcast_shapes` would raise exception broadcast_shapes(query_shape[:-1], mask_shape[:-1]) def type_inference(self): self._validate_inputs() shape = list(self.query.shape[:-1]) + [self.value.shape[-1]] return types.tensor(self.query.dtype, shape) def value_inference(self): query = self.query.val key = self.key.val value = self.value.val if query is None or key is None or value is None: return None float_mask = None if self.attn_mask is not None and self.attn_mask.val is not None: mask = self.attn_mask.val if mask.dtype == bool: float_mask = np.zeros(mask.shape) float_mask[np.where(np.logical_not(mask))] = -np.inf else: float_mask = mask similarity = np.matmul(query, key.swapaxes(-2, -1)) / np.sqrt(query.shape[-1]) if float_mask is not None: similarity += float_mask attention_weight = self.numpy_softmax_last_dim(similarity) attention = np.matmul(attention_weight, value) return attention @staticmethod def numpy_softmax_last_dim(x: np.ndarray) -> np.ndarray: exps = np.exp(x - np.max(x, axis=-1)[..., None]) softmax = exps / np.sum(exps, axis=-1)[..., None] return softmax