use super::{operation::OperationRefMut, BlockRef, Location, Operation, OperationRef};
use crate::{
context::{Context, ContextRef},
string_ref::StringRef,
};
use mlir_sys::{
mlirModuleCreateEmpty, mlirModuleCreateParse, mlirModuleDestroy, mlirModuleFromOperation,
mlirModuleGetBody, mlirModuleGetContext, mlirModuleGetOperation, MlirModule,
};
use std::{ffi::CString, marker::PhantomData};
#[derive(Debug)]
pub struct Module<'c> {
raw: MlirModule,
_context: PhantomData<&'c Context>,
}
impl<'c> Module<'c> {
pub fn new(location: Location) -> Self {
unsafe { Self::from_raw(mlirModuleCreateEmpty(location.to_raw())) }
}
pub fn parse(context: &Context, source: &str) -> Option<Self> {
let source = CString::new(source).unwrap();
let source = StringRef::from_c_str(&source);
unsafe { Self::from_option_raw(mlirModuleCreateParse(context.to_raw(), source.to_raw())) }
}
pub fn as_operation(&self) -> OperationRef<'c, '_> {
unsafe { OperationRef::from_raw(mlirModuleGetOperation(self.raw)) }
}
pub fn as_operation_mut(&mut self) -> OperationRefMut<'c, '_> {
unsafe { OperationRefMut::from_raw(mlirModuleGetOperation(self.raw)) }
}
pub fn context(&self) -> ContextRef<'c> {
unsafe { ContextRef::from_raw(mlirModuleGetContext(self.raw)) }
}
pub fn body(&self) -> BlockRef<'c, '_> {
unsafe { BlockRef::from_raw(mlirModuleGetBody(self.raw)) }
}
pub fn from_operation(operation: Operation) -> Option<Self> {
unsafe { Self::from_option_raw(mlirModuleFromOperation(operation.into_raw())) }
}
pub unsafe fn from_raw(raw: MlirModule) -> Self {
Self {
raw,
_context: Default::default(),
}
}
pub unsafe fn from_option_raw(raw: MlirModule) -> Option<Self> {
if raw.ptr.is_null() {
None
} else {
Some(Self::from_raw(raw))
}
}
pub const fn to_raw(&self) -> MlirModule {
self.raw
}
}
impl<'c> Drop for Module<'c> {
fn drop(&mut self) {
unsafe { mlirModuleDestroy(self.raw) };
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
ir::{attribute::StringAttribute, operation::OperationBuilder, Block, Region},
test::create_test_context,
};
#[test]
fn new() {
Module::new(Location::new(&Context::new(), "foo", 42, 42));
}
#[test]
fn context() {
Module::new(Location::new(&Context::new(), "foo", 42, 42)).context();
}
#[test]
fn parse() {
assert!(Module::parse(&Context::new(), "module{}").is_some());
}
#[test]
fn parse_none() {
assert!(Module::parse(&Context::new(), "module{").is_none());
}
#[test]
fn from_operation() {
let context = create_test_context();
let region = Region::new();
region.append_block(Block::new(&[]));
let module = Module::from_operation(
OperationBuilder::new("builtin.module", Location::unknown(&context))
.add_regions([region])
.build()
.unwrap(),
)
.unwrap();
assert!(module.as_operation().verify());
assert_eq!(module.as_operation().to_string(), "module {\n}\n")
}
#[test]
fn from_operation_fail() {
let context = create_test_context();
assert!(Module::from_operation(
OperationBuilder::new("func.func", Location::unknown(&context),)
.build()
.unwrap()
)
.is_none());
}
#[test]
fn set_attribute() {
let context = create_test_context();
let mut module = Module::new(Location::unknown(&context));
module
.as_operation_mut()
.set_attribute("sym_name", StringAttribute::new(&context, "foo").into());
assert!(module.as_operation().verify());
}
}