1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
use mlir_sys::{mlirStringRefEqual, MlirStringRef};
use std::{
    ffi::CStr,
    marker::PhantomData,
    slice,
    str::{self, Utf8Error},
};

/// A string reference.
#[derive(Clone, Copy, Debug)]
pub struct StringRef<'a> {
    raw: MlirStringRef,
    _parent: PhantomData<&'a str>,
}

impl<'a> StringRef<'a> {
    /// Creates a string reference.
    pub fn new(string: &'a str) -> Self {
        let string = MlirStringRef {
            data: string.as_bytes().as_ptr() as *const i8,
            length: string.len(),
        };

        unsafe { Self::from_raw(string) }
    }

    /// Converts a C-style string into a string reference.
    pub fn from_c_str(string: &'a CStr) -> Self {
        let string = MlirStringRef {
            data: string.as_ptr(),
            length: string.to_bytes_with_nul().len() - 1,
        };

        unsafe { Self::from_raw(string) }
    }

    /// Converts a string reference into a `str`.
    pub fn as_str(&self) -> Result<&'a str, Utf8Error> {
        unsafe {
            let bytes = slice::from_raw_parts(self.raw.data as *mut u8, self.raw.length);

            str::from_utf8(if bytes[bytes.len() - 1] == 0 {
                &bytes[..bytes.len() - 1]
            } else {
                bytes
            })
        }
    }

    /// Converts a string reference into a raw object.
    pub const fn to_raw(self) -> MlirStringRef {
        self.raw
    }

    /// Creates a string reference from a raw object.
    ///
    /// # Safety
    ///
    /// A raw object must be valid.
    pub unsafe fn from_raw(string: MlirStringRef) -> Self {
        Self {
            raw: string,
            _parent: Default::default(),
        }
    }
}

impl<'a> PartialEq for StringRef<'a> {
    fn eq(&self, other: &Self) -> bool {
        unsafe { mlirStringRefEqual(self.raw, other.raw) }
    }
}

impl<'a> Eq for StringRef<'a> {}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn equal() {
        assert_eq!(StringRef::new("foo"), StringRef::new("foo"));
    }

    #[test]
    fn equal_str() {
        assert_eq!(StringRef::new("foo").as_str().unwrap(), "foo");
    }

    #[test]
    fn not_equal() {
        assert_ne!(StringRef::new("foo"), StringRef::new("bar"));
    }
}