pub struct GatherOperation<'c> { /* private fields */ }
Expand description
A gather
operation. Gather a subset of a tensor at specified indices.
The gather
operation extracts a subset of the elements from a source
tensor at the given indices.
In its most general form, the tensor of indices specifies all the coordinates
of every element to extract (i.e. COO format, without the payload).
The indices are expected to be confined to coordinate values that fit the
range of the source
tensor, otherwise the behavior is undefined.
The leading dimensions of the index tensor give the result tensor its leading
dimensions. The trailing dimensions of the result tensor are obtained from
the source tensor by omitting the dimensions specified in gather_dims
(rank-reducing semantics) or setting them to 1
(rank-preserving semantics)
(see examples).
The trailing dimension of the index tensor contains the coordinates and is
expected to have its size equal to the number of dimensions being gathered.
This convention allows an idiomatic specification and lowering of “gathering
multiple N-D slices from the source tensor”.
Note: in the examples below, we separate out the indexing part of the tensor type by a whitespace for readability purposes.
Example:
// For each 1x2 triple of coordinates in %indices, extract the
// element (i.e. 0-D subset) at the coordinates triple in %source.
//
%out = tensor.gather %source[%indices] gather_dims([0, 1, 2]) :
(tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>
// Note: result type may be further rank-reduced to tensor<1x2x f32>.
A slice variant is provided to allow specifying whole slices of the source tensor.
Example:
// For each 5x6 singleton of coordinates in %indices, extract the 2-D
// slice %source[*, %indices[...]:%indices[...] + 1, *] with the indices
// corresponding to the `gather_dims` attribute specified by %indices.
//
%out = tensor.gather %source[%indices] gather_dims([1]) :
(tensor<3x4x5xf32>, tensor<6x7x 1xindex>) -> tensor<6x7x 3x1x5xf32>
// Note: result type may be further rank-reduced to tensor<6x7x 3x5xf32>.
The dimensions specified in the gather_dims attribute are ones for which the
result tensor has size 1
.
I.e. if the source type is axbxcxd
and the coordinates are [1, 3], then
the shape suffix is ax1xcx1
.
Gather also allows rank-reducing semantics where the shape ax1xcx1
can be
further simplified to axc
.
The elemental type of the indices tensor can be any integer type.
In the absence of target-specific or problem specific information the default
type one should use is index
.
This operation does not support unranked tensors.
An optional unique
unit attribute may be specified to indicate that the
coordinates in indices
are statically guaranteed to be unique at runtime.
Incorrectly setting the unique
attribute when the coordinates are not truly
unique is undefined behavior.
Only full slices are meant to be supported by this op, if one desires partial slices (e.g. strided windows) one should compose this op with other tensor ops (e.g. tensor.extract_slice). This is to avoid a slippery slope of complexity that would make the op unusable in practice.
At the tensor-level, the index tensor is specified in an AoS form (i.e. coordinate tuple is the most minor). It is the responsibility of further lowerings and bufferiation to implement various concrete layouts.
Note: As currently specified, the operation must lower to an abstraction that performs copies to the output tensor. This is because the buffer type system is currently not rich enough to allow multiple non-contiguous views in the same type. This is visible more clearly in a notional buffer version of the op:
// memref<?x4x1xf32> is a contiguous buffer of ?x4x1 elements.
// gather from random source slices must copy to the contiguous output.
%out = memref.gather %source[%indices] gather_dims([1]) :
(memref<4x4xf32>, memref<?x 1xindex>) -> memref<?x 4x1xf32>
// Nested buffer support would allow gather to directly index into the
// source buffer (i.e. represent a jagged view into the source).
%out = memref.gather %source[%indices] gather_dims([1]) :
(memref<4x4xf32>, memref<?x 1xindex>) -> memref<? x memref<4x1xf32>>
Implementations§
source§impl<'c> GatherOperation<'c>
impl<'c> GatherOperation<'c>
sourcepub fn as_operation(&self) -> &Operation<'c>
pub fn as_operation(&self) -> &Operation<'c>
Returns a generic operation.
sourcepub fn builder(
context: &'c Context,
location: Location<'c>
) -> GatherOperationBuilder<'c, Unset, Unset, Unset, Unset>
pub fn builder( context: &'c Context, location: Location<'c> ) -> GatherOperationBuilder<'c, Unset, Unset, Unset, Unset>
Creates a builder.