From e9876f36bda716892f6a3878b896991e1be02b81 Mon Sep 17 00:00:00 2001 From: Jeremy Koritzinsky Date: Fri, 11 Jun 2021 17:16:07 -0700 Subject: [PATCH] Implement marshallers for Span and ReadOnlySpan (#1222) * Implement marshallers for Span/ReadOnlySpan. * Add tests for span marshalling. * Remove unused code. * PR feedback. * Use dotnet/runtime style. * Fix allocation size in ReadOnlySpanMarshaller. --- .../Ancillary.Interop/SpanMarshallers.cs | 393 ++++++++++++++++++ .../SpanTests.cs | 160 +++++++ 2 files changed, 553 insertions(+) create mode 100644 DllImportGenerator/Ancillary.Interop/SpanMarshallers.cs create mode 100644 DllImportGenerator/DllImportGenerator.IntegrationTests/SpanTests.cs diff --git a/DllImportGenerator/Ancillary.Interop/SpanMarshallers.cs b/DllImportGenerator/Ancillary.Interop/SpanMarshallers.cs new file mode 100644 index 000000000000..dccfa6a23ad5 --- /dev/null +++ b/DllImportGenerator/Ancillary.Interop/SpanMarshallers.cs @@ -0,0 +1,393 @@ + +using System.Diagnostics; +using System.Runtime.CompilerServices; + +namespace System.Runtime.InteropServices.GeneratedMarshalling +{ + [GenericContiguousCollectionMarshaller] + public unsafe ref struct ReadOnlySpanMarshaller + { + private ReadOnlySpan _managedSpan; + private readonly int _sizeOfNativeElement; + private IntPtr _allocatedMemory; + + public ReadOnlySpanMarshaller(int sizeOfNativeElement) + : this() + { + _sizeOfNativeElement = sizeOfNativeElement; + } + + public ReadOnlySpanMarshaller(ReadOnlySpan managed, int sizeOfNativeElement) + { + _allocatedMemory = default; + _sizeOfNativeElement = sizeOfNativeElement; + if (managed.Length == 0) + { + _managedSpan = default; + NativeValueStorage = default; + return; + } + _managedSpan = managed; + _sizeOfNativeElement = sizeOfNativeElement; + int spaceToAllocate = managed.Length * sizeOfNativeElement; + _allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); + NativeValueStorage = new Span((void*)_allocatedMemory, spaceToAllocate); + } + + public ReadOnlySpanMarshaller(ReadOnlySpan managed, Span stackSpace, int sizeOfNativeElement) + { + _allocatedMemory = default; + _sizeOfNativeElement = sizeOfNativeElement; + if (managed.Length == 0) + { + _managedSpan = default; + NativeValueStorage = default; + return; + } + _managedSpan = managed; + int spaceToAllocate = managed.Length * sizeOfNativeElement; + if (spaceToAllocate <= stackSpace.Length) + { + NativeValueStorage = stackSpace[0..spaceToAllocate]; + } + else + { + _allocatedMemory = Marshal.AllocCoTaskMem(spaceToAllocate); + NativeValueStorage = new Span((void*)_allocatedMemory, spaceToAllocate); + } + } + + /// + /// Stack-alloc threshold set to 256 bytes to enable small arrays to be passed on the stack. + /// Number kept small to ensure that P/Invokes with a lot of array parameters doesn't + /// blow the stack since this is a new optimization in the code-generated interop. + /// + public const int StackBufferSize = 0x200; + + public Span ManagedValues => MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(_managedSpan), _managedSpan.Length); + + public Span NativeValueStorage { get; private set; } + + public ref byte GetPinnableReference() => ref MemoryMarshal.GetReference(NativeValueStorage); + + public void SetUnmarshalledCollectionLength(int length) + { + _managedSpan = new T[length]; + } + + public byte* Value + { + get + { + Debug.Assert(_managedSpan.IsEmpty || _allocatedMemory != IntPtr.Zero); + return (byte*)_allocatedMemory; + } + set + { + if (value == null) + { + _managedSpan = null; + NativeValueStorage = default; + } + else + { + _allocatedMemory = (IntPtr)value; + NativeValueStorage = new Span(value, _managedSpan.Length * _sizeOfNativeElement); + } + } + } + + public ReadOnlySpan ToManaged() => _managedSpan; + + public void FreeNative() + { + Marshal.FreeCoTaskMem(_allocatedMemory); + } + } + + [GenericContiguousCollectionMarshaller] + public unsafe ref struct SpanMarshaller + { + private ReadOnlySpanMarshaller _inner; + + public SpanMarshaller(int sizeOfNativeElement) + : this() + { + _inner = new ReadOnlySpanMarshaller(sizeOfNativeElement); + } + + public SpanMarshaller(Span managed, int sizeOfNativeElement) + { + _inner = new ReadOnlySpanMarshaller(managed, sizeOfNativeElement); + } + + public SpanMarshaller(Span managed, Span stackSpace, int sizeOfNativeElement) + { + _inner = new ReadOnlySpanMarshaller(managed, stackSpace, sizeOfNativeElement); + } + + /// + /// Stack-alloc threshold set to 256 bytes to enable small arrays to be passed on the stack. + /// Number kept small to ensure that P/Invokes with a lot of array parameters doesn't + /// blow the stack since this is a new optimization in the code-generated interop. + /// + public const int StackBufferSize = ReadOnlySpanMarshaller.StackBufferSize; + + public Span ManagedValues => _inner.ManagedValues; + + public Span NativeValueStorage + { + get => _inner.NativeValueStorage; + } + + public ref byte GetPinnableReference() => ref _inner.GetPinnableReference(); + + public void SetUnmarshalledCollectionLength(int length) + { + _inner.SetUnmarshalledCollectionLength(length); + } + + public byte* Value + { + get => _inner.Value; + set => _inner.Value = value; + } + + public Span ToManaged() + { + ReadOnlySpan managedInner = _inner.ToManaged(); + return MemoryMarshal.CreateSpan(ref MemoryMarshal.GetReference(managedInner), managedInner.Length); + } + + public void FreeNative() + { + _inner.FreeNative(); + } + } + + [GenericContiguousCollectionMarshaller] + public unsafe ref struct NeverNullSpanMarshaller + { + private SpanMarshaller _inner; + + public NeverNullSpanMarshaller(int sizeOfNativeElement) + : this() + { + _inner = new SpanMarshaller(sizeOfNativeElement); + } + + public NeverNullSpanMarshaller(Span managed, int sizeOfNativeElement) + { + _inner = new SpanMarshaller(managed, sizeOfNativeElement); + } + + public NeverNullSpanMarshaller(Span managed, Span stackSpace, int sizeOfNativeElement) + { + _inner = new SpanMarshaller(managed, stackSpace, sizeOfNativeElement); + } + + /// + /// Stack-alloc threshold set to 256 bytes to enable small spans to be passed on the stack. + /// Number kept small to ensure that P/Invokes with a lot of span parameters doesn't + /// blow the stack. + /// + public const int StackBufferSize = SpanMarshaller.StackBufferSize; + + public Span ManagedValues => _inner.ManagedValues; + + public Span NativeValueStorage + { + get => _inner.NativeValueStorage; + } + + public ref byte GetPinnableReference() + { + if (_inner.ManagedValues.Length == 0) + { + return ref *(byte*)0xa5a5a5a5; + } + return ref _inner.GetPinnableReference(); + } + + public void SetUnmarshalledCollectionLength(int length) + { + _inner.SetUnmarshalledCollectionLength(length); + } + + public byte* Value + { + get + { + if (_inner.ManagedValues.Length == 0) + { + return (byte*)0x1; + } + return _inner.Value; + } + + set => _inner.Value = value; + } + + public Span ToManaged() => _inner.ToManaged(); + + public void FreeNative() + { + _inner.FreeNative(); + } + } + + [GenericContiguousCollectionMarshaller] + public unsafe ref struct NeverNullReadOnlySpanMarshaller + { + private ReadOnlySpanMarshaller _inner; + + public NeverNullReadOnlySpanMarshaller(int sizeOfNativeElement) + : this() + { + _inner = new ReadOnlySpanMarshaller(sizeOfNativeElement); + } + + public NeverNullReadOnlySpanMarshaller(ReadOnlySpan managed, int sizeOfNativeElement) + { + _inner = new ReadOnlySpanMarshaller(managed, sizeOfNativeElement); + } + + public NeverNullReadOnlySpanMarshaller(ReadOnlySpan managed, Span stackSpace, int sizeOfNativeElement) + { + _inner = new ReadOnlySpanMarshaller(managed, stackSpace, sizeOfNativeElement); + } + + /// + /// Stack-alloc threshold set to 256 bytes to enable small spans to be passed on the stack. + /// Number kept small to ensure that P/Invokes with a lot of span parameters doesn't + /// blow the stack. + /// + public const int StackBufferSize = SpanMarshaller.StackBufferSize; + + public Span ManagedValues => _inner.ManagedValues; + + public Span NativeValueStorage + { + get => _inner.NativeValueStorage; + } + + public ref byte GetPinnableReference() + { + if (_inner.ManagedValues.Length == 0) + { + return ref *(byte*)0xa5a5a5a5; + } + return ref _inner.GetPinnableReference(); + } + + public void SetUnmarshalledCollectionLength(int length) + { + _inner.SetUnmarshalledCollectionLength(length); + } + + public byte* Value + { + get + { + if (_inner.ManagedValues.Length == 0) + { + return (byte*)0x1; + } + return _inner.Value; + } + + set => _inner.Value = value; + } + + public ReadOnlySpan ToManaged() => _inner.ToManaged(); + + public void FreeNative() + { + _inner.FreeNative(); + } + } + + [GenericContiguousCollectionMarshaller] + public unsafe ref struct DirectSpanMarshaller + where T : unmanaged + { + private int _unmarshalledLength; + private T* _allocatedMemory; + private Span _data; + + public DirectSpanMarshaller(int sizeOfNativeElement) + :this() + { + // This check is not exhaustive, but it will catch the majority of cases. + if (typeof(T) == typeof(bool) || typeof(T) == typeof(char) || Unsafe.SizeOf() != sizeOfNativeElement) + { + throw new ArgumentException("This marshaller only supports blittable element types. The provided type parameter must be blittable", nameof(T)); + } + } + + public DirectSpanMarshaller(Span managed, int sizeOfNativeElement) + :this(sizeOfNativeElement) + { + if (managed.Length == 0) + { + return; + } + + int spaceToAllocate = managed.Length * Unsafe.SizeOf(); + _allocatedMemory = (T*)Marshal.AllocCoTaskMem(spaceToAllocate); + _data = managed; + } + + public DirectSpanMarshaller(Span managed, Span stackSpace, int sizeOfNativeElement) + :this(sizeOfNativeElement) + { + Debug.Assert(stackSpace.IsEmpty); + _data = managed; + } + + /// + /// Stack-alloc threshold set to 0 so that the generator can use the constructor that takes a stackSpace to let the marshaller know that the original data span can be used and safely pinned. + /// + public const int StackBufferSize = 0; + + public Span ManagedValues => _data; + + public Span NativeValueStorage => _allocatedMemory != null + ? new Span(_allocatedMemory, _data.Length * Unsafe.SizeOf()) + : MemoryMarshal.Cast(_data); + + public ref T GetPinnableReference() => ref _data.GetPinnableReference(); + + public void SetUnmarshalledCollectionLength(int length) + { + _unmarshalledLength = length; + } + + public T* Value + { + get + { + Debug.Assert(_data.IsEmpty || _allocatedMemory != null); + return _allocatedMemory; + } + set + { + // We don't save the pointer assigned here to be freed + // since this marshaller passes back the actual memory span from native code + // back to managed code. + _allocatedMemory = null; + _data = new Span(value, _unmarshalledLength); + } + } + + public Span ToManaged() + { + return _data; + } + + public void FreeNative() + { + Marshal.FreeCoTaskMem((IntPtr)_allocatedMemory); + } + } +} \ No newline at end of file diff --git a/DllImportGenerator/DllImportGenerator.IntegrationTests/SpanTests.cs b/DllImportGenerator/DllImportGenerator.IntegrationTests/SpanTests.cs new file mode 100644 index 000000000000..53421dbc47d1 --- /dev/null +++ b/DllImportGenerator/DllImportGenerator.IntegrationTests/SpanTests.cs @@ -0,0 +1,160 @@ +using SharedTypes; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.InteropServices.GeneratedMarshalling; +using System.Text; + +using Xunit; + +namespace DllImportGenerator.IntegrationTests +{ + partial class NativeExportsNE + { + public partial class Span + { + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int Sum([MarshalUsing(typeof(SpanMarshaller))] Span values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int SumNeverNull([MarshalUsing(typeof(NeverNullSpanMarshaller))] Span values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array")] + public static partial int SumNeverNull([MarshalUsing(typeof(NeverNullReadOnlySpanMarshaller))] ReadOnlySpan values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "sum_int_array_ref")] + public static partial int SumInArray([MarshalUsing(typeof(SpanMarshaller))] in Span values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "duplicate_int_array")] + public static partial void Duplicate([MarshalUsing(typeof(SpanMarshaller), CountElementName = "numValues")] ref Span values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "duplicate_int_array")] + public static partial void DuplicateRaw([MarshalUsing(typeof(DirectSpanMarshaller), CountElementName = "numValues")] ref Span values, int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "create_range_array")] + [return: MarshalUsing(typeof(SpanMarshaller), CountElementName = "numValues")] + public static partial Span CreateRange(int start, int end, out int numValues); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "create_range_array_out")] + public static partial void CreateRange_Out(int start, int end, out int numValues, [MarshalUsing(typeof(SpanMarshaller), CountElementName = "numValues")] out Span res); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "get_long_bytes")] + [return: MarshalUsing(typeof(SpanMarshaller), ConstantElementCount = sizeof(long))] + public static partial Span GetLongBytes(long l); + + [GeneratedDllImport(NativeExportsNE_Binary, EntryPoint = "and_all_members")] + [return: MarshalAs(UnmanagedType.U1)] + public static partial bool AndAllMembers([MarshalUsing(typeof(SpanMarshaller))] Span pArray, int length); + } + } + + public class SpanTests + { + [Fact] + public void BlittableElementSpanMarshalledToNativeAsExpected() + { + var list = new int[] { 1, 5, 79, 165, 32, 3 }; + Assert.Equal(list.Sum(), NativeExportsNE.Span.Sum(list, list.Length)); + } + + [Fact] + public void DefaultBlittableElementSpanMarshalledToNativeAsExpected() + { + Assert.Equal(-1, NativeExportsNE.Span.Sum(default, 0)); + } + + [Fact] + public void NeverNullSpanMarshallerMarshalsDefaultAsNonNull() + { + Assert.Equal(0, NativeExportsNE.Span.SumNeverNull(Span.Empty, 0)); + } + + [Fact] + public void NeverNullReadOnlySpanMarshallerMarshalsDefaultAsNonNull() + { + Assert.Equal(0, NativeExportsNE.Span.SumNeverNull(ReadOnlySpan.Empty, 0)); + } + + [Fact] + public void BlittableElementSpanInParameter() + { + var list = new int[] { 1, 5, 79, 165, 32, 3 }; + Assert.Equal(list.Sum(), NativeExportsNE.Span.SumInArray(list, list.Length)); + } + + [Fact] + public void BlittableElementSpanRefParameter() + { + var list = new int[] { 1, 5, 79, 165, 32, 3 }; + Span newSpan = list; + NativeExportsNE.Span.Duplicate(ref newSpan, list.Length); + Assert.Equal((IEnumerable)list, newSpan.ToArray()); + } + + [Fact] + public unsafe void DirectSpanMarshaller() + { + var list = new int[] { 1, 5, 79, 165, 32, 3 }; + Span newSpan = list; + NativeExportsNE.Span.DuplicateRaw(ref newSpan, list.Length); + Assert.Equal((IEnumerable)list, newSpan.ToArray()); + Marshal.FreeCoTaskMem((IntPtr)Unsafe.AsPointer(ref newSpan.GetPinnableReference())); + } + + [Fact] + public void BlittableElementSpanReturnedFromNative() + { + int start = 5; + int end = 20; + + IEnumerable expected = Enumerable.Range(start, end - start); + Assert.Equal(expected, NativeExportsNE.Collections.CreateRange(start, end, out _)); + + Span res; + NativeExportsNE.Span.CreateRange_Out(start, end, out _, out res); + Assert.Equal(expected, res.ToArray()); + } + + [Fact] + public void NullBlittableElementSpanReturnedFromNative() + { + Assert.Null(NativeExportsNE.Collections.CreateRange(1, 0, out _)); + + Span res; + NativeExportsNE.Span.CreateRange_Out(1, 0, out _, out res); + Assert.True(res.IsEmpty); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void SpanWithSimpleNonBlittableTypeMarshalling(bool result) + { + var boolValues = new BoolStruct[] + { + new BoolStruct + { + b1 = true, + b2 = true, + b3 = true, + }, + new BoolStruct + { + b1 = true, + b2 = true, + b3 = true, + }, + new BoolStruct + { + b1 = true, + b2 = true, + b3 = result, + }, + }; + + Assert.Equal(result, NativeExportsNE.Span.AndAllMembers(boolValues, boolValues.Length)); + } + } +}