diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs index 6a60224305a6b1..a54dc8405f1ab6 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Runtime.CompilerServices; namespace System.Runtime.InteropServices { @@ -39,5 +40,40 @@ public static ref TValue GetValueRefOrNullRef(DictionaryItems should not be added to or removed from the while the ref is in use. public static ref TValue? GetValueRefOrAddDefault(Dictionary dictionary, TKey key, out bool exists) where TKey : notnull => ref Dictionary.CollectionsMarshalHelper.GetValueRefOrAddDefault(dictionary, key, out exists); + + /// + /// Sets the count of the to the specified value. + /// + /// The list to set the count of. + /// The value to set the list's count to. + /// + /// is . + /// + /// + /// is negative. + /// + /// + /// When increasing the count, uninitialized data is being exposed. + /// + public static void SetCount(List list, int count) + { + if (count < 0) + { + ThrowHelper.ThrowArgumentOutOfRangeException_NeedNonNegNum(nameof(count)); + } + + list._version++; + + if (count > list.Capacity) + { + list.Grow(count); + } + else if (count < list._size && RuntimeHelpers.IsReferenceOrContainsReferences()) + { + Array.Clear(list._items, count, list._size - count); + } + + list._size = count; + } } } diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs index 6e091f27cc3e70..9beac1d41e9148 100644 --- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs +++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs @@ -483,6 +483,7 @@ public static partial class CollectionsMarshal public static System.Span AsSpan(System.Collections.Generic.List? list) { throw null; } public static ref TValue GetValueRefOrNullRef(System.Collections.Generic.Dictionary dictionary, TKey key) where TKey : notnull { throw null; } public static ref TValue? GetValueRefOrAddDefault(System.Collections.Generic.Dictionary dictionary, TKey key, out bool exists) where TKey : notnull { throw null; } + public static void SetCount(System.Collections.Generic.List list, int count) { throw null; } } [System.AttributeUsageAttribute(System.AttributeTargets.Class, Inherited=false)] public sealed partial class ComDefaultInterfaceAttribute : System.Attribute diff --git a/src/libraries/System.Runtime.InteropServices/tests/System.Runtime.InteropServices.UnitTests/System/Runtime/InteropServices/CollectionsMarshalTests.cs b/src/libraries/System.Runtime.InteropServices/tests/System.Runtime.InteropServices.UnitTests/System/Runtime/InteropServices/CollectionsMarshalTests.cs index 876c0681bc6648..8a3ec2207da486 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/System.Runtime.InteropServices.UnitTests/System/Runtime/InteropServices/CollectionsMarshalTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/System.Runtime.InteropServices.UnitTests/System/Runtime/InteropServices/CollectionsMarshalTests.cs @@ -505,5 +505,62 @@ private class IntAsObject public int Value; public int Property { get; set; } } + + [Fact] + public void ListSetCount() + { + List list = null!; + Assert.Throws(() => CollectionsMarshal.SetCount(list, 3)); + + Assert.Throws(() => CollectionsMarshal.SetCount(list, -1)); + + list = new(); + Assert.Throws(() => CollectionsMarshal.SetCount(list, -1)); + + CollectionsMarshal.SetCount(list, 5); + Assert.Equal(5, list.Count); + + list = new() { 1, 2, 3, 4, 5 }; + ref int intRef = ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list)); + // make sure that size decrease preserves content + CollectionsMarshal.SetCount(list, 3); + Assert.Equal(3, list.Count); + Assert.Throws(() => list[3]); + SequenceEquals(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3 }); + Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list)))); + + // make sure that size increase preserves content and doesn't clear + CollectionsMarshal.SetCount(list, 5); + SequenceEquals(CollectionsMarshal.AsSpan(list), new int[] { 1, 2, 3, 4, 5 }); + Assert.True(Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list)))); + + // make sure that reallocations preserve content + int newCount = list.Capacity * 2; + CollectionsMarshal.SetCount(list, newCount); + Assert.Equal(newCount, list.Count); + SequenceEquals(CollectionsMarshal.AsSpan(list)[..3], new int[] { 1, 2, 3 }); + Assert.True(!Unsafe.AreSame(ref intRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(list)))); + + List listReference = new() { "a", "b", "c", "d", "e" }; + ref string stringRef = ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference)); + CollectionsMarshal.SetCount(listReference, 3); + // verify that reference types aren't cleared + SequenceEquals(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c" }); + Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference)))); + CollectionsMarshal.SetCount(listReference, 5); + // verify that removed reference types are cleared + SequenceEquals(CollectionsMarshal.AsSpan(listReference), new string[] { "a", "b", "c", null, null }); + Assert.True(Unsafe.AreSame(ref stringRef, ref MemoryMarshal.GetReference(CollectionsMarshal.AsSpan(listReference)))); + + static void SequenceEquals(ReadOnlySpan actual, ReadOnlySpan expected) + { + Assert.Equal(actual.Length, expected.Length); + + for (int i = 0; i < actual.Length; i++) + { + Assert.Equal(actual[i], expected[i]); + } + } + } } }