Skip to content

Commit

Permalink
Fix bug in COM interface chain support (microsoft#3060)
Browse files Browse the repository at this point in the history
The definition of a COM interface may inherit from another interface.
These are known as "interface chains". The `#[implement]` macro allows
designers to specify only the minimal set of interface chains that are
needed for a given COM object implementation. The `#[implement]` macro
(and the `#[interface]` macro) work together to pull in the
implementations of all interfaces along the chain.

Unfortunately there is a bug in the implementation of `QueryInterface`
for interface chains. The current `QueryInterface` implementation will
only check the IIDs of the interfaces at the root of the chian, i.e.
the "most-derived" interface. `QueryInterface` will not search the IIDs
of interfaces that are in the inheritance chain.

This bug is demonstrated (detected) by the new unit tests in
`crates/tests/implement_core/src/com_chain.rs`. This PR fixes the bug
by generating an `fn matches()` method that checks the current IID and
then checks the parent interface (if any) by calling its `match()`
method. This fixes the unit test.

Co-authored-by: Arlie Davis <ardavis@microsoft.com>
  • Loading branch information
2 people authored and mati865 committed Jun 15, 2024
1 parent 4e72682 commit 0d0feab
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 2 deletions.
17 changes: 15 additions & 2 deletions crates/libs/interface/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,17 @@ impl Interface {
let parent_vtable_generics = if self.parent_is_iunknown() { quote!(Identity, OFFSET) } else { quote!(Identity, Impl, OFFSET) };
let parent_vtable = self.parent_vtable();

// or_parent_matches will be `|| parent::matches(iid)` if this interface inherits from another
// interface (except for IUnknown) or will be empty if this is not applicable. This is what allows
// QueryInterface to work correctly for all interfaces in an inheritance chain, e.g.
// IFoo3 derives from IFoo2 derives from IFoo.
//
// We avoid matching IUnknown because object identity depends on the uniqueness of the IUnknown pointer.
let or_parent_matches = match parent_vtable.as_ref() {
Some(parent) if !self.parent_is_iunknown() => quote! (|| <#parent>::matches(iid)),
_ => quote!(),
};

let functions = self
.methods
.iter()
Expand Down Expand Up @@ -287,8 +298,10 @@ impl Interface {
Self { base__: #parent_vtable::new::<#parent_vtable_generics>(), #(#entries),* }
}

pub fn matches(iid: &windows_core::GUID) -> bool {
iid == &<#name as ::windows_core::Interface>::IID
#[inline(always)]
pub fn matches(iid: &::windows_core::GUID) -> bool {
*iid == <#name as ::windows_core::Interface>::IID
#or_parent_matches
}
}
}
Expand Down
50 changes: 50 additions & 0 deletions crates/tests/implement_core/src/com_chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use windows_core::*;

#[interface("cccccccc-0000-0000-0000-000000000001")]
unsafe trait IFoo: IUnknown {}

#[interface("cccccccc-0000-0000-0000-000000000002")]
unsafe trait IFoo2: IFoo {}

#[interface("cccccccc-0000-0000-0000-000000000003")]
unsafe trait IFoo3: IFoo2 {}

// ObjectA implements a single interface chain, which consists of 3 different
// interfaces: IFoo3, IFoo2, and IFoo. You do not need to explicitly list all
// of the interfaces in the interface chain. Listing all of the interfaces is
// less efficient because it generates redundant interface chains (pointer
// fields in the the generated ObjectA_Impl type), which will never be used.
#[implement(IFoo3)]
struct ObjectWithChains {}

impl IFoo_Impl for ObjectWithChains {}
impl IFoo2_Impl for ObjectWithChains {}
impl IFoo3_Impl for ObjectWithChains {}

#[test]
fn interface_chain_query() {
let object = ComObject::new(ObjectWithChains {});
let unknown: IUnknown = object.to_interface();
let _foo: IFoo = unknown.cast().expect("QueryInterface for IFoo");
let _foo2: IFoo2 = unknown.cast().expect("QueryInterface for IFoo2");
let _foo3: IFoo3 = unknown.cast().expect("QueryInterface for IFoo3");
}

// ObjectRedundantChains implements the same interfaces as ObjectWithChains,
// but it defines more than one interface chain. This is unnecessary because it
// is redundant, but we are verifying that this works.
#[implement(IFoo3, IFoo2, IFoo)]
struct ObjectRedundantChains {}

impl IFoo_Impl for ObjectRedundantChains {}
impl IFoo2_Impl for ObjectRedundantChains {}
impl IFoo3_Impl for ObjectRedundantChains {}

#[test]
fn redundant_interface_chains() {
let object = ComObject::new(ObjectRedundantChains {});
let unknown: IUnknown = object.to_interface();
let _foo: IFoo = unknown.cast().expect("QueryInterface for IFoo");
let _foo2: IFoo2 = unknown.cast().expect("QueryInterface for IFoo2");
let _foo3: IFoo3 = unknown.cast().expect("QueryInterface for IFoo3");
}
1 change: 1 addition & 0 deletions crates/tests/implement_core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
#![cfg(test)]

mod com_chain;
mod com_object;

0 comments on commit 0d0feab

Please sign in to comment.