diff --git a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs index eb7a7bf2875423..d5dee5002151f1 100644 --- a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs +++ b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs @@ -89,6 +89,17 @@ public static void ComparerImplementations_Dictionary_WithWellKnownStringCompare expectedInternalComparerTypeBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture, expectedInternalComparerTypeAfterCollisionThreshold: StringComparer.InvariantCulture.GetType()); + + // CollectionsMarshal.GetValueRefOrAddDefault + + RunCollectionTestCommon( + () => new Dictionary(StringComparer.Ordinal), + (dictionary, key) => CollectionsMarshal.GetValueRefOrAddDefault(dictionary, key, out _) = null, + (dictionary, key) => dictionary.ContainsKey(key), + dictionary => dictionary.Comparer, + expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal, + expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalComparerType); static void RunDictionaryTest( IEqualityComparer equalityComparer, diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs index c08031b58d63d9..bc6eb587e80bd3 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs @@ -497,6 +497,9 @@ private int Initialize(int capacity) private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior) { + // NOTE: this method is mirrored in CollectionsMarshal.GetValueRefOrAddDefault below. + // If you make any changes here, make sure to keep that version in sync as well. + if (key == null) { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); @@ -681,6 +684,190 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior) return true; } + /// + /// A helper class containing APIs exposed through . + /// These methods are relatively niche and only used in specific scenarios, so adding them in a separate type avoids + /// the additional overhead on each instantiation, especially in AOT scenarios. + /// + internal static class CollectionsMarshalHelper + { + /// + public static ref TValue? GetValueRefOrAddDefault(Dictionary dictionary, TKey key, out bool exists) + { + // NOTE: this method is mirrored by Dictionary.TryInsert above. + // If you make any changes here, make sure to keep that version in sync as well. + + if (key == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } + + if (dictionary._buckets == null) + { + dictionary.Initialize(0); + } + Debug.Assert(dictionary._buckets != null); + + Entry[]? entries = dictionary._entries; + Debug.Assert(entries != null, "expected entries to be non-null"); + + IEqualityComparer? comparer = dictionary._comparer; + uint hashCode = (uint)((comparer == null) ? key.GetHashCode() : comparer.GetHashCode(key)); + + uint collisionCount = 0; + ref int bucket = ref dictionary.GetBucket(hashCode); + int i = bucket - 1; // Value in _buckets is 1-based + + if (comparer == null) + { + if (typeof(TKey).IsValueType) + { + // ValueType: Devirtualize with EqualityComparer.Default intrinsic + while (true) + { + // Should be a while loop https://github.com/dotnet/runtime/issues/9422 + // Test uint in if rather than loop condition to drop range check for following array access + if ((uint)i >= (uint)entries.Length) + { + break; + } + + if (entries[i].hashCode == hashCode && EqualityComparer.Default.Equals(entries[i].key, key)) + { + exists = true; + + return ref entries[i].value!; + } + + i = entries[i].next; + + collisionCount++; + if (collisionCount > (uint)entries.Length) + { + // The chain of entries forms a loop; which means a concurrent update has happened. + // Break out of the loop and throw, rather than looping forever. + ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); + } + } + } + else + { + // Object type: Shared Generic, EqualityComparer.Default won't devirtualize + // https://github.com/dotnet/runtime/issues/10050 + // So cache in a local rather than get EqualityComparer per loop iteration + EqualityComparer defaultComparer = EqualityComparer.Default; + while (true) + { + // Should be a while loop https://github.com/dotnet/runtime/issues/9422 + // Test uint in if rather than loop condition to drop range check for following array access + if ((uint)i >= (uint)entries.Length) + { + break; + } + + if (entries[i].hashCode == hashCode && defaultComparer.Equals(entries[i].key, key)) + { + exists = true; + + return ref entries[i].value!; + } + + i = entries[i].next; + + collisionCount++; + if (collisionCount > (uint)entries.Length) + { + // The chain of entries forms a loop; which means a concurrent update has happened. + // Break out of the loop and throw, rather than looping forever. + ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); + } + } + } + } + else + { + while (true) + { + // Should be a while loop https://github.com/dotnet/runtime/issues/9422 + // Test uint in if rather than loop condition to drop range check for following array access + if ((uint)i >= (uint)entries.Length) + { + break; + } + + if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) + { + exists = true; + + return ref entries[i].value!; + } + + i = entries[i].next; + + collisionCount++; + if (collisionCount > (uint)entries.Length) + { + // The chain of entries forms a loop; which means a concurrent update has happened. + // Break out of the loop and throw, rather than looping forever. + ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); + } + } + } + + int index; + if (dictionary._freeCount > 0) + { + index = dictionary._freeList; + Debug.Assert((StartOfFreeList - entries[dictionary._freeList].next) >= -1, "shouldn't overflow because `next` cannot underflow"); + dictionary._freeList = StartOfFreeList - entries[dictionary._freeList].next; + dictionary._freeCount--; + } + else + { + int count = dictionary._count; + if (count == entries.Length) + { + dictionary.Resize(); + bucket = ref dictionary.GetBucket(hashCode); + } + index = count; + dictionary._count = count + 1; + entries = dictionary._entries; + } + + ref Entry entry = ref entries![index]; + entry.hashCode = hashCode; + entry.next = bucket - 1; // Value in _buckets is 1-based + entry.key = key; + entry.value = default!; + bucket = index + 1; // Value in _buckets is 1-based + dictionary._version++; + + // Value types never rehash + if (!typeof(TKey).IsValueType && collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer) + { + // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing + // i.e. EqualityComparer.Default. + dictionary.Resize(entries.Length, true); + + exists = false; + + // At this point the entries array has been resized, so the current reference we have is no longer valid. + // We're forced to do a new lookup and return an updated reference to the new entry instance. This new + // lookup is guaranteed to always find a value though and it will never return a null reference here. + ref TValue? value = ref dictionary.FindValue(key)!; + + Debug.Assert(!Unsafe.IsNullRef(ref value), "the lookup result cannot be a null ref here"); + + return ref value; + } + + exists = false; + + return ref entry.value!; + } + } + public virtual void OnDeserialization(object? sender) { HashHelpers.SerializationInfoTable.TryGetValue(this, out SerializationInfo? siInfo); diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs index a64f44683323e2..6a60224305a6b1 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/InteropServices/CollectionsMarshal.cs @@ -29,5 +29,15 @@ public static Span AsSpan(List? list) /// public static ref TValue GetValueRefOrNullRef(Dictionary dictionary, TKey key) where TKey : notnull => ref dictionary.FindValue(key); + + /// + /// Gets a ref to a in the , adding a new entry with a default value if it does not exist in the . + /// + /// The dictionary to get the ref to from. + /// The key used for lookup. + /// Whether or not a new entry for the given key was added to the dictionary. + /// Items should not be added to or removed from the while the ref is in use. + public static ref TValue? GetValueRefOrAddDefault(Dictionary dictionary, TKey key, out bool exists) where TKey : notnull + => ref Dictionary.CollectionsMarshalHelper.GetValueRefOrAddDefault(dictionary, key, out exists); } } diff --git a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs index 374a3f14c55fc4..e6c03d03d24181 100644 --- a/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs +++ b/src/libraries/System.Runtime.InteropServices/ref/System.Runtime.InteropServices.cs @@ -177,6 +177,7 @@ public static partial class CollectionsMarshal { public static System.Span AsSpan(System.Collections.Generic.List? list) { throw null; } public static ref TValue GetValueRefOrNullRef(System.Collections.Generic.Dictionary dictionary, TKey key) where TKey : notnull { throw null; } + public static ref TValue? GetValueRefOrAddDefault(System.Collections.Generic.Dictionary dictionary, TKey key, out bool exists) where TKey : notnull { throw null; } } [System.AttributeUsageAttribute(System.AttributeTargets.Field | System.AttributeTargets.Parameter | System.AttributeTargets.Property | System.AttributeTargets.ReturnValue, Inherited=false)] public sealed partial class ComAliasNameAttribute : System.Attribute diff --git a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/CollectionsMarshalTests.cs b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/CollectionsMarshalTests.cs index 5c5116b2a72004..164d98287ad9cc 100644 --- a/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/CollectionsMarshalTests.cs +++ b/src/libraries/System.Runtime.InteropServices/tests/System/Runtime/InteropServices/CollectionsMarshalTests.cs @@ -299,6 +299,201 @@ public void GetValueRefOrNullRefLinkBreaksOnResize() Assert.Equal(50, dict.Count); } + [Fact] + public void GetValueRefOrAddDefaultValueType() + { + // This test is the same as the one for GetValueRefOrNullRef, but it uses + // GetValueRefOrAddDefault instead, and also checks for incorrect additions. + // The two APIs should behave the same when values already exist. + var dict = new Dictionary + { + { 1, default }, + { 2, default } + }; + + Assert.Equal(2, dict.Count); + + Assert.Equal(0, dict[1].Value); + Assert.Equal(0, dict[1].Property); + + var itemVal = dict[1]; + itemVal.Value = 1; + itemVal.Property = 2; + + // Does not change values in dictionary + Assert.Equal(0, dict[1].Value); + Assert.Equal(0, dict[1].Property); + + CollectionsMarshal.GetValueRefOrAddDefault(dict, 1, out bool exists).Value = 3; + + Assert.True(exists); + Assert.Equal(2, dict.Count); + + CollectionsMarshal.GetValueRefOrAddDefault(dict, 1, out exists).Property = 4; + + Assert.True(exists); + Assert.Equal(2, dict.Count); + Assert.Equal(3, dict[1].Value); + Assert.Equal(4, dict[1].Property); + + ref var itemRef = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, 2, out exists); + + Assert.True(exists); + Assert.Equal(2, dict.Count); + Assert.Equal(0, itemRef.Value); + Assert.Equal(0, itemRef.Property); + + itemRef.Value = 5; + itemRef.Property = 6; + + Assert.Equal(5, itemRef.Value); + Assert.Equal(6, itemRef.Property); + Assert.Equal(dict[2].Value, itemRef.Value); + Assert.Equal(dict[2].Property, itemRef.Property); + + itemRef = new() { Value = 7, Property = 8 }; + + Assert.Equal(7, itemRef.Value); + Assert.Equal(8, itemRef.Property); + Assert.Equal(dict[2].Value, itemRef.Value); + Assert.Equal(dict[2].Property, itemRef.Property); + + // Check for correct additions + + ref var entry3Ref = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, 3, out exists); + + Assert.False(exists); + Assert.Equal(3, dict.Count); + Assert.False(Unsafe.IsNullRef(ref entry3Ref)); + Assert.True(EqualityComparer.Default.Equals(entry3Ref, default)); + + entry3Ref.Property = 42; + entry3Ref.Value = 12345; + + var value3 = dict[3]; + + Assert.Equal(42, value3.Property); + Assert.Equal(12345, value3.Value); + } + + [Fact] + public void GetValueRefOrAddDefaultClass() + { + var dict = new Dictionary + { + { 1, new() }, + { 2, new() } + }; + + Assert.Equal(2, dict.Count); + + Assert.Equal(0, dict[1].Value); + Assert.Equal(0, dict[1].Property); + + var itemVal = dict[1]; + itemVal.Value = 1; + itemVal.Property = 2; + + // Does change values in dictionary + Assert.Equal(1, dict[1].Value); + Assert.Equal(2, dict[1].Property); + + CollectionsMarshal.GetValueRefOrAddDefault(dict, 1, out bool exists).Value = 3; + + Assert.True(exists); + Assert.Equal(2, dict.Count); + + CollectionsMarshal.GetValueRefOrAddDefault(dict, 1, out exists).Property = 4; + + Assert.True(exists); + Assert.Equal(2, dict.Count); + Assert.Equal(3, dict[1].Value); + Assert.Equal(4, dict[1].Property); + + ref var itemRef = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, 2, out exists); + + Assert.True(exists); + Assert.Equal(2, dict.Count); + Assert.Equal(0, itemRef.Value); + Assert.Equal(0, itemRef.Property); + + itemRef.Value = 5; + itemRef.Property = 6; + + Assert.Equal(5, itemRef.Value); + Assert.Equal(6, itemRef.Property); + Assert.Equal(dict[2].Value, itemRef.Value); + Assert.Equal(dict[2].Property, itemRef.Property); + + itemRef = new() { Value = 7, Property = 8 }; + + Assert.Equal(7, itemRef.Value); + Assert.Equal(8, itemRef.Property); + Assert.Equal(dict[2].Value, itemRef.Value); + Assert.Equal(dict[2].Property, itemRef.Property); + + // Check for correct additions + + ref var entry3Ref = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, 3, out exists); + + Assert.False(exists); + Assert.Equal(3, dict.Count); + Assert.False(Unsafe.IsNullRef(ref entry3Ref)); + Assert.Null(entry3Ref); + + entry3Ref = new() { Value = 12345, Property = 42 }; + + var value3 = dict[3]; + + Assert.Equal(42, value3.Property); + Assert.Equal(12345, value3.Value); + } + + [Fact] + public void GetValueRefOrAddDefaultLinkBreaksOnResize() + { + var dict = new Dictionary + { + { 1, new() } + }; + + Assert.Equal(1, dict.Count); + + ref var itemRef = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, 1, out bool exists); + + Assert.True(exists); + Assert.Equal(1, dict.Count); + Assert.Equal(0, itemRef.Value); + Assert.Equal(0, itemRef.Property); + + itemRef.Value = 1; + itemRef.Property = 2; + + Assert.Equal(1, itemRef.Value); + Assert.Equal(2, itemRef.Property); + Assert.Equal(dict[1].Value, itemRef.Value); + Assert.Equal(dict[1].Property, itemRef.Property); + + // Resize + dict.EnsureCapacity(100); + for (int i = 2; i <= 50; i++) + { + dict.Add(i, new()); + } + + itemRef.Value = 3; + itemRef.Property = 4; + + Assert.Equal(3, itemRef.Value); + Assert.Equal(4, itemRef.Property); + + // Check connection broken + Assert.NotEqual(dict[1].Value, itemRef.Value); + Assert.NotEqual(dict[1].Property, itemRef.Property); + + Assert.Equal(50, dict.Count); + } + private struct Struct { public int Value;