diff --git a/src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs b/src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs
index 04765f34f..397a6fc12 100644
--- a/src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs
+++ b/src/Authoring/WinRT.SourceGenerator/DiagnosticUtils.cs
@@ -87,10 +87,12 @@ private void CheckDeclarations()
foreach (var declaration in syntaxReceiver.Declarations)
{
var model = _context.Compilation.GetSemanticModel(declaration.SyntaxTree);
+ var symbol = model.GetDeclaredSymbol(declaration);
// Check symbol information for whether it is public to properly detect partial types
- // which can leave out modifier.
- if (model.GetDeclaredSymbol(declaration).DeclaredAccessibility != Accessibility.Public)
+ // which can leave out modifier. Also ignore nested types not effectively public
+ if (symbol.DeclaredAccessibility != Accessibility.Public ||
+ (symbol is ITypeSymbol typeSymbol && !typeSymbol.IsPubliclyAccessible()))
{
continue;
}
diff --git a/src/Authoring/WinRT.SourceGenerator/Extensions/SymbolExtensions.cs b/src/Authoring/WinRT.SourceGenerator/Extensions/SymbolExtensions.cs
new file mode 100644
index 000000000..5a8f443d2
--- /dev/null
+++ b/src/Authoring/WinRT.SourceGenerator/Extensions/SymbolExtensions.cs
@@ -0,0 +1,54 @@
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.CodeAnalysis;
+
+#nullable enable
+
+namespace Generator;
+
+///
+/// Extensions for symbol types.
+///
+internal static class SymbolExtensions
+{
+ ///
+ /// Checks whether a given type symbol is publicly accessible (ie. it's public and not nested in any non public type).
+ ///
+ /// The type symbol to check for public accessibility.
+ /// Whether is publicly accessible.
+ public static bool IsPubliclyAccessible(this ITypeSymbol type)
+ {
+ for (ITypeSymbol? currentType = type; currentType is not null; currentType = currentType.ContainingType)
+ {
+ // If any type in the type hierarchy is not public, the type is not public.
+ // This makes sure to detect public types nested into eg. a private type.
+ if (currentType.DeclaredAccessibility is not Accessibility.Public)
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+
+ ///
+ /// Checks whether a given symbol is an explicit interface implementation of a member of an internal interface (or more than one).
+ ///
+ /// The input member symbol to check.
+ /// Whether is an explicit interface implementation of internal interfaces.
+ public static bool IsExplicitInterfaceImplementationOfInternalInterfaces(this ISymbol symbol)
+ {
+ static bool IsAnyContainingTypePublic(IEnumerable symbols)
+ {
+ return symbols.Any(static symbol => symbol.ContainingType!.IsPubliclyAccessible());
+ }
+
+ return symbol switch
+ {
+ IMethodSymbol { ExplicitInterfaceImplementations: { Length: > 0 } methods } => !IsAnyContainingTypePublic(methods),
+ IPropertySymbol { ExplicitInterfaceImplementations: { Length: > 0 } properties } => !IsAnyContainingTypePublic(properties),
+ IEventSymbol { ExplicitInterfaceImplementations: { Length: > 0 } events } => !IsAnyContainingTypePublic(events),
+ _ => false
+ };
+ }
+}
diff --git a/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs b/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs
index f52aff52e..cf585ff22 100644
--- a/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs
+++ b/src/Authoring/WinRT.SourceGenerator/WinRTTypeWriter.cs
@@ -1303,21 +1303,38 @@ Symbol GetType(string type, bool isGeneric = false, int genericIndex = -1, bool
private IEnumerable GetInterfaces(INamedTypeSymbol symbol, bool includeInterfacesWithoutMappings = false)
{
- HashSet interfaces = new HashSet();
- foreach (var @interface in symbol.Interfaces)
+ HashSet interfaces = new();
+
+ // Gather all interfaces that are publicly accessible. We specifically need to exclude interfaces
+ // that are not public, as eg. those might be used for additional cloaked WinRT/COM interfaces.
+ // Ignoring them here makes sure that they're not processed to be part of the .winmd file.
+ void GatherPubliclyAccessibleInterfaces(ITypeSymbol symbol)
{
- interfaces.Add(@interface);
- interfaces.UnionWith(@interface.AllInterfaces);
+ foreach (var @interface in symbol.Interfaces)
+ {
+ if (@interface.IsPubliclyAccessible())
+ {
+ _ = interfaces.Add(@interface);
+ }
+
+ // We're not using AllInterfaces on purpose: we only want to gather all interfaces but not
+ // from the base type. That's handled below to skip types that are already WinRT projections.
+ foreach (var @interface2 in @interface.AllInterfaces)
+ {
+ if (@interface2.IsPubliclyAccessible())
+ {
+ _ = interfaces.Add(@interface2);
+ }
+ }
+ }
}
+ GatherPubliclyAccessibleInterfaces(symbol);
+
var baseType = symbol.BaseType;
while (baseType != null && !IsWinRTType(baseType))
{
- interfaces.UnionWith(baseType.Interfaces);
- foreach (var @interface in baseType.Interfaces)
- {
- interfaces.UnionWith(@interface.AllInterfaces);
- }
+ GatherPubliclyAccessibleInterfaces(baseType);
baseType = baseType.BaseType;
}
@@ -2010,6 +2027,13 @@ void AddComponentType(INamedTypeSymbol type, Action visitTypeDeclaration = null)
}
else
{
+ // Special case: skip members that are explicitly implementing internal interfaces.
+ // This allows implementing classic COM internal interfaces with non-WinRT signatures.
+ if (member.IsExplicitInterfaceImplementationOfInternalInterfaces())
+ {
+ continue;
+ }
+
if (member is IMethodSymbol method &&
(method.MethodKind == MethodKind.Ordinary ||
method.MethodKind == MethodKind.ExplicitInterfaceImplementation ||
@@ -2736,12 +2760,19 @@ public void FinalizeGeneration()
}
}
- public bool IsPublic(ISymbol type)
+ public bool IsPublic(ISymbol symbol)
{
- return type.DeclaredAccessibility == Accessibility.Public ||
- type is IMethodSymbol method && !method.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
- type is IPropertySymbol property && !property.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
- type is IEventSymbol @event && !@event.ExplicitInterfaceImplementations.IsDefaultOrEmpty;
+ // Check that the type has either public accessibility, or is an explicit interface implementation
+ if (symbol.DeclaredAccessibility == Accessibility.Public ||
+ symbol is IMethodSymbol method && !method.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
+ symbol is IPropertySymbol property && !property.ExplicitInterfaceImplementations.IsDefaultOrEmpty ||
+ symbol is IEventSymbol @event && !@event.ExplicitInterfaceImplementations.IsDefaultOrEmpty)
+ {
+ // If we have a containing type, we also check that it's publicly accessible
+ return symbol.ContainingType is not { } containingType || containingType.IsPubliclyAccessible();
+ }
+
+ return false;
}
public void GetNamespaceAndTypename(string qualifiedName, out string @namespace, out string typename)
diff --git a/src/Tests/AuthoringConsumptionTest/AuthoringConsumptionTest.exe.manifest b/src/Tests/AuthoringConsumptionTest/AuthoringConsumptionTest.exe.manifest
index 59cc17e2c..9a907d35f 100644
--- a/src/Tests/AuthoringConsumptionTest/AuthoringConsumptionTest.exe.manifest
+++ b/src/Tests/AuthoringConsumptionTest/AuthoringConsumptionTest.exe.manifest
@@ -74,5 +74,9 @@
name="AuthoringTest.TestClass"
threadingModel="both"
xmlns="urn:schemas-microsoft-com:winrt.v1" />
+
\ No newline at end of file
diff --git a/src/Tests/AuthoringConsumptionTest/pch.h b/src/Tests/AuthoringConsumptionTest/pch.h
index 1eb7dade4..781afc1b9 100644
--- a/src/Tests/AuthoringConsumptionTest/pch.h
+++ b/src/Tests/AuthoringConsumptionTest/pch.h
@@ -4,6 +4,7 @@
// conflict with Storyboard::GetCurrentTime
#undef GetCurrentTime
+#include
#include
#include
diff --git a/src/Tests/AuthoringConsumptionTest/test.cpp b/src/Tests/AuthoringConsumptionTest/test.cpp
index d8bbf3c23..585078bc3 100644
--- a/src/Tests/AuthoringConsumptionTest/test.cpp
+++ b/src/Tests/AuthoringConsumptionTest/test.cpp
@@ -639,4 +639,38 @@ TEST(AuthoringTest, PartialClass)
EXPECT_EQ(partialStruct.X, 3);
EXPECT_EQ(partialStruct.Y, 4);
EXPECT_EQ(partialStruct.Z, 5);
+}
+
+TEST(AuthoringTest, MixedWinRTClassicCOM)
+{
+ TestMixedWinRTCOMWrapper wrapper;
+
+ // Normal WinRT methods work as you'd expect
+ EXPECT_EQ(wrapper.HelloWorld(), L"Hello from mixed WinRT/COM");
+
+ // Verify we can grab the internal interface
+ IID internalInterface1Iid;
+ check_hresult(IIDFromString(L"{C7850559-8FF2-4E54-A237-6ED813F20CDC}", &internalInterface1Iid));
+ winrt::com_ptr<::IUnknown> unknown1 = wrapper.as<::IUnknown>();
+ winrt::com_ptr<::IUnknown> internalInterface1;
+ EXPECT_EQ(unknown1->QueryInterface(internalInterface1Iid, internalInterface1.put_void()), S_OK);
+
+ // Verify we can grab the nested public interface (in an internal type)
+ IID internalInterface2Iid;
+ check_hresult(IIDFromString(L"{8A08E18A-8D20-4E7C-9242-857BFE1E3159}", &internalInterface2Iid));
+ winrt::com_ptr<::IUnknown> unknown2 = wrapper.as<::IUnknown>();
+ winrt::com_ptr<::IUnknown> internalInterface2;
+ EXPECT_EQ(unknown2->QueryInterface(internalInterface2Iid, internalInterface2.put_void()), S_OK);
+
+ typedef int (__stdcall* GetNumber)(void*, int*);
+
+ int number;
+
+ // Validate the first call on IInternalInterface1
+ EXPECT_EQ(reinterpret_cast((*reinterpret_cast(internalInterface1.get()))[3])(internalInterface1.get(), &number), S_OK);
+ EXPECT_EQ(number, 42);
+
+ // Validate the second call on IInternalInterface2
+ EXPECT_EQ(reinterpret_cast((*reinterpret_cast(internalInterface2.get()))[3])(internalInterface2.get(), &number), S_OK);
+ EXPECT_EQ(number, 123);
}
\ No newline at end of file
diff --git a/src/Tests/AuthoringTest/Program.cs b/src/Tests/AuthoringTest/Program.cs
index 0c1a90f26..1e490cc82 100644
--- a/src/Tests/AuthoringTest/Program.cs
+++ b/src/Tests/AuthoringTest/Program.cs
@@ -9,12 +9,17 @@
using System.ComponentModel.DataAnnotations;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
+using System.Runtime.CompilerServices;
+using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
using System.Windows.Input;
using Windows.Foundation;
using Windows.Foundation.Collections;
using Windows.Foundation.Metadata;
+using Windows.Graphics.Effects;
+using WinRT;
+using WinRT.Interop;
#pragma warning disable CA1416
@@ -1569,6 +1574,142 @@ public partial struct PartialStruct
{
public double Z;
}
+
+ public sealed class TestMixedWinRTCOMWrapper : IGraphicsEffectSource, IPublicInterface, IInternalInterface1, SomeInternalType.IInternalInterface2
+ {
+ public string HelloWorld()
+ {
+ return "Hello from mixed WinRT/COM";
+ }
+
+ unsafe int IInternalInterface1.GetNumber(int* value)
+ {
+ *value = 42;
+
+ return 0;
+ }
+
+ unsafe int SomeInternalType.IInternalInterface2.GetNumber(int* value)
+ {
+ *value = 123;
+
+ return 0;
+ }
+ }
+
+ public interface IPublicInterface
+ {
+ string HelloWorld();
+ }
+
+ // Internal, classic COM interface
+ [global::System.Runtime.InteropServices.Guid("C7850559-8FF2-4E54-A237-6ED813F20CDC")]
+ [WindowsRuntimeType]
+ [WindowsRuntimeHelperType(typeof(IInternalInterface1))]
+ internal unsafe interface IInternalInterface1
+ {
+ int GetNumber(int* value);
+
+ [global::System.Runtime.InteropServices.Guid("C7850559-8FF2-4E54-A237-6ED813F20CDC")]
+ public struct Vftbl
+ {
+ public static readonly IntPtr AbiToProjectionVftablePtr = InitVtbl();
+
+ private static IntPtr InitVtbl()
+ {
+ Vftbl* lpVtbl = (Vftbl*)ComWrappersSupport.AllocateVtableMemory(typeof(Vftbl), sizeof(Vftbl));
+
+ lpVtbl->IUnknownVftbl = IUnknownVftbl.AbiToProjectionVftbl;
+ lpVtbl->GetNumber = &GetNumberFromAbi;
+
+ return (IntPtr)lpVtbl;
+ }
+
+ private IUnknownVftbl IUnknownVftbl;
+ private delegate* unmanaged[Stdcall] GetNumber;
+
+ [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
+ private static int GetNumberFromAbi(void* thisPtr, int* value)
+ {
+ try
+ {
+ return ComWrappersSupport.FindObject((IntPtr)thisPtr).GetNumber(value);
+ }
+ catch (Exception e)
+ {
+ ExceptionHelpers.SetErrorInfo(e);
+
+ return Marshal.GetHRForException(e);
+ }
+ }
+ }
+ }
+
+ internal struct SomeInternalType
+ {
+ // Nested, classic COM interface
+ [global::System.Runtime.InteropServices.Guid("8A08E18A-8D20-4E7C-9242-857BFE1E3159")]
+ [WindowsRuntimeType]
+ [WindowsRuntimeHelperType(typeof(IInternalInterface2))]
+ public unsafe interface IInternalInterface2
+ {
+ int GetNumber(int* value);
+
+ [global::System.Runtime.InteropServices.Guid("8A08E18A-8D20-4E7C-9242-857BFE1E3159")]
+ public struct Vftbl
+ {
+ public static readonly IntPtr AbiToProjectionVftablePtr = InitVtbl();
+
+ private static IntPtr InitVtbl()
+ {
+ Vftbl* lpVtbl = (Vftbl*)ComWrappersSupport.AllocateVtableMemory(typeof(Vftbl), sizeof(Vftbl));
+
+ lpVtbl->IUnknownVftbl = IUnknownVftbl.AbiToProjectionVftbl;
+ lpVtbl->GetNumber = &GetNumberFromAbi;
+
+ return (IntPtr)lpVtbl;
+ }
+
+ private IUnknownVftbl IUnknownVftbl;
+ private delegate* unmanaged[Stdcall] GetNumber;
+
+ [UnmanagedCallersOnly(CallConvs = new[] { typeof(CallConvStdcall) })]
+ private static int GetNumberFromAbi(void* thisPtr, int* value)
+ {
+ try
+ {
+ return ComWrappersSupport.FindObject((IntPtr)thisPtr).GetNumber(value);
+ }
+ catch (Exception e)
+ {
+ ExceptionHelpers.SetErrorInfo(e);
+
+ return Marshal.GetHRForException(e);
+ }
+ }
+ }
+ }
+ }
+}
+
+namespace ABI.AuthoringTest
+{
+ internal static class IInternalInterface1Methods
+ {
+ public static Guid IID => typeof(global::AuthoringTest.IInternalInterface1).GUID;
+
+ public static IntPtr AbiToProjectionVftablePtr => global::AuthoringTest.IInternalInterface1.Vftbl.AbiToProjectionVftablePtr;
+ }
+
+ internal struct SomeInternalType
+ {
+ internal static class IInternalInterface2Methods
+ {
+ public static Guid IID => typeof(global::AuthoringTest.SomeInternalType.IInternalInterface2).GUID;
+
+ public static IntPtr AbiToProjectionVftablePtr => global::AuthoringTest.SomeInternalType.IInternalInterface2.Vftbl.AbiToProjectionVftablePtr;
+ }
+ }
}
namespace AnotherNamespace