Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add string marshallers for ANSI and platform-defined #288

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 332 additions & 27 deletions DllImportGenerator/DllImportGenerator.IntegrationTests/StringTests.cs

Large diffs are not rendered by default.

15 changes: 5 additions & 10 deletions DllImportGenerator/DllImportGenerator.UnitTests/Compiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public static IEnumerable<object[]> CodeSnippetsToCompile_NoDiagnostics()
yield return new[] { CodeSnippets.ArrayParameterWithSizeParam<UIntPtr>(isByRef: true) };
yield return new[] { CodeSnippets.BasicParametersAndModifiersWithCharSet<char>(CharSet.Unicode) };
yield return new[] { CodeSnippets.BasicParametersAndModifiersWithCharSet<string>(CharSet.Unicode) };
//yield return new[] { CodeSnippets.BasicParametersAndModifiersWithCharSet<string>(CharSet.Ansi) };
//yield return new[] { CodeSnippets.BasicParametersAndModifiersWithCharSet<string>(CharSet.Auto) };
yield return new[] { CodeSnippets.BasicParametersAndModifiersWithCharSet<string>(CharSet.Ansi) };
yield return new[] { CodeSnippets.BasicParametersAndModifiersWithCharSet<string>(CharSet.Auto) };
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<bool>(UnmanagedType.Bool) };
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<bool>(UnmanagedType.VariantBool) };
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<bool>(UnmanagedType.I1) };
Expand All @@ -78,9 +78,10 @@ public static IEnumerable<object[]> CodeSnippetsToCompile_NoDiagnostics()
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPWStr) };
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPTStr) };
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPUTF8Str) };
//yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPStr) };
yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPStr) };
yield return new[] { CodeSnippets.ArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPWStr) };
yield return new[] { CodeSnippets.ArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPUTF8Str) };
yield return new[] { CodeSnippets.ArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPStr) };
//yield return new[] { CodeSnippets.EnumParameters };
yield return new[] { CodeSnippets.PreserveSigFalseVoidReturn };
yield return new[] { CodeSnippets.PreserveSigFalse<byte>() };
Expand Down Expand Up @@ -135,17 +136,11 @@ public static IEnumerable<object[]> CodeSnippetsToCompile_WithDiagnostics()
yield return new[] { CodeSnippets.BasicParametersAndModifiers<string[]>() };
yield return new[] { CodeSnippets.BasicParametersAndModifiers<IntPtr[]>() };
yield return new[] { CodeSnippets.BasicParametersAndModifiers<UIntPtr[]>() };

yield return new[] { CodeSnippets.ArrayParameterWithSizeParam<float>(isByRef: false) };
yield return new[] { CodeSnippets.ArrayParameterWithSizeParam<double>(isByRef: false) };
yield return new[] { CodeSnippets.ArrayParameterWithSizeParam<bool>(isByRef: false) };

yield return new[] { CodeSnippets.BasicParametersAndModifiersWithCharSet<string>(CharSet.Ansi) };
yield return new[] { CodeSnippets.BasicParametersAndModifiersWithCharSet<string>(CharSet.Auto) };

yield return new[] { CodeSnippets.MarshalAsParametersAndModifiers<string>(UnmanagedType.LPStr) };
yield return new[] { CodeSnippets.ArrayParameterWithNestedMarshalInfo<string>(UnmanagedType.LPStr) };

yield return new[] { CodeSnippets.EnumParameters };

yield return new[] { CodeSnippets.PreserveSigFalse<byte[]>() };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ private static ReferenceAssemblies GetReferenceAssemblies(TargetFramework target
TargetFramework.Framework => ReferenceAssemblies.NetFramework.Net48.Default,
TargetFramework.Standard => ReferenceAssemblies.NetStandard.NetStandard21,
TargetFramework.Core => ReferenceAssemblies.NetCore.NetCoreApp31,
TargetFramework.Net => ReferenceAssemblies.NetCore.NetCoreApp50,
TargetFramework.Net => ReferenceAssemblies.Net.Net50,
_ => ReferenceAssemblies.Default
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Workspaces" Version="$(CompilerPlatformVersion)" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.Analyzers" Version="3.3.0" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Analyzer.Testing.XUnit" Version="1.0.1-beta1.20418.1" PrivateAssets="all" />
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Analyzer.Testing.XUnit" Version="1.0.1-beta1.20478.1" PrivateAssets="all" />
<PackageReference Include="Microsoft.Net.Compilers.Toolset" Version="$(CompilerPlatformVersion)">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public static async Task<Compilation> CreateCompilationWithReferenceAssemblies(s
public static (ReferenceAssemblies, MetadataReference) GetReferenceAssemblies()
{
// TODO: When .NET 5.0 releases, we can simplify this.
var referenceAssemblies = ReferenceAssemblies.NetCore.NetCoreApp50;
var referenceAssemblies = ReferenceAssemblies.Net.Net50;

// Include the assembly containing the new attribute and all of its references.
// [TODO] Remove once the attribute has been added to the BCL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ public override IEnumerable<StatementSyntax> Generate(TypePositionInfo info, Stu
VariableDeclaration(
AsNativeType(info),
SingletonSeparatedList(VariableDeclarator(nativeIdentifier))));

if (TryGenerateSetupSyntax(info, context, out StatementSyntax conditionalAllocSetup))
yield return conditionalAllocSetup;

break;
case StubCodeContext.Stage.Marshal:
if (info.RefKind != RefKind.Out)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,39 @@ namespace Microsoft.Interop
{
internal abstract class ConditionalStackallocMarshallingGenerator : IMarshallingGenerator
{
private static string GetAllocationMarkerIdentifier(string managedIdentifier) => $"{managedIdentifier}__allocated";
protected static string GetAllocationMarkerIdentifier(string managedIdentifier) => $"{managedIdentifier}__allocated";

private static string GetByteLengthIdentifier(string managedIdentifier) => $"{managedIdentifier}__bytelen";

private static string GetStackAllocIdentifier(string managedIdentifier) => $"{managedIdentifier}__stackptr";

protected bool UsesConditionalStackAlloc(TypePositionInfo info, StubCodeContext context)
{
return context.CanUseAdditionalTemporaryState
&& context.StackSpaceUsable
&& (!info.IsByRef || info.RefKind == RefKind.In)
&& !info.IsManagedReturnPosition;
}

protected bool TryGenerateSetupSyntax(TypePositionInfo info, StubCodeContext context, out StatementSyntax statement)
{
statement = EmptyStatement();

if (!UsesConditionalStackAlloc(info, context))
return false;

string allocationMarkerIdentifier = GetAllocationMarkerIdentifier(context.GetIdentifiers(info).managed);

// bool <allocationMarker> = false;
statement = LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.BoolKeyword)),
SingletonSeparatedList(
VariableDeclarator(allocationMarkerIdentifier)
.WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression))))));
return true;
}

protected IEnumerable<StatementSyntax> GenerateConditionalAllocationSyntax(
TypePositionInfo info,
StubCodeContext context,
Expand All @@ -39,8 +66,8 @@ protected IEnumerable<StatementSyntax> GenerateConditionalAllocationSyntax(
VariableDeclarator(byteLenIdentifier)
.WithInitializer(EqualsValueClause(
GenerateByteLengthCalculationExpression(info, context))))));
if (!context.CanUseAdditionalTemporaryState || !context.StackSpaceUsable || (info.IsByRef && info.RefKind != RefKind.In))

if (!UsesConditionalStackAlloc(info, context))
{
List<StatementSyntax> statements = new List<StatementSyntax>();
if (allocationRequiresByteLength)
Expand All @@ -58,13 +85,6 @@ protected IEnumerable<StatementSyntax> GenerateConditionalAllocationSyntax(
Block(statements));
yield break;
}
// <allocationMarkerIdentifier> = false;
yield return LocalDeclarationStatement(
VariableDeclaration(
PredefinedType(Token(SyntaxKind.BoolKeyword)),
SingletonSeparatedList(
VariableDeclarator(allocationMarkerIdentifier)
.WithInitializer(EqualsValueClause(LiteralExpression(SyntaxKind.FalseLiteralExpression))))));

// Code block for stackalloc if number of bytes is below threshold size
var marshalOnStack = Block(
Expand Down Expand Up @@ -94,6 +114,7 @@ protected IEnumerable<StatementSyntax> GenerateConditionalAllocationSyntax(
// if (<byteLen> > <StackAllocBytesThreshold>)
// {
// <allocationStatement>;
// <allocationMarker> = true;
// }
// else
// {
Expand Down Expand Up @@ -135,7 +156,7 @@ protected StatementSyntax GenerateConditionalAllocationFreeSyntax(
{
(string managedIdentifier, string nativeIdentifier) = context.GetIdentifiers(info);
string allocationMarkerIdentifier = GetAllocationMarkerIdentifier(managedIdentifier);
if (!context.CanUseAdditionalTemporaryState || (info.IsByRef && info.RefKind != RefKind.In))
if (!UsesConditionalStackAlloc(info, context))
{
return ExpressionStatement(GenerateFreeExpression(info, context));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ namespace Microsoft.Interop
{
internal static class MarshallerHelpers
{
public static readonly ExpressionSyntax IsWindows = InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName("System.OperatingSystem"),
IdentifierName("IsWindows")));

public static readonly TypeSyntax InteropServicesMarshalType = ParseTypeName(TypeNames.System_Runtime_InteropServices_Marshal);

public static ForStatementSyntax GetForLoop(string collectionIdentifier, string indexerIdentifier)
{
// for(int <indexerIdentifier> = 0; <indexerIdentifier> < <collectionIdentifier>.Length; ++<indexerIdentifier>)
Expand Down Expand Up @@ -41,5 +49,48 @@ public static ForStatementSyntax GetForLoop(string collectionIdentifier, string
SyntaxKind.PreIncrementExpression,
IdentifierName(indexerIdentifier))));
}

public static class StringMarshaller
{
public static ExpressionSyntax AllocationExpression(CharEncoding encoding, string managedIdentifier)
{
string methodName = encoding switch
{
CharEncoding.Utf8 => "StringToCoTaskMemUTF8",
CharEncoding.Utf16 => "StringToCoTaskMemUni",
CharEncoding.Ansi => "StringToCoTaskMemAnsi",
_ => throw new System.ArgumentOutOfRangeException(nameof(encoding))
};

// Marshal.StringToCoTaskMemUTF8(<managed>)
// or
// Marshal.StringToCoTaskMemUni(<managed>)
// or
// Marshal.StringToCoTaskMemAnsi(<managed>)
return InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
InteropServicesMarshalType,
IdentifierName(methodName)),
ArgumentList(
SingletonSeparatedList<ArgumentSyntax>(
Argument(IdentifierName(managedIdentifier)))));
}

public static ExpressionSyntax FreeExpression(string nativeIdentifier)
{
// Marshal.FreeCoTaskMem((IntPtr)<nativeIdentifier>)
return InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
InteropServicesMarshalType,
IdentifierName("FreeCoTaskMem")),
ArgumentList(SingletonSeparatedList(
Argument(
CastExpression(
ParseTypeName("System.IntPtr"),
IdentifierName(nativeIdentifier))))));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,13 @@ internal class MarshallingGenerators
public static readonly ByteBoolMarshaller ByteBool = new ByteBoolMarshaller();
public static readonly WinBoolMarshaller WinBool = new WinBoolMarshaller();
public static readonly VariantBoolMarshaller VariantBool = new VariantBoolMarshaller();

public static readonly Utf16CharMarshaller Utf16Char = new Utf16CharMarshaller();
public static readonly Utf16StringMarshaller Utf16String = new Utf16StringMarshaller();
public static readonly Utf8StringMarshaller Utf8String = new Utf8StringMarshaller();
public static readonly AnsiStringMarshaller AnsiString = new AnsiStringMarshaller(Utf8String);
public static readonly PlatformDefinedStringMarshaller PlatformDefinedString = new PlatformDefinedStringMarshaller(Utf16String, Utf8String);

public static readonly Forwarder Forwarder = new Forwarder();
public static readonly BlittableMarshaller Blittable = new BlittableMarshaller();
public static readonly DelegateMarshaller Delegate = new DelegateMarshaller();
Expand Down Expand Up @@ -259,6 +263,8 @@ private static IMarshallingGenerator CreateStringMarshaller(TypePositionInfo inf
{
switch (marshalAsInfo.UnmanagedType)
{
case UnmanagedType.LPStr:
return AnsiString;
case UnmanagedType.LPTStr:
case UnmanagedType.LPWStr:
return Utf16String;
Expand All @@ -270,10 +276,14 @@ private static IMarshallingGenerator CreateStringMarshaller(TypePositionInfo inf
{
switch (marshalStringInfo.CharEncoding)
{
case CharEncoding.Ansi:
return AnsiString;
case CharEncoding.Utf16:
return Utf16String;
case CharEncoding.Utf8:
return Utf8String;
case CharEncoding.PlatformDefined:
return PlatformDefinedString;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ public override IEnumerable<StatementSyntax> Generate(TypePositionInfo info, Stu
VariableDeclaration(
AsNativeType(info),
SingletonSeparatedList(VariableDeclarator(nativeIdentifier))));

if (TryGenerateSetupSyntax(info, context, out StatementSyntax conditionalAllocSetup))
yield return conditionalAllocSetup;

break;
case StubCodeContext.Stage.Marshal:
if (info.RefKind != RefKind.Out)
Expand Down
Loading