Skip to content

CStr8 cleanup and enhancements #506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
- Implemented `core::fmt::Write` for the `Serial` protocol.
- Added the `MemoryProtection` protocol.
- Added `BootServices::get_handle_for_protocol`.
- Added trait `EqStrUntilNul` and implemented it for `CStr16` and `CString16`.
Now you can compare everything that is `AsRef<str>` (such as `String` and `str`
from the standard library) to uefi strings. Please head to the documentation of
`EqStrUntilNul` to find out limitations and further information.
- Added trait `EqStrUntilNul` and implemented it for `CStr8`, `CStr16`, and `CString16`
(CString8 doesn't exist yet). Now you can compare everything that is `AsRef<str>`
(such as `String` and `str` from the standard library) to UEFI strings. Please head to the
documentation of `EqStrUntilNul` to find out limitations and further information.
- Added `BootServices::image_handle` to get the handle of the executing
image. The image is set automatically by the `#[entry]` macro; if a
program does not use that macro then it should call
Expand All @@ -27,6 +27,7 @@
- Added `DiskIo` and `DiskIo2` protocols.
- Added `HardDriveMediaDevicePath` and related types.
- Added `PartialOrd` and `Ord` to the traits derived by `Guid`.
- Added `TryFrom<core::ffi::CStr>` implementation for `CStr8`.

### Fixed

Expand Down
151 changes: 114 additions & 37 deletions src/data_types/strs.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::chars::{Char16, Char8, NUL_16, NUL_8};
use core::ffi::CStr;
use core::fmt;
use core::iter::Iterator;
use core::marker::PhantomData;
Expand Down Expand Up @@ -52,21 +53,30 @@ pub enum FromStrWithBufError {
BufferTooSmall,
}

/// A Latin-1 null-terminated string
/// A null-terminated Latin-1 string.
///
/// This type is largely inspired by `std::ffi::CStr`, see the documentation of
/// `CStr` for more details on its semantics.
/// This type is largely inspired by [`core::ffi::CStr`] with the exception that all characters are
/// guaranteed to be 8 bit long.
///
/// A [`CStr8`] can be constructed from a [`core::ffi::CStr`] via a `try_from` call:
/// ```ignore
/// let cstr8: &CStr8 = TryFrom::try_from(cstr).unwrap();
/// ```
///
/// For convenience, a [`CStr8`] is comparable with [`core::str`] and
/// `alloc::string::String` from the standard library through the trait [`EqStrUntilNul`].
#[repr(transparent)]
#[derive(Eq, PartialEq)]
pub struct CStr8([Char8]);

impl CStr8 {
/// Wraps a raw UEFI string with a safe C string wrapper
/// Takes a raw pointer to a null-terminated Latin-1 string and wraps it in a CStr8 reference.
///
/// # Safety
///
/// The function will start accessing memory from `ptr` until the first
/// null byte. It's the callers responsability to ensure `ptr` points to
/// a valid string, in accessible memory.
/// null byte. It's the callers responsibility to ensure `ptr` points to
/// a valid null-terminated string in accessible memory.
pub unsafe fn from_ptr<'ptr>(ptr: *const Char8) -> &'ptr Self {
let mut len = 0;
while *ptr.add(len) != NUL_8 {
Expand All @@ -76,7 +86,7 @@ impl CStr8 {
Self::from_bytes_with_nul_unchecked(slice::from_raw_parts(ptr, len + 1))
}

/// Creates a C string wrapper from bytes
/// Creates a CStr8 reference from bytes.
pub fn from_bytes_with_nul(chars: &[u8]) -> Result<&Self, FromSliceWithNulError> {
let nul_pos = chars.iter().position(|&c| c == 0);
if let Some(nul_pos) = nul_pos {
Expand All @@ -89,40 +99,82 @@ impl CStr8 {
}
}

/// Unsafely creates a C string wrapper from bytes
/// Unsafely creates a CStr8 reference from bytes.
///
/// # Safety
///
/// It's the callers responsability to ensure chars is a valid Latin-1
/// It's the callers responsibility to ensure chars is a valid Latin-1
/// null-terminated string, with no interior null bytes.
pub unsafe fn from_bytes_with_nul_unchecked(chars: &[u8]) -> &Self {
&*(chars as *const [u8] as *const Self)
}

/// Returns the inner pointer to this C string
/// Returns the inner pointer to this CStr8.
pub fn as_ptr(&self) -> *const Char8 {
self.0.as_ptr()
}

/// Converts this C string to a slice of bytes
/// Converts this CStr8 to a slice of bytes without the terminating null byte.
pub fn to_bytes(&self) -> &[u8] {
let chars = self.to_bytes_with_nul();
&chars[..chars.len() - 1]
}

/// Converts this C string to a slice of bytes containing the trailing 0 char
/// Converts this CStr8 to a slice of bytes containing the trailing null byte.
pub fn to_bytes_with_nul(&self) -> &[u8] {
unsafe { &*(&self.0 as *const [Char8] as *const [u8]) }
}
}

/// An UCS-2 null-terminated string
impl fmt::Debug for CStr8 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "CStr8({:?})", &self.0)
}
}

impl fmt::Display for CStr8 {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
for c in self.0.iter() {
<Char8 as fmt::Display>::fmt(c, f)?;
}
Ok(())
}
}

impl<StrType: AsRef<str>> EqStrUntilNul<StrType> for CStr8 {
fn eq_str_until_nul(&self, other: &StrType) -> bool {
let other = other.as_ref();

// TODO: CStr16 has .iter() implemented, CStr8 not yet
let any_not_equal = self
.0
.iter()
.copied()
.map(char::from)
.zip(other.chars())
// this only works as CStr8 is guaranteed to have a fixed character length
.take_while(|(l, r)| *l != '\0' && *r != '\0')
.any(|(l, r)| l != r);

!any_not_equal
}
}

impl<'a> TryFrom<&'a CStr> for &'a CStr8 {
type Error = FromSliceWithNulError;

fn try_from(cstr: &'a CStr) -> Result<Self, Self::Error> {
CStr8::from_bytes_with_nul(cstr.to_bytes_with_nul())
}
}

/// An UCS-2 null-terminated string.
///
/// This type is largely inspired by `std::ffi::CStr`, see the documentation of
/// `CStr` for more details on its semantics.
/// This type is largely inspired by [`core::ffi::CStr`] with the exception that all characters are
/// guaranteed to be 16 bit long.
///
/// For convenience, a [CStr16] is comparable with `&str` and `String` from the standard library
/// through the trait [EqStrUntilNul].
/// For convenience, a [`CStr16`] is comparable with [`core::str`] and
/// `alloc::string::String` from the standard library through the trait [`EqStrUntilNul`].
#[derive(Eq, PartialEq)]
#[repr(transparent)]
pub struct CStr16([Char16]);
Expand All @@ -133,7 +185,7 @@ impl CStr16 {
/// # Safety
///
/// The function will start accessing memory from `ptr` until the first
/// null byte. It's the callers responsability to ensure `ptr` points to
/// null byte. It's the callers responsibility to ensure `ptr` points to
/// a valid string, in accessible memory.
pub unsafe fn from_ptr<'ptr>(ptr: *const Char16) -> &'ptr Self {
let mut len = 0;
Expand Down Expand Up @@ -171,7 +223,7 @@ impl CStr16 {
///
/// # Safety
///
/// It's the callers responsability to ensure chars is a valid UCS-2
/// It's the callers responsibility to ensure chars is a valid UCS-2
/// null-terminated string, with no interior null bytes.
pub unsafe fn from_u16_with_nul_unchecked(codes: &[u16]) -> &Self {
&*(codes as *const [u16] as *const Self)
Expand Down Expand Up @@ -257,7 +309,7 @@ impl CStr16 {
self.0.len() * 2
}

/// Writes each [`Char16`] as a [´char´] (4 bytes long in Rust language) into the buffer.
/// Writes each [`Char16`] as a [`char`] (4 bytes long in Rust language) into the buffer.
/// It is up the the implementer of [`core::fmt::Write`] to convert the char to a string
/// with proper encoding/charset. For example, in the case of [`alloc::string::String`]
/// all Rust chars (UTF-32) get converted to UTF-8.
Expand Down Expand Up @@ -290,6 +342,7 @@ impl<StrType: AsRef<str>> EqStrUntilNul<StrType> for CStr16 {
.copied()
.map(char::from)
.zip(other.chars())
// this only works as CStr16 is guaranteed to have a fixed character length
.take_while(|(l, r)| *l != '\0' && *r != '\0')
.any(|(l, r)| l != r);

Expand Down Expand Up @@ -478,7 +531,17 @@ where
mod tests {
use super::*;
use crate::alloc_api::string::String;
use uefi_macros::cstr16;
use uefi_macros::{cstr16, cstr8};

// Tests if our CStr8 type can be constructed from a valid core::ffi::CStr
#[test]
fn test_cstr8_from_cstr() {
let msg = "hello world\0";
let cstr = unsafe { CStr::from_ptr(msg.as_ptr().cast()) };
let cstr8: &CStr8 = TryFrom::try_from(cstr).unwrap();
assert!(cstr8.eq_str_until_nul(&msg));
assert!(msg.eq_str_until_nul(cstr8));
}

#[test]
fn test_cstr16_num_bytes() {
Expand Down Expand Up @@ -565,23 +628,37 @@ mod tests {
);
}

#[test]
fn test_compare() {
let input: &CStr16 = cstr16!("test");
// Code generation helper for the compare tests of our CStrX types against "str" and "String"
// from the standard library.
#[allow(non_snake_case)]
macro_rules! test_compare_cstrX {
($input:ident) => {
assert!($input.eq_str_until_nul(&"test"));
assert!($input.eq_str_until_nul(&String::from("test")));

// now other direction
assert!(String::from("test").eq_str_until_nul($input));
assert!("test".eq_str_until_nul($input));

// some more tests
// this is fine: compare until the first null
assert!($input.eq_str_until_nul(&"te\0st"));
// this is fine
assert!($input.eq_str_until_nul(&"test\0"));
assert!(!$input.eq_str_until_nul(&"hello"));
};
}

#[test]
fn test_compare_cstr8() {
// test various comparisons with different order (left, right)
assert!(input.eq_str_until_nul(&"test"));
assert!(input.eq_str_until_nul(&String::from("test")));

// now other direction
assert!(String::from("test").eq_str_until_nul(input));
assert!("test".eq_str_until_nul(input));

// some more tests
// this is fine: compare until the first null
assert!(input.eq_str_until_nul(&"te\0st"));
// this is fine
assert!(input.eq_str_until_nul(&"test\0"));
assert!(!input.eq_str_until_nul(&"hello"));
let input: &CStr8 = cstr8!("test");
test_compare_cstrX!(input);
}

#[test]
fn test_compare_cstr16() {
let input: &CStr16 = cstr16!("test");
test_compare_cstrX!(input);
}
}