pub struct ScatterOperation<'c> { /* private fields */ }
Expand description

A scatter operation. Scatter a tensor into a destination tensor at specified indices.

The scatter operation inserts a source tensor into a dest tensor at the given indices.

In its most general form, the tensor of indices specifies all the coordinates of every element to insert (i.e. COO format, without the payload). The indices are expected to be confined to coordinate values that fit the range of the dest tensor, otherwise the behavior is undefined.

The leading dimensions of the index tensor must match that of the dest tensor. The trailing dimensions of the dest tensor must match those of the source tensor by omitting the dimensions specified in scatter_dims (rank-reducing semantics) or setting them to 1 (rank-preserving semantics) (see examples). This convention allows an idiomatic specification and lowering of “scattering multiple N-D slices into the dest tensor”. The result type must match the type of the dest tensor.

Note: in the examples below, we separate out the indexing part of the tensor type by a whitespace for readability purposes.

Example:

    // For each 1x2 triple of coordinates in %indices, insert the
    // element (i.e. 0-D subset) at the coordinates triple in %dest.
    //
    %out = tensor.scatter %source into %dest[%indices]
        scatter_dims([0, 1, 2]) unique :
      (tensor<1x2x 1x1x1xf32>, tensor<4x4x4xf32>, tensor<1x2x 3xindex>)
        -> tensor<4x4x4xf32>

    // Note: source type may be further rank-reduced to tensor<1x2x f32>.

A slice variant is provided to allow specifying insertion of whole tensor slices into the dest tensor.

Example:

    // For each 3 singleton of coordinates in %indices, insert the 2-D
    // slice into %dest[*, %indices[...]:%indices[...] + 1, *] with the
    // indices corresponding to the scatter_dims attribute specified by
    // %indices.
    //
    %out = tensor.scatter %source into %dest[%indices] scatter_dims([1]) unique :
      (tensor<3x 4x1x6xf32>, tensor<4x5x6xf32>, tensor<3x 1xindex>)
        -> tensor<4x5x6xf32>

The dimensions specified in the scatter_dims attribute are ones for which the source tensor has size 1. I.e. if the dest type is axbxcxd and the coordinates are [1, 3], then the source type suffix is ax1xcx1. Sactter also allows rank-reducing semantics where the shape ax1xcx1 can be further simplified to axc.

The elemental type of the indices tensor can be any integer type. In the absence of target-specific or problem specific information the default type one should use is index.

This operation does not support unranked tensors.

A unique unit attribute must be be specified to indicate that the coordinates are statically guaranteed to be unique at runtime. If coordinates are not truly unique at runtime, the behavior is undefined.

Only full slices are meant to be supported by this op, if one desires partial slices (e.g. strided windows) one should compose this op with other tensor ops (e.g. tensor.insert_slice). This is to avoid a slippery slope of complexity that would make the op unusable in practice.

At the tensor-level, the index tensor is specified in an AoS form (i.e. coordinate tuple is the most minor). It is the responsibility of further lowerings and bufferiation to implement various concrete layouts.

Note: As currently specified, the operation must lower to an abstraction that performs copies to the output tensor. This is because the buffer type system is currently not rich enough to allow multiple non-contiguous views in the same type. This is visible more clearly in a notional buffer version of the op:

    // memref<?x 4xf32> is a contiguous buffer of ?x4 elements, scatter into
    // random dest slices must copy to the contiguous dest.
    //
    some_side_effecting_op_writing_into %source, ...: memref<3x 4xf32>
    memref.scatter %source into %dest[%indices] scatter_dims([1]) unique :
      (memref<3x 4xf32>, memref<?x 4xf32>, memref<?x 1xindex>)

    // Nested buffer support in the producing op would allow writing directly
    // into the dest buffer.
    %v = some_nested_buffer_view_op %dest[%indices] scatter_dims([1]) unique :
      memref<? x memref<4xf32>>
    some_side_effecting_op_writing_into %v, ...: memref<? x memref<4xf32>>

Implementations§

source§

impl<'c> ScatterOperation<'c>

source

pub fn name() -> &'static str

Returns a name.

source

pub fn as_operation(&self) -> &Operation<'c>

Returns a generic operation.

source

pub fn builder( context: &'c Context, location: Location<'c> ) -> ScatterOperationBuilder<'c, Unset, Unset, Unset, Unset, Unset>

Creates a builder.

source

pub fn result(&self) -> Result<OperationResult<'c, '_>, Error>

source

pub fn source(&self) -> Result<Value<'c, '_>, Error>

source

pub fn dest(&self) -> Result<Value<'c, '_>, Error>

source

pub fn indices(&self) -> Result<Value<'c, '_>, Error>

source

pub fn scatter_dims(&self) -> Result<Attribute<'c>, Error>

source

pub fn set_scatter_dims(&mut self, value: Attribute<'c>)

source

pub fn unique(&self) -> Result<Attribute<'c>, Error>

source

pub fn set_unique(&mut self, value: Attribute<'c>)

source

pub fn remove_unique(&mut self) -> Result<(), Error>

Trait Implementations§

source§

impl<'c> From<ScatterOperation<'c>> for Operation<'c>

source§

fn from(operation: ScatterOperation<'c>) -> Self

Converts to this type from the input type.
source§

impl<'c> TryFrom<Operation<'c>> for ScatterOperation<'c>

§

type Error = Error

The type returned in the event of a conversion error.
source§

fn try_from(operation: Operation<'c>) -> Result<Self, Self::Error>

Performs the conversion.

Auto Trait Implementations§

§

impl<'c> RefUnwindSafe for ScatterOperation<'c>

§

impl<'c> !Send for ScatterOperation<'c>

§

impl<'c> !Sync for ScatterOperation<'c>

§

impl<'c> Unpin for ScatterOperation<'c>

§

impl<'c> UnwindSafe for ScatterOperation<'c>

Blanket Implementations§

source§

impl<T> Any for T
where T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for T
where T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T, U> Into<U> for T
where U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.