Source code for

#  Copyright (c) 2020, Apple Inc. All rights reserved.
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at

from import Operation, types
from import DefaultInputs, InputSpec, TensorInputType
from import SYMBOL, VALUE, precondition
from import register_op
from import compute_gather
from import (
    gather_along_axis as _gather_along_axis_iOS15,
from import _IOS16_TARGET

[docs] @register_op(opset_version=_IOS16_TARGET) class gather(Operation): """ The iOS16 version. This section documents only the differences between this version and the iOS 15 :py:class:`~.iOS15.scatter_gather.gather`. This version supports ``batch_dims``, similar to `tf.gather <>`_. Input parameter ``indices`` now supports ``int16`` and ``uint16``. Parameters ---------- x: tensor<\*D, T> (Required) indices: tensor<\*N, I> (Required) * Indices values may be negative. More precisely, ``-D[axis]<= v < D[axis]`` for ``v`` in ``indices``. axis: const i32 (Optional. Default=``0``) * Negative axis is supported. batch_dims: const i32 (Optional. Default=``0``) * The number of batch dimensions. Returns ------- tensor<\*K, T> * Where ``K = D[:axis] + N[batch_dims:] + D[axis+1:]``. Attributes ---------- T: fp16, fp32, i32 I: uint16, int16, int32 References ---------- See `tf.gather <>`_. """ input_spec = InputSpec( x=TensorInputType(type_domain="T"), indices=TensorInputType(type_domain="I"), axis=TensorInputType(const=True, optional=True, type_domain=types.int32), batch_dims=TensorInputType(const=True, optional=True, type_domain=types.int32) ) type_domains = { "T": (types.fp16, types.fp32, types.int32), "I": (types.int32, types.uint16, types.int16), } def default_inputs(self): return DefaultInputs( axis=0, batch_dims=0, ) @precondition(allow=VALUE | SYMBOL) def value_inference(self): x = self.x.sym_val indices = self.indices.val if indices is None: # only allow x to be symbolic. indices cannot. return None return compute_gather( params=self.x.sym_val, indices=self.indices.val, axis=self.axis.val, batch_dims=self.batch_dims.val, ) def type_inference(self): # validate parameters if self.axis.val < -self.x.rank or self.axis.val >= self.x.rank: raise IndexError( "Axis value {} is out of bounds for {} node {}".format( self.axis.val, self.op_type, ) ) if self.batch_dims.val >= self.x.rank: raise ValueError( "batch_dims {} must be less than x.rank {} for node {}".format( self.batch_dims.val, self.x.rank, ) ) if self.batch_dims.val > self.indices.rank: raise ValueError( "batch_dims {} must be less or equal to than indices.rank {} for node {}".format( self.batch_dims.val, self.indices.rank, ) ) output_rank = self.x.rank - 1 + self.indices.rank - self.batch_dims.val if output_rank == 0: # output scalar return self.x.dtype # compute output shape axis = self.axis.val axis = axis if axis >= 0 else axis + self.x.rank batch_dims = self.batch_dims.val out_shape = self.x.shape[:axis] + self.indices.shape[batch_dims:] + self.x.shape[axis + 1 :] return types.tensor(self.x.dtype, out_shape)
@register_op(opset_version=_IOS16_TARGET) class gather_along_axis(_gather_along_axis_iOS15): """ The iOS16 version. The only difference between this version and the iOS 15 :py:class:`~.iOS15.scatter_gather.gather_along_axis`. is that input parameter ``indices`` now supports ``int16`` and ``uint16``. Parameters ---------- x: tensor<\*D, T> (Required) indices: tensor<\*K, I> (Required) axis: const i32 (Optional): * Default to ``0``. Returns ------- tensor<\*D, T>: * Output tensor has the same shape as ``indices``. Attributes ---------- T: fp16, fp32, i32 I: uint16, int16, int32 """ input_spec = InputSpec( x=TensorInputType(type_domain="T"), indices=TensorInputType(type_domain="I"), axis=TensorInputType(const=True, optional=True, type_domain=types.int32), ) type_domains = { "T": (types.fp16, types.fp32, types.int32), "I": (types.int32, types.uint16, types.int16), }
[docs] @register_op(opset_version=_IOS16_TARGET) class gather_nd(Operation): """ The iOS16 version. This section documents only the differences between this version and the iOS 15 :py:class:`~.iOS15.scatter_gather.gather_nd`. This version supports ``batch_dims``. Input parameter ``indices`` now supports ``int16`` and ``uint16``. Parameters ---------- x: tensor<\*D, T> (Required) indices: tensor<\*K, I> (Required) batch_dims: const i32 (Optional. Default=``0``) * The number of batch dimensions. Returns ------- tensor<\*V, T> * ``V = K[:-1] + D[batch_dims + K[-1]:]``, where ``D = x.shape`` and ``K = indices.shape``. Attributes ---------- T: fp16, fp32, i32 I: uint16, int16, int32 References ---------- See `tf.gather_nd <>`_. """ input_spec = InputSpec( x=TensorInputType(type_domain="T"), indices=TensorInputType(type_domain="I"), batch_dims=TensorInputType(const=True, optional=True, type_domain=types.int32), ) type_domains = { "T": (types.fp16, types.fp32, types.int32), "I": (types.int32, types.uint16, types.int16), } def default_inputs(self): return DefaultInputs( batch_dims=0, ) def type_inference(self): batch_dims = self.batch_dims.val indices_depth = self.indices.shape[-1] if indices_depth > self.x.rank - batch_dims: msg = "For node {}, indices.shape[-1] ({}) + batch_dims ({}) must be smaller or equal to the input rank {}".format(, indices_depth, batch_dims, self.x.rank ) raise ValueError(msg) out_type = self.x.dtype out_shape = self.indices.shape[:-1] + self.x.shape[batch_dims+indices_depth:] return types.tensor(out_type, out_shape)