use crate::{
diagnostic::{Diagnostic, DiagnosticHandlerId},
dialect::{Dialect, DialectRegistry},
logical_result::LogicalResult,
string_ref::StringRef,
};
use mlir_sys::{
mlirContextAppendDialectRegistry, mlirContextAttachDiagnosticHandler, mlirContextCreate,
mlirContextDestroy, mlirContextDetachDiagnosticHandler, mlirContextEnableMultithreading,
mlirContextEqual, mlirContextGetAllowUnregisteredDialects, mlirContextGetNumLoadedDialects,
mlirContextGetNumRegisteredDialects, mlirContextGetOrLoadDialect,
mlirContextIsRegisteredOperation, mlirContextLoadAllAvailableDialects,
mlirContextSetAllowUnregisteredDialects, MlirContext, MlirDiagnostic, MlirLogicalResult,
};
use std::{ffi::c_void, marker::PhantomData};
#[derive(Debug)]
pub struct Context {
raw: MlirContext,
}
impl Context {
pub fn new() -> Self {
Self {
raw: unsafe { mlirContextCreate() },
}
}
pub fn registered_dialect_count(&self) -> usize {
unsafe { mlirContextGetNumRegisteredDialects(self.raw) as usize }
}
pub fn loaded_dialect_count(&self) -> usize {
unsafe { mlirContextGetNumLoadedDialects(self.raw) as usize }
}
pub fn get_or_load_dialect(&self, name: &str) -> Dialect {
let name = StringRef::new(name);
unsafe { Dialect::from_raw(mlirContextGetOrLoadDialect(self.raw, name.to_raw())) }
}
pub fn append_dialect_registry(&self, registry: &DialectRegistry) {
unsafe { mlirContextAppendDialectRegistry(self.raw, registry.to_raw()) }
}
pub fn load_all_available_dialects(&self) {
unsafe { mlirContextLoadAllAvailableDialects(self.raw) }
}
pub fn enable_multi_threading(&self, enabled: bool) {
unsafe { mlirContextEnableMultithreading(self.raw, enabled) }
}
pub fn allow_unregistered_dialects(&self) -> bool {
unsafe { mlirContextGetAllowUnregisteredDialects(self.raw) }
}
pub fn set_allow_unregistered_dialects(&self, allowed: bool) {
unsafe { mlirContextSetAllowUnregisteredDialects(self.raw, allowed) }
}
pub fn is_registered_operation(&self, name: &str) -> bool {
let name = StringRef::new(name);
unsafe { mlirContextIsRegisteredOperation(self.raw, name.to_raw()) }
}
pub const fn to_raw(&self) -> MlirContext {
self.raw
}
pub fn attach_diagnostic_handler<F: FnMut(Diagnostic) -> bool>(
&self,
handler: F,
) -> DiagnosticHandlerId {
unsafe extern "C" fn handle<F: FnMut(Diagnostic) -> bool>(
diagnostic: MlirDiagnostic,
user_data: *mut c_void,
) -> MlirLogicalResult {
LogicalResult::from((*(user_data as *mut F))(Diagnostic::from_raw(diagnostic))).to_raw()
}
unsafe extern "C" fn destroy<F: FnMut(Diagnostic) -> bool>(user_data: *mut c_void) {
drop(Box::from_raw(user_data as *mut F));
}
unsafe {
DiagnosticHandlerId::from_raw(mlirContextAttachDiagnosticHandler(
self.to_raw(),
Some(handle::<F>),
Box::into_raw(Box::new(handler)) as *mut _,
Some(destroy::<F>),
))
}
}
pub fn detach_diagnostic_handler(&self, id: DiagnosticHandlerId) {
unsafe { mlirContextDetachDiagnosticHandler(self.to_raw(), id.to_raw()) }
}
pub(crate) fn to_ref(&self) -> ContextRef {
unsafe { ContextRef::from_raw(self.to_raw()) }
}
}
impl Drop for Context {
fn drop(&mut self) {
unsafe { mlirContextDestroy(self.raw) };
}
}
impl Default for Context {
fn default() -> Self {
Self::new()
}
}
impl PartialEq for Context {
fn eq(&self, other: &Self) -> bool {
unsafe { mlirContextEqual(self.raw, other.raw) }
}
}
impl<'a> PartialEq<ContextRef<'a>> for Context {
fn eq(&self, &other: &ContextRef<'a>) -> bool {
self.to_ref() == other
}
}
impl Eq for Context {}
#[derive(Clone, Copy, Debug)]
pub struct ContextRef<'c> {
raw: MlirContext,
_reference: PhantomData<&'c Context>,
}
impl<'c> ContextRef<'c> {
pub unsafe fn from_raw(raw: MlirContext) -> Self {
Self {
raw,
_reference: Default::default(),
}
}
}
impl<'c> PartialEq for ContextRef<'c> {
fn eq(&self, other: &Self) -> bool {
unsafe { mlirContextEqual(self.raw, other.raw) }
}
}
impl<'c> PartialEq<Context> for ContextRef<'c> {
fn eq(&self, other: &Context) -> bool {
self == &other.to_ref()
}
}
impl<'c> Eq for ContextRef<'c> {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new() {
Context::new();
}
#[test]
fn registered_dialect_count() {
let context = Context::new();
assert_eq!(context.registered_dialect_count(), 1);
}
#[test]
fn loaded_dialect_count() {
let context = Context::new();
assert_eq!(context.loaded_dialect_count(), 1);
}
#[test]
fn append_dialect_registry() {
let context = Context::new();
context.append_dialect_registry(&DialectRegistry::new());
}
#[test]
fn is_registered_operation() {
let context = Context::new();
assert!(context.is_registered_operation("builtin.module"));
}
#[test]
fn is_not_registered_operation() {
let context = Context::new();
assert!(!context.is_registered_operation("func.func"));
}
#[test]
fn enable_multi_threading() {
let context = Context::new();
context.enable_multi_threading(true);
}
#[test]
fn disable_multi_threading() {
let context = Context::new();
context.enable_multi_threading(false);
}
#[test]
fn allow_unregistered_dialects() {
let context = Context::new();
assert!(!context.allow_unregistered_dialects());
}
#[test]
fn set_allow_unregistered_dialects() {
let context = Context::new();
context.set_allow_unregistered_dialects(true);
assert!(context.allow_unregistered_dialects());
}
#[test]
fn attach_and_detach_diagnostic_handler() {
let context = Context::new();
let id = context.attach_diagnostic_handler(|diagnostic| {
println!("{}", diagnostic);
true
});
context.detach_diagnostic_handler(id);
}
#[test]
fn compare_contexts() {
let one = Context::new();
let other = Context::new();
assert_eq!(&one, &one);
assert_ne!(&one, &other);
assert_ne!(&other, &one);
assert_eq!(&other, &other);
}
#[test]
fn compare_context_refs() {
let one = Context::new();
let other = Context::new();
let one_ref = one.to_ref();
let other_ref = other.to_ref();
assert_eq!(&one, &one_ref);
assert_eq!(&one_ref, &one);
assert_eq!(&other, &other_ref);
assert_eq!(&other_ref, &other);
assert_ne!(&one, &other_ref);
assert_ne!(&other_ref, &one);
assert_ne!(&other, &one_ref);
assert_ne!(&one_ref, &other);
}
}