pub struct BinaryOperation<'c> { /* private fields */ }
Expand description

A binary operation. Binary set operation utilized within linalg.generic.

Defines a computation within a linalg.generic operation that takes two operands and executes one of the regions depending on whether both operands or either operand is nonzero (i.e. stored explicitly in the sparse storage format).

Three regions are defined for the operation and must appear in this order:

  • overlap (elements present in both sparse tensors)
  • left (elements only present in the left sparse tensor)
  • right (element only present in the right sparse tensor)

Each region contains a single block describing the computation and result. Every non-empty block must end with a sparse_tensor.yield and the return type must match the type of output. The primary region’s block has two arguments, while the left and right region’s block has only one argument.

A region may also be declared empty (i.e. left={}), indicating that the region does not contribute to the output. For example, setting both left={} and right={} is equivalent to the intersection of the two inputs as only the overlap region will contribute values to the output.

As a convenience, there is also a special token identity which can be used in place of the left or right region. This token indicates that the return value is the input value (i.e. func(%x) => return %x). As a practical example, setting left=identity and right=identity would be equivalent to a union operation where non-overlapping values in the inputs are copied to the output unchanged.

Due to the possibility of empty regions, i.e. lack of a value for certain cases, the result of this operation may only feed directly into the output of the linalg.generic operation or into into a custom reduction sparse_tensor.reduce operation that follows in the same region.

Example of isEqual applied to intersecting elements only:

%C = bufferization.alloc_tensor...
%0 = linalg.generic #trait
  ins(%A: tensor<?xf64, #SparseVector>,
      %B: tensor<?xf64, #SparseVector>)
  outs(%C: tensor<?xi8, #SparseVector>) {
  ^bb0(%a: f64, %b: f64, %c: i8) :
    %result = sparse_tensor.binary %a, %b : f64, f64 to i8
      overlap={
        ^bb0(%arg0: f64, %arg1: f64):
          %cmp = arith.cmpf "oeq", %arg0, %arg1 : f64
          %ret_i8 = arith.extui %cmp : i1 to i8
          sparse_tensor.yield %ret_i8 : i8
      }
      left={}
      right={}
    linalg.yield %result : i8
} -> tensor<?xi8, #SparseVector>

Example of A+B in upper triangle, A-B in lower triangle:

%C = bufferization.alloc_tensor...
%1 = linalg.generic #trait
  ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64, #CSR>
  outs(%C: tensor<?x?xf64, #CSR> {
  ^bb0(%a: f64, %b: f64, %c: f64) :
    %row = linalg.index 0 : index
    %col = linalg.index 1 : index
    %result = sparse_tensor.binary %a, %b : f64, f64 to f64
      overlap={
        ^bb0(%x: f64, %y: f64):
          %cmp = arith.cmpi "uge", %col, %row : index
          %upperTriangleResult = arith.addf %x, %y : f64
          %lowerTriangleResult = arith.subf %x, %y : f64
          %ret = arith.select %cmp, %upperTriangleResult, %lowerTriangleResult : f64
          sparse_tensor.yield %ret : f64
      }
      left=identity
      right={
        ^bb0(%y: f64):
          %cmp = arith.cmpi "uge", %col, %row : index
          %lowerTriangleResult = arith.negf %y : f64
          %ret = arith.select %cmp, %y, %lowerTriangleResult : f64
          sparse_tensor.yield %ret : f64
      }
    linalg.yield %result : f64
} -> tensor<?x?xf64, #CSR>

Example of set difference. Returns a copy of A where its sparse structure is not overlapped by B. The element type of B can be different than A because we never use its values, only its sparse structure:

%C = bufferization.alloc_tensor...
%2 = linalg.generic #trait
  ins(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xi32, #CSR>
  outs(%C: tensor<?x?xf64, #CSR> {
  ^bb0(%a: f64, %b: i32, %c: f64) :
    %result = sparse_tensor.binary %a, %b : f64, i32 to f64
      overlap={}
      left=identity
      right={}
    linalg.yield %result : f64
} -> tensor<?x?xf64, #CSR>

Implementations§

source§

impl<'c> BinaryOperation<'c>

source

pub fn name() -> &'static str

Returns a name.

source

pub fn as_operation(&self) -> &Operation<'c>

Returns a generic operation.

source

pub fn builder( context: &'c Context, location: Location<'c> ) -> BinaryOperationBuilder<'c, Unset, Unset, Unset, Unset, Unset, Unset>

Creates a builder.

source

pub fn output(&self) -> Result<OperationResult<'c, '_>, Error>

source

pub fn x(&self) -> Result<Value<'c, '_>, Error>

source

pub fn y(&self) -> Result<Value<'c, '_>, Error>

source

pub fn overlap_region(&self) -> Result<RegionRef<'c, '_>, Error>

source

pub fn left_region(&self) -> Result<RegionRef<'c, '_>, Error>

source

pub fn right_region(&self) -> Result<RegionRef<'c, '_>, Error>

source

pub fn left_identity(&self) -> Result<Attribute<'c>, Error>

source

pub fn set_left_identity(&mut self, value: Attribute<'c>)

source

pub fn remove_left_identity(&mut self) -> Result<(), Error>

source

pub fn right_identity(&self) -> Result<Attribute<'c>, Error>

source

pub fn set_right_identity(&mut self, value: Attribute<'c>)

source

pub fn remove_right_identity(&mut self) -> Result<(), Error>

Trait Implementations§

source§

impl<'c> From<BinaryOperation<'c>> for Operation<'c>

source§

fn from(operation: BinaryOperation<'c>) -> Self

Converts to this type from the input type.
source§

impl<'c> TryFrom<Operation<'c>> for BinaryOperation<'c>

§

type Error = Error

The type returned in the event of a conversion error.
source§

fn try_from(operation: Operation<'c>) -> Result<Self, Self::Error>

Performs the conversion.

Auto Trait Implementations§

§

impl<'c> RefUnwindSafe for BinaryOperation<'c>

§

impl<'c> !Send for BinaryOperation<'c>

§

impl<'c> !Sync for BinaryOperation<'c>

§

impl<'c> Unpin for BinaryOperation<'c>

§

impl<'c> UnwindSafe for BinaryOperation<'c>

Blanket Implementations§

source§

impl<T> Any for T
where T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for T
where T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T, U> Into<U> for T
where U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.