pub struct GatherOperation<'c> { /* private fields */ }
Expand description

A gather operation. Gather a subset of a tensor at specified indices.

The gather operation extracts a subset of the elements from a source tensor at the given indices.

In its most general form, the tensor of indices specifies all the coordinates of every element to extract (i.e. COO format, without the payload). The indices are expected to be confined to coordinate values that fit the range of the source tensor, otherwise the behavior is undefined.

The leading dimensions of the index tensor give the result tensor its leading dimensions. The trailing dimensions of the result tensor are obtained from the source tensor by omitting the dimensions specified in gather_dims (rank-reducing semantics) or setting them to 1 (rank-preserving semantics) (see examples). The trailing dimension of the index tensor contains the coordinates and is expected to have its size equal to the number of dimensions being gathered. This convention allows an idiomatic specification and lowering of “gathering multiple N-D slices from the source tensor”.

Note: in the examples below, we separate out the indexing part of the tensor type by a whitespace for readability purposes.

Example:

    // For each 1x2 triple of coordinates in %indices, extract the
    // element (i.e. 0-D subset) at the coordinates triple in %source.
    //
    %out = tensor.gather %source[%indices] gather_dims([0, 1, 2]) :
      (tensor<4x4x4xf32>, tensor<1x2x 3xindex>) -> tensor<1x2x 1x1x1xf32>

    // Note: result type may be further rank-reduced to tensor<1x2x f32>.

A slice variant is provided to allow specifying whole slices of the source tensor.

Example:

    // For each 5x6 singleton of coordinates in %indices, extract the 2-D
    // slice %source[*, %indices[...]:%indices[...] + 1, *] with the indices
    // corresponding to the `gather_dims` attribute specified by %indices.
    //
    %out = tensor.gather %source[%indices] gather_dims([1]) :
      (tensor<3x4x5xf32>, tensor<6x7x 1xindex>) -> tensor<6x7x 3x1x5xf32>

    // Note: result type may be further rank-reduced to tensor<6x7x 3x5xf32>.

The dimensions specified in the gather_dims attribute are ones for which the result tensor has size 1. I.e. if the source type is axbxcxd and the coordinates are [1, 3], then the shape suffix is ax1xcx1. Gather also allows rank-reducing semantics where the shape ax1xcx1 can be further simplified to axc.

The elemental type of the indices tensor can be any integer type. In the absence of target-specific or problem specific information the default type one should use is index.

This operation does not support unranked tensors.

An optional unique unit attribute may be specified to indicate that the coordinates in indices are statically guaranteed to be unique at runtime. Incorrectly setting the unique attribute when the coordinates are not truly unique is undefined behavior.

Only full slices are meant to be supported by this op, if one desires partial slices (e.g. strided windows) one should compose this op with other tensor ops (e.g. tensor.extract_slice). This is to avoid a slippery slope of complexity that would make the op unusable in practice.

At the tensor-level, the index tensor is specified in an AoS form (i.e. coordinate tuple is the most minor). It is the responsibility of further lowerings and bufferiation to implement various concrete layouts.

Note: As currently specified, the operation must lower to an abstraction that performs copies to the output tensor. This is because the buffer type system is currently not rich enough to allow multiple non-contiguous views in the same type. This is visible more clearly in a notional buffer version of the op:

    // memref<?x4x1xf32> is a contiguous buffer of ?x4x1 elements.
    // gather from random source slices must copy to the contiguous output.
    %out = memref.gather %source[%indices] gather_dims([1]) :
      (memref<4x4xf32>, memref<?x 1xindex>) -> memref<?x 4x1xf32>

    // Nested buffer support would allow gather to directly index into the
    // source buffer (i.e. represent a jagged view into the source).
    %out = memref.gather %source[%indices] gather_dims([1]) :
      (memref<4x4xf32>, memref<?x 1xindex>) -> memref<? x memref<4x1xf32>>

Implementations§

source§

impl<'c> GatherOperation<'c>

source

pub fn name() -> &'static str

Returns a name.

source

pub fn as_operation(&self) -> &Operation<'c>

Returns a generic operation.

source

pub fn builder( context: &'c Context, location: Location<'c> ) -> GatherOperationBuilder<'c, Unset, Unset, Unset, Unset>

Creates a builder.

source

pub fn result(&self) -> Result<OperationResult<'c, '_>, Error>

source

pub fn source(&self) -> Result<Value<'c, '_>, Error>

source

pub fn indices(&self) -> Result<Value<'c, '_>, Error>

source

pub fn gather_dims(&self) -> Result<Attribute<'c>, Error>

source

pub fn set_gather_dims(&mut self, value: Attribute<'c>)

source

pub fn unique(&self) -> Result<Attribute<'c>, Error>

source

pub fn set_unique(&mut self, value: Attribute<'c>)

source

pub fn remove_unique(&mut self) -> Result<(), Error>

Trait Implementations§

source§

impl<'c> From<GatherOperation<'c>> for Operation<'c>

source§

fn from(operation: GatherOperation<'c>) -> Self

Converts to this type from the input type.
source§

impl<'c> TryFrom<Operation<'c>> for GatherOperation<'c>

§

type Error = Error

The type returned in the event of a conversion error.
source§

fn try_from(operation: Operation<'c>) -> Result<Self, Self::Error>

Performs the conversion.

Auto Trait Implementations§

§

impl<'c> RefUnwindSafe for GatherOperation<'c>

§

impl<'c> !Send for GatherOperation<'c>

§

impl<'c> !Sync for GatherOperation<'c>

§

impl<'c> Unpin for GatherOperation<'c>

§

impl<'c> UnwindSafe for GatherOperation<'c>

Blanket Implementations§

source§

impl<T> Any for T
where T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for T
where T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T, U> Into<U> for T
where U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.