Skip to content

Commit

Permalink
Allow windows-result to work on non-Windows platforms (#3082)
Browse files Browse the repository at this point in the history
  • Loading branch information
sivadeilra authored Jun 11, 2024
1 parent 66ad6d9 commit 56fd381
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 82 deletions.
134 changes: 85 additions & 49 deletions crates/libs/result/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ffi::c_void;
#[derive(Clone, PartialEq, Eq)]
pub struct Error {
code: HRESULT,
#[cfg(windows)]
info: Option<ComPtr>,
}

Expand All @@ -13,35 +14,54 @@ impl Error {
pub const fn empty() -> Self {
Self {
code: HRESULT(0),
#[cfg(windows)]
info: None,
}
}

/// Creates a new error object, capturing the stack and other information about the
/// point of failure.
pub fn new<T: AsRef<str>>(code: HRESULT, message: T) -> Self {
let message: Vec<_> = message.as_ref().encode_utf16().collect();

if message.is_empty() {
Self::from_hresult(code)
} else {
unsafe {
RoOriginateErrorW(code.0, message.len() as u32, message.as_ptr());
#[cfg(windows)]
{
let message: Vec<_> = message.as_ref().encode_utf16().collect();
if message.is_empty() {
Self::from_hresult(code)
} else {
unsafe {
RoOriginateErrorW(code.0, message.len() as u32, message.as_ptr());
}
code.into()
}
code.into()
}
#[cfg(not(windows))]
{
let _ = message;
Self::from_hresult(code)
}
}

/// Creates a new error object with an error code, but without additional error information.
pub fn from_hresult(code: HRESULT) -> Self {
Self { code, info: None }
Self {
code,
#[cfg(windows)]
info: None,
}
}

/// Creates a new `Error` from the Win32 error code returned by `GetLastError()`.
pub fn from_win32() -> Self {
Self {
code: HRESULT::from_win32(unsafe { GetLastError() }),
info: None,
#[cfg(windows)]
{
Self {
code: HRESULT::from_win32(unsafe { GetLastError() }),
info: None,
}
}
#[cfg(not(windows))]
{
unimplemented!()
}
}

Expand All @@ -52,49 +72,53 @@ impl Error {

/// The error message describing the error.
pub fn message(&self) -> String {
if let Some(info) = &self.info {
let mut message = BasicString::default();

// First attempt to retrieve the restricted error information.
if let Some(info) = info.cast(&IID_IRestrictedErrorInfo) {
let mut fallback = BasicString::default();
let mut code = 0;

unsafe {
com_call!(
IRestrictedErrorInfo_Vtbl,
info.GetErrorDetails(
&mut fallback as *mut _ as _,
&mut code,
&mut message as *mut _ as _,
&mut BasicString::default() as *mut _ as _
)
);
#[cfg(windows)]
{
if let Some(info) = &self.info {
let mut message = BasicString::default();

// First attempt to retrieve the restricted error information.
if let Some(info) = info.cast(&IID_IRestrictedErrorInfo) {
let mut fallback = BasicString::default();
let mut code = 0;

unsafe {
com_call!(
IRestrictedErrorInfo_Vtbl,
info.GetErrorDetails(
&mut fallback as *mut _ as _,
&mut code,
&mut message as *mut _ as _,
&mut BasicString::default() as *mut _ as _
)
);
}

if message.is_empty() {
message = fallback
};
}

// Next attempt to retrieve the regular error information.
if message.is_empty() {
message = fallback
};
}

// Next attempt to retrieve the regular error information.
if message.is_empty() {
unsafe {
com_call!(
IErrorInfo_Vtbl,
info.GetDescription(&mut message as *mut _ as _)
);
unsafe {
com_call!(
IErrorInfo_Vtbl,
info.GetDescription(&mut message as *mut _ as _)
);
}
}
}

return String::from_utf16_lossy(wide_trim_end(message.as_wide()));
return String::from_utf16_lossy(wide_trim_end(message.as_wide()));
}
}

// Otherwise fallback to a generic error code description.
self.code.message()
}

/// The error object describing the error.
#[cfg(windows)]
pub fn as_ptr(&self) -> *mut c_void {
self.info
.as_ref()
Expand All @@ -109,9 +133,12 @@ unsafe impl Sync for Error {}

impl From<Error> for HRESULT {
fn from(error: Error) -> Self {
if let Some(info) = error.info {
unsafe {
SetErrorInfo(0, info.as_raw());
#[cfg(windows)]
{
if let Some(info) = error.info {
unsafe {
SetErrorInfo(0, info.as_raw());
}
}
}
error.code
Expand All @@ -120,9 +147,15 @@ impl From<Error> for HRESULT {

impl From<HRESULT> for Error {
fn from(code: HRESULT) -> Self {
let mut info = None;
unsafe { GetErrorInfo(0, &mut info as *mut _ as _) };
Self { code, info }
Self {
code,
#[cfg(windows)]
info: {
let mut info = None;
unsafe { GetErrorInfo(0, &mut info as *mut _ as _) };
info
},
}
}
}

Expand All @@ -147,6 +180,7 @@ impl From<alloc::string::FromUtf16Error> for Error {
fn from(_: alloc::string::FromUtf16Error) -> Self {
Self {
code: HRESULT::from_win32(ERROR_NO_UNICODE_TRANSLATION),
#[cfg(windows)]
info: None,
}
}
Expand All @@ -156,6 +190,7 @@ impl From<alloc::string::FromUtf8Error> for Error {
fn from(_: alloc::string::FromUtf8Error) -> Self {
Self {
code: HRESULT::from_win32(ERROR_NO_UNICODE_TRANSLATION),
#[cfg(windows)]
info: None,
}
}
Expand All @@ -165,6 +200,7 @@ impl From<core::num::TryFromIntError> for Error {
fn from(_: core::num::TryFromIntError) -> Self {
Self {
code: HRESULT::from_win32(ERROR_INVALID_DATA),
#[cfg(windows)]
info: None,
}
}
Expand Down
77 changes: 44 additions & 33 deletions crates/libs/result/src/hresult.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,41 +64,52 @@ impl HRESULT {

/// The error message describing the error.
pub fn message(&self) -> String {
let mut message = HeapString::default();
let mut code = self.0;
let mut module = 0;

let mut flags = FORMAT_MESSAGE_ALLOCATE_BUFFER
| FORMAT_MESSAGE_FROM_SYSTEM
| FORMAT_MESSAGE_IGNORE_INSERTS;

unsafe {
if self.0 & 0x1000_0000 == 0x1000_0000 {
code ^= 0x1000_0000;
flags |= FORMAT_MESSAGE_FROM_HMODULE;

module =
LoadLibraryExA(b"ntdll.dll\0".as_ptr(), 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS);
#[cfg(windows)]
{
let mut message = HeapString::default();
let mut code = self.0;
let mut module = 0;

let mut flags = FORMAT_MESSAGE_ALLOCATE_BUFFER
| FORMAT_MESSAGE_FROM_SYSTEM
| FORMAT_MESSAGE_IGNORE_INSERTS;

unsafe {
if self.0 & 0x1000_0000 == 0x1000_0000 {
code ^= 0x1000_0000;
flags |= FORMAT_MESSAGE_FROM_HMODULE;

module = LoadLibraryExA(
b"ntdll.dll\0".as_ptr(),
0,
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS,
);
}

let size = FormatMessageW(
flags,
module as _,
code as _,
0,
&mut message.0 as *mut _ as *mut _,
0,
core::ptr::null(),
);

if !message.0.is_null() && size > 0 {
String::from_utf16_lossy(wide_trim_end(core::slice::from_raw_parts(
message.0,
size as usize,
)))
} else {
String::default()
}
}
}

let size = FormatMessageW(
flags,
module as _,
code as _,
0,
&mut message.0 as *mut _ as *mut _,
0,
core::ptr::null(),
);

if !message.0.is_null() && size > 0 {
String::from_utf16_lossy(wide_trim_end(core::slice::from_raw_parts(
message.0,
size as usize,
)))
} else {
String::default()
}
#[cfg(not(windows))]
{
return format!("0x{:08x}", self.0 as u32);
}
}

Expand Down
5 changes: 5 additions & 0 deletions crates/libs/result/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Learn more about Rust for Windows here: <https://github.com/microsoft/windows-rs
debugger_visualizer(natvis_file = "../.natvis")
)]
#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
#![cfg_attr(not(windows), allow(unused_imports))]

extern crate alloc;

Expand All @@ -16,10 +17,14 @@ use alloc::vec::Vec;
mod bindings;
use bindings::*;

#[cfg(windows)]
mod com;
#[cfg(windows)]
use com::*;

#[cfg(windows)]
mod strings;
#[cfg(windows)]
use strings::*;

mod error;
Expand Down
34 changes: 34 additions & 0 deletions crates/tests/linux/tests/hresult.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// This tests code paths in `windows-result` that are different on non-Windows platforms.
#![cfg(not(windows))]

use windows::core::Error;
use windows::Win32::Foundation::{E_FAIL, S_OK};

#[test]
fn basic_hresult() {
assert!(E_FAIL.is_err());
assert!(S_OK.is_ok());

let ok_message = S_OK.message();
assert_eq!(ok_message, "0x00000000");
}

#[test]
fn error_message_is_not_supported() {
let e = Error::new(S_OK, "this gets ignored");
let message = e.message();
assert_eq!(message, "0x00000000");
}

#[test]
#[should_panic]
fn from_win32_panics() {
// from_win32() is not implemented on non-Windows platforms.
let _e = Error::from_win32();
}

#[test]
fn error_from_hresult() {
let e = Error::from(E_FAIL);
assert_eq!(e.code(), E_FAIL);
}

0 comments on commit 56fd381

Please sign in to comment.