pub struct ScatterOperation<'c> { /* private fields */ }
Expand description
A scatter
operation. Scatter a tensor into a destination tensor at specified indices.
The scatter
operation inserts a source
tensor into a dest
tensor at
the given indices.
In its most general form, the tensor of indices specifies all the coordinates
of every element to insert (i.e. COO format, without the payload).
The indices are expected to be confined to coordinate values that fit the
range of the dest
tensor, otherwise the behavior is undefined.
The leading dimensions of the index tensor must match that of the dest
tensor. The trailing dimensions of the dest tensor must match those of the
source tensor by omitting the dimensions specified in scatter_dims
(rank-reducing semantics) or setting them to 1
(rank-preserving semantics)
(see examples).
This convention allows an idiomatic specification and lowering of
“scattering multiple N-D slices into the dest tensor”.
The result type must match the type of the dest tensor.
Note: in the examples below, we separate out the indexing part of the tensor type by a whitespace for readability purposes.
Example:
// For each 1x2 triple of coordinates in %indices, insert the
// element (i.e. 0-D subset) at the coordinates triple in %dest.
//
%out = tensor.scatter %source into %dest[%indices]
scatter_dims([0, 1, 2]) unique :
(tensor<1x2x 1x1x1xf32>, tensor<4x4x4xf32>, tensor<1x2x 3xindex>)
-> tensor<4x4x4xf32>
// Note: source type may be further rank-reduced to tensor<1x2x f32>.
A slice variant is provided to allow specifying insertion of whole tensor
slices into the dest
tensor.
Example:
// For each 3 singleton of coordinates in %indices, insert the 2-D
// slice into %dest[*, %indices[...]:%indices[...] + 1, *] with the
// indices corresponding to the scatter_dims attribute specified by
// %indices.
//
%out = tensor.scatter %source into %dest[%indices] scatter_dims([1]) unique :
(tensor<3x 4x1x6xf32>, tensor<4x5x6xf32>, tensor<3x 1xindex>)
-> tensor<4x5x6xf32>
The dimensions specified in the scatter_dims attribute are ones for which the
source tensor has size 1
.
I.e. if the dest type is axbxcxd
and the coordinates are [1, 3], then
the source type suffix is ax1xcx1
.
Sactter also allows rank-reducing semantics where the shape ax1xcx1
can be
further simplified to axc
.
The elemental type of the indices tensor can be any integer type.
In the absence of target-specific or problem specific information the default
type one should use is index
.
This operation does not support unranked tensors.
A unique
unit attribute must be be specified to indicate that the
coordinates are statically guaranteed to be unique at runtime. If coordinates
are not truly unique at runtime, the behavior is undefined.
Only full slices are meant to be supported by this op, if one desires partial slices (e.g. strided windows) one should compose this op with other tensor ops (e.g. tensor.insert_slice). This is to avoid a slippery slope of complexity that would make the op unusable in practice.
At the tensor-level, the index tensor is specified in an AoS form (i.e. coordinate tuple is the most minor). It is the responsibility of further lowerings and bufferiation to implement various concrete layouts.
Note: As currently specified, the operation must lower to an abstraction that performs copies to the output tensor. This is because the buffer type system is currently not rich enough to allow multiple non-contiguous views in the same type. This is visible more clearly in a notional buffer version of the op:
// memref<?x 4xf32> is a contiguous buffer of ?x4 elements, scatter into
// random dest slices must copy to the contiguous dest.
//
some_side_effecting_op_writing_into %source, ...: memref<3x 4xf32>
memref.scatter %source into %dest[%indices] scatter_dims([1]) unique :
(memref<3x 4xf32>, memref<?x 4xf32>, memref<?x 1xindex>)
// Nested buffer support in the producing op would allow writing directly
// into the dest buffer.
%v = some_nested_buffer_view_op %dest[%indices] scatter_dims([1]) unique :
memref<? x memref<4xf32>>
some_side_effecting_op_writing_into %v, ...: memref<? x memref<4xf32>>
Implementations§
source§impl<'c> ScatterOperation<'c>
impl<'c> ScatterOperation<'c>
sourcepub fn as_operation(&self) -> &Operation<'c>
pub fn as_operation(&self) -> &Operation<'c>
Returns a generic operation.
sourcepub fn builder(
context: &'c Context,
location: Location<'c>
) -> ScatterOperationBuilder<'c, Unset, Unset, Unset, Unset, Unset>
pub fn builder( context: &'c Context, location: Location<'c> ) -> ScatterOperationBuilder<'c, Unset, Unset, Unset, Unset, Unset>
Creates a builder.