use super::OperationPassManager;
use crate::{
context::Context, ir::Module, logical_result::LogicalResult, pass::Pass, string_ref::StringRef,
Error,
};
use mlir_sys::{
mlirPassManagerAddOwnedPass, mlirPassManagerCreate, mlirPassManagerDestroy,
mlirPassManagerEnableIRPrinting, mlirPassManagerEnableVerifier,
mlirPassManagerGetAsOpPassManager, mlirPassManagerGetNestedUnder, mlirPassManagerRunOnOp,
MlirPassManager,
};
use std::marker::PhantomData;
pub struct PassManager<'c> {
raw: MlirPassManager,
_context: PhantomData<&'c Context>,
}
impl<'c> PassManager<'c> {
pub fn new(context: &Context) -> Self {
Self {
raw: unsafe { mlirPassManagerCreate(context.to_raw()) },
_context: Default::default(),
}
}
pub fn nested_under(&self, name: &str) -> OperationPassManager {
let name = StringRef::new(name);
unsafe {
OperationPassManager::from_raw(mlirPassManagerGetNestedUnder(self.raw, name.to_raw()))
}
}
pub fn add_pass(&self, pass: Pass) {
unsafe { mlirPassManagerAddOwnedPass(self.raw, pass.to_raw()) }
}
pub fn enable_verifier(&self, enabled: bool) {
unsafe { mlirPassManagerEnableVerifier(self.raw, enabled) }
}
pub fn enable_ir_printing(&self) {
unsafe { mlirPassManagerEnableIRPrinting(self.raw) }
}
pub fn run(&self, module: &mut Module) -> Result<(), Error> {
let result = LogicalResult::from_raw(unsafe {
mlirPassManagerRunOnOp(self.raw, module.as_operation().to_raw())
});
if result.is_success() {
Ok(())
} else {
Err(Error::RunPass)
}
}
pub fn as_operation_pass_manager(&self) -> OperationPassManager {
unsafe { OperationPassManager::from_raw(mlirPassManagerGetAsOpPassManager(self.raw)) }
}
}
impl<'c> Drop for PassManager<'c> {
fn drop(&mut self) {
unsafe { mlirPassManagerDestroy(self.raw) }
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
ir::{Location, Module},
pass::{self, transform::register_print_op_stats},
test::create_test_context,
utility::parse_pass_pipeline,
};
use indoc::indoc;
use pretty_assertions::assert_eq;
#[test]
fn new() {
let context = create_test_context();
PassManager::new(&context);
}
#[test]
fn add_pass() {
let context = create_test_context();
PassManager::new(&context).add_pass(pass::conversion::create_func_to_llvm());
}
#[test]
fn enable_verifier() {
let context = create_test_context();
PassManager::new(&context).enable_verifier(true);
}
#[test]
fn run() {
let context = create_test_context();
let manager = PassManager::new(&context);
manager.add_pass(pass::conversion::create_func_to_llvm());
manager
.run(&mut Module::new(Location::unknown(&context)))
.unwrap();
}
#[test]
fn run_on_function() {
let context = create_test_context();
let mut module = Module::parse(
&context,
indoc!(
"
func.func @foo(%arg0 : i32) -> i32 {
%res = arith.addi %arg0, %arg0 : i32
return %res : i32
}
"
),
)
.unwrap();
let manager = PassManager::new(&context);
manager.add_pass(pass::transform::create_print_op_stats());
assert_eq!(manager.run(&mut module), Ok(()));
}
#[test]
fn run_on_function_in_nested_module() {
let context = create_test_context();
let mut module = Module::parse(
&context,
indoc!(
"
func.func @foo(%arg0 : i32) -> i32 {
%res = arith.addi %arg0, %arg0 : i32
return %res : i32
}
module {
func.func @bar(%arg0 : f32) -> f32 {
%res = arith.addf %arg0, %arg0 : f32
return %res : f32
}
}
"
),
)
.unwrap();
let manager = PassManager::new(&context);
manager
.nested_under("func.func")
.add_pass(pass::transform::create_print_op_stats());
assert_eq!(manager.run(&mut module), Ok(()));
let manager = PassManager::new(&context);
manager
.nested_under("builtin.module")
.nested_under("func.func")
.add_pass(pass::transform::create_print_op_stats());
assert_eq!(manager.run(&mut module), Ok(()));
}
#[test]
fn print_pass_pipeline() {
let context = create_test_context();
let manager = PassManager::new(&context);
let function_manager = manager.nested_under("func.func");
function_manager.add_pass(pass::transform::create_print_op_stats());
assert_eq!(
manager.as_operation_pass_manager().to_string(),
"any(func.func(print-op-stats{json=false}))"
);
assert_eq!(
function_manager.to_string(),
"func.func(print-op-stats{json=false})"
);
}
#[test]
fn parse_pass_pipeline_() {
let context = Context::new();
let manager = PassManager::new(&context);
insta::assert_snapshot!(parse_pass_pipeline(
manager.as_operation_pass_manager(),
"builtin.module(func.func(print-op-stats{json=false}),\
func.func(print-op-stats{json=false}))"
)
.unwrap_err());
register_print_op_stats();
assert_eq!(
parse_pass_pipeline(
manager.as_operation_pass_manager(),
"builtin.module(func.func(print-op-stats{json=false}),\
func.func(print-op-stats{json=false}))"
),
Ok(())
);
assert_eq!(
manager.as_operation_pass_manager().to_string(),
"builtin.module(func.func(print-op-stats{json=false}),\
func.func(print-op-stats{json=false}))"
);
}
}