pub struct ContractionOperation<'c> { /* private fields */ }
Expand description

A contract operation. Vector contraction operation.

Computes the sum of products of vector elements along contracting dimension pairs from 2 vectors of rank M and N respectively, adds this intermediate result to the accumulator argument of rank K, and returns a vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims + num_batch_dims (see dimension type descriptions below)). For K = 0 (no free or batch dimensions), the accumulator and output are a scalar.

If operands and the result have types of different bitwidths, operands are promoted to have the same bitwidth as the result before performing the contraction. For integer types, only signless integer types are supported, and the promotion happens via sign extension.

An iterator type attribute list must be specified, where each element of the list represents an iterator with one of the following types:

  • “reduction”: reduction dimensions are present in the lhs and rhs arguments but not in the output (and accumulator argument). These are the dimensions along which the vector contraction op computes the sum of products, and contracting dimension pair dimension sizes must match between lhs/rhs.

  • “parallel”: Batch dimensions are iterator type “parallel”, and are non-contracting dimensions present in the lhs, rhs and output. The lhs/rhs co-iterate along the batch dimensions, which should be expressed in their indexing maps.

    Free dimensions are iterator type “parallel”, and are non-contraction, non-batch dimensions accessed by either the lhs or rhs (but not both). The lhs and rhs free dimensions are unrelated to each other and do not co-iterate, which should be expressed in their indexing maps.

An indexing map attribute list must be specified with an entry for lhs, rhs and acc arguments. An indexing map attribute specifies a mapping from each iterator in the iterator type list, to each dimension of an N-D vector.

An optional kind attribute may be used to specify the combining function between the intermediate result and accumulator argument of rank K. This attribute can take the values add/mul/min/max for int/fp, and/or/xor for int only. The default is “add”.

Example:

// Simple DOT product (K = 0).
#contraction_accesses = [
 affine_map<(i) -> (i)>,
 affine_map<(i) -> (i)>,
 affine_map<(i) -> ()>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = ["reduction"]
}
%3 = vector.contract #contraction_trait %0, %1, %2
  : vector<10xf32>, vector<10xf32> into f32

// 2D vector contraction with one contracting dimension (matmul, K = 2).
#contraction_accesses = [
  affine_map<(i, j, k) -> (i, k)>,
  affine_map<(i, j, k) -> (k, j)>,
  affine_map<(i, j, k) -> (i, j)>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = ["parallel", "parallel", "reduction"]
}

%3 = vector.contract #contraction_trait %0, %1, %2
  : vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>

// 4D to 3D vector contraction with two contracting dimensions and
// one batch dimension (K = 3).
#contraction_accesses = [
  affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>,
  affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>,
  affine_map<(b0, f0, f1, c0, c1) -> (b0, f0, f1)>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = ["parallel", "parallel", "parallel",
                    "reduction", "reduction"]
}

%4 = vector.contract #contraction_trait %0, %1, %2
    : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>

// Vector contraction with mixed typed. lhs/rhs have different element
// types than accumulator/result.
%5 = vector.contract #contraction_trait %0, %1, %2
  : vector<10xf16>, vector<10xf16> into f32

// Contract with max (K = 0).
#contraction_accesses = [
 affine_map<(i) -> (i)>,
 affine_map<(i) -> (i)>,
 affine_map<(i) -> ()>
]
#contraction_trait = {
  indexing_maps = #contraction_accesses,
  iterator_types = ["reduction"],
  kind = #vector.kind<max>
}
%6 = vector.contract #contraction_trait %0, %1, %2
  : vector<10xf32>, vector<10xf32> into f32

Implementations§

source§

impl<'c> ContractionOperation<'c>

source

pub fn name() -> &'static str

Returns a name.

source

pub fn as_operation(&self) -> &Operation<'c>

Returns a generic operation.

source

pub fn builder( context: &'c Context, location: Location<'c> ) -> ContractionOperationBuilder<'c, Unset, Unset, Unset, Unset, Unset>

Creates a builder.

source

pub fn lhs(&self) -> Result<Value<'c, '_>, Error>

source

pub fn rhs(&self) -> Result<Value<'c, '_>, Error>

source

pub fn acc(&self) -> Result<Value<'c, '_>, Error>

source

pub fn indexing_maps(&self) -> Result<ArrayAttribute<'c>, Error>

source

pub fn set_indexing_maps(&mut self, value: ArrayAttribute<'c>)

source

pub fn iterator_types(&self) -> Result<ArrayAttribute<'c>, Error>

source

pub fn set_iterator_types(&mut self, value: ArrayAttribute<'c>)

source

pub fn kind(&self) -> Result<Attribute<'c>, Error>

source

pub fn set_kind(&mut self, value: Attribute<'c>)

source

pub fn remove_kind(&mut self) -> Result<(), Error>

Trait Implementations§

source§

impl<'c> From<ContractionOperation<'c>> for Operation<'c>

source§

fn from(operation: ContractionOperation<'c>) -> Self

Converts to this type from the input type.
source§

impl<'c> TryFrom<Operation<'c>> for ContractionOperation<'c>

§

type Error = Error

The type returned in the event of a conversion error.
source§

fn try_from(operation: Operation<'c>) -> Result<Self, Self::Error>

Performs the conversion.

Auto Trait Implementations§

Blanket Implementations§

source§

impl<T> Any for T
where T: 'static + ?Sized,

source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
source§

impl<T> Borrow<T> for T
where T: ?Sized,

source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
source§

impl<T> From<T> for T

source§

fn from(t: T) -> T

Returns the argument unchanged.

source§

impl<T, U> Into<U> for T
where U: From<T>,

source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

§

type Error = Infallible

The type returned in the event of a conversion error.
source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.