-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathAlgorithms.cs
351 lines (317 loc) · 17.6 KB
/
Algorithms.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
// Force disables the vectorized suffix search implementations so you can test/benchmark the scalar one
// #define FORCE_SCALAR_IMPLEMENTATION
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics.Arm;
using System.Runtime.Intrinsics.Wasm;
using System.Runtime.Intrinsics.X86;
using System.Runtime.Intrinsics;
namespace SimdDictionary {
public partial class VectorizedDictionary<TKey, TValue> {
// Extracting all this logic into each caller improves codegen slightly + reduces code size slightly, but the
// duplication reduces maintainability, so I'm pretty happy doing this instead.
// We rely on inlining to cause this struct to completely disappear, and its fields to become registers or individual locals.
// Will never fail as long as buckets isn't 0-length. You don't need to call Advance before your first loop iteration.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private ref Bucket NewEnumerator (uint hashCode, out LoopingBucketEnumerator result) {
Unsafe.SkipInit(out result);
var buckets = new Span<Bucket>(_Buckets);
var initialIndex = BucketIndexForHashCode(hashCode, buckets);
Debug.Assert(buckets.Length > 0);
// This is calculated by BucketIndexForHashCode, so it won't be out of range, but it's possible FastMod is broken if
// you concurrently resize the container, so have span bounds-check it for us.
ref var initialBucket = ref buckets[initialIndex];
result.buckets = buckets;
result.index = result.initialIndex = initialIndex;
return ref initialBucket;
}
private ref struct LoopingBucketEnumerator {
// The size of this struct is REALLY important! Adding even a single field to this will add stack spills to critical loops.
// FIXME: This span being a field puts pressure on the JIT to do recursive struct decomposition; I'm not sure it always does
internal Span<Bucket> buckets;
internal int index, initialIndex;
[Obsolete("Use VectorizedDictionary.NewEnumerator")]
public LoopingBucketEnumerator () {
}
/// <summary>
/// Walks forward through buckets, wrapping around at the end of the container.
/// Never visits a bucket twice.
/// </summary>
/// <returns>The next bucket, or NullRef if you have visited every bucket exactly once.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ref Bucket Advance () {
// Operating on the index field directly is harmless as long as the enumerator struct got decomposed, which it seems to
// Caching index into a local and then doing a writeback at the end increases generated code size so it's not worth it
if (++index >= buckets.Length)
index = 0;
if (index == initialIndex)
return ref Unsafe.NullRef<Bucket>();
else
return ref buckets[index];
}
/// <summary>
/// Walks back through the buckets you have previously visited.
/// </summary>
/// <returns>Each bucket you previously visited, exactly once, in reverse order, then NullRef.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public ref Bucket Retreat () {
if (index == initialIndex)
return ref Unsafe.NullRef<Bucket>();
if (--index < 0)
index = buckets.Length - 1;
return ref buckets[index];
}
/// <summary>
/// Indicates whether the enumerator has moved away from its original location and retreating is possible.
/// </summary>
public bool HasMoved {
[MethodImpl(MethodImplOptions.AggressiveInlining)]
get => index != initialIndex;
}
}
/// <summary>
/// Visits every bucket in the container exactly once.
/// </summary>
// Callback is passed by-ref so it can be used to store results from the enumeration operation
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void EnumerateBuckets<TCallback> (Span<Bucket> buckets, ref TCallback callback)
where TCallback : struct, IBucketCallback {
// FIXME: Using a foreach on this span produces an imul-per-iteration for some reason.
ref Bucket bucket = ref MemoryMarshal.GetReference(buckets),
lastBucket = ref Unsafe.Add(ref bucket, buckets.Length - 1);
while (true) {
var ok = callback.Bucket(ref bucket);
if (ok && !Unsafe.AreSame(ref bucket, ref lastBucket))
bucket = ref Unsafe.Add(ref bucket, 1);
else
break;
}
}
/// <summary>
/// Visits every key/value pair in the container exactly once.
/// </summary>
// Callback is passed by-ref so it can be used to store results from the enumeration operation
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void EnumeratePairs<TCallback> (Span<Bucket> buckets, ref TCallback callback)
where TCallback : struct, IPairCallback {
// FIXME: Using a foreach on this span produces an imul-per-iteration for some reason.
ref Bucket bucket = ref MemoryMarshal.GetReference(buckets),
lastBucket = ref Unsafe.Add(ref bucket, buckets.Length - 1);
while (true) {
ref var pair = ref bucket.Pairs.Pair0;
// FIXME: Awkward construction to prevent pair from ever becoming an invalid reference for a full bucket
int i = 0, c = bucket.Count;
if (i < c) {
iteration:
if (!callback.Pair(ref pair))
return;
if (++i < c) {
pair = ref Unsafe.Add(ref pair, 1);
goto iteration;
}
}
if (!Unsafe.AreSame(ref bucket, ref lastBucket))
bucket = ref Unsafe.Add(ref bucket, 1);
else
return;
}
}
/// <summary>
/// Scans the suffix table for the bucket for suffixes that match the provided search vector.
/// </summary>
/// <param name="bucket">The bucket to scan.</param>
/// <param name="searchVector">A search vector (all bytes must be the desired suffix)</param>
/// <param name="bucketCount">bucket.Count</param>
/// <returns>The location of the first match, or 32.</returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int FindSuffixInBucket (ref Bucket bucket, Vector128<byte> searchVector, int bucketCount) {
#if !FORCE_SCALAR_IMPLEMENTATION
if (Sse2.IsSupported) {
return BitOperations.TrailingZeroCount(Sse2.MoveMask(Sse2.CompareEqual(searchVector, bucket.Suffixes)));
} else if (AdvSimd.Arm64.IsSupported) {
// Completely untested
var laneBits = AdvSimd.And(
AdvSimd.CompareEqual(searchVector, bucket.Suffixes),
Vector128.Create(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128)
);
var moveMask = AdvSimd.Arm64.AddAcross(laneBits.GetLower()).ToScalar() |
(AdvSimd.Arm64.AddAcross(laneBits.GetUpper()).ToScalar() << 8);
return BitOperations.TrailingZeroCount(moveMask);
} else if (PackedSimd.IsSupported) {
// Completely untested
return BitOperations.TrailingZeroCount(PackedSimd.Bitmask(PackedSimd.CompareEqual(searchVector, bucket.Suffixes)));
} else {
#else
{
#endif
if (false) {
// Hand-unrolled scan of multiple bytes at a time. If a bucket contains 9 or more items, we will erroneously
// check lanes 15 and 16 (which contain the count and cascade count), but finding a false match there is harmless
// We could do this 4 bytes at a time instead, but that isn't actually faster
// This produces larger code than a chain of ifs.
var wideHaystack = (UInt64*)Unsafe.AsPointer(ref bucket);
for (int i = 0; i < bucketCount; i += 8, wideHaystack += 1) {
// Doing a xor this way basically performs a vectorized compare of all the lanes, and we can test the result with
// a == 0 check on the low 8 bits, which is a single 'test rNNb' instruction on x86/x64
var matchMask = *wideHaystack ^ searchVector.AsUInt64()[0];
if (Step(ref matchMask))
return i;
if (Step(ref matchMask))
return i + 1;
if (Step(ref matchMask))
return i + 2;
if (Step(ref matchMask))
return i + 3;
if (Step(ref matchMask))
return i + 4;
if (Step(ref matchMask))
return i + 5;
if (Step(ref matchMask))
return i + 6;
if (Step(ref matchMask))
return i + 7;
}
} else if (true) {
// Hand-unrolling the search into four comparisons per loop iteration is a significant performance improvement
// for a moderate code size penalty (733b -> 826b; 399usec -> 321usec, vs BCL's 421b and 270usec)
// If a bucket contains 13 or more items we will erroneously check lanes 14/15/16 but this is harmless.
var haystack = (byte*)Unsafe.AsPointer(ref bucket);
for (int i = 0; i < bucketCount; i += 4, haystack += 4) {
if (haystack[0] == searchVector[0])
return i;
if (haystack[1] == searchVector[0])
return i + 1;
if (haystack[2] == searchVector[0])
return i + 2;
if (haystack[3] == searchVector[0])
return i + 3;
}
} else {
var haystack = (byte*)Unsafe.AsPointer(ref bucket);
for (int i = 0; i < bucketCount; i++, haystack++)
if (*haystack == searchVector[0])
return i;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
static bool Step (ref UInt64 matchMask) {
if ((matchMask & 0xFF) == 0)
return true;
matchMask >>= 8;
return false;
}
return 32;
}
}
/// <summary>
/// Walks backwards through previously-visited buckets, adjusting their cascade counter upward or downward.
/// </summary>
/// <param name="enumerator">The enumerator that was used to visit buckets.</param>
/// <param name="increase">true to increase cascade counts (you added something), false to decrease (you removed something).</param>
// In the common case this method never runs, but inlining allows some smart stuff to happen in terms of stack size and
// register usage.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private void AdjustCascadeCounts (
LoopingBucketEnumerator enumerator, bool increase
) {
// Early-out before doing setup work since in the common case we won't have cascaded out of a bucket at all
if (!enumerator.HasMoved)
return;
// We may have cascaded out of a previous bucket; if so, scan backwards and update
// the cascade count for every bucket we previously scanned.
ref Bucket bucket = ref enumerator.Retreat();
while (!Unsafe.IsNullRef(ref bucket)) {
// FIXME: Track number of times we cascade out of a bucket for string rehashing anti-DoS mitigation!
var cascadeCount = bucket.CascadeCount;
if (increase) {
// Never overflow (wrap around) the counter
if (cascadeCount < DegradedCascadeCount)
bucket.CascadeCount = (ushort)(cascadeCount + 1);
} else {
if (cascadeCount == 0)
ThrowCorrupted();
// If the cascade counter hit the maximum, it's possible the actual cascade count through here is higher,
// so it's no longer safe to decrement. This is a very rare scenario, but it permanently degrades the table.
// TODO: Track this and trigger a rehash once too many buckets are in this state + dict is mostly empty.
else if (cascadeCount < DegradedCascadeCount)
bucket.CascadeCount = (ushort)(cascadeCount - 1);
}
bucket = ref enumerator.Retreat();
}
}
#pragma warning disable CS8619
// These have to be structs so that the JIT will specialize callers instead of Canonizing them
private struct DefaultComparerKeySearcher : IKeySearcher {
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static uint GetHashCode (IEqualityComparer<TKey>? comparer, TKey key) {
return FinalizeHashCode(unchecked((uint)EqualityComparer<TKey>.Default.GetHashCode(key!)));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe ref Pair FindKeyInBucket (
// We have to use UnscopedRef to allow lazy initialization
[UnscopedRef] ref Bucket bucket, int indexInBucket, int bucketCount,
IEqualityComparer<TKey>? comparer, TKey needle, out int matchIndexInBucket
) {
Unsafe.SkipInit(out matchIndexInBucket);
Debug.Assert(indexInBucket >= 0);
int count = bucketCount - indexInBucket;
if (count <= 0)
return ref Unsafe.NullRef<Pair>();
ref Pair pair = ref Unsafe.Add(ref bucket.Pairs.Pair0, indexInBucket);
while (true) {
if (EqualityComparer<TKey>.Default.Equals(needle, pair.Key)) {
// We could optimize out the bucketCount local to prevent a stack spill in some cases by doing
// Unsafe.ByteOffset(...) / sizeof(Pair), but the potential idiv is extremely painful
matchIndexInBucket = bucketCount - count;
return ref pair;
}
// NOTE: --count <= 0 produces an extra 'test' opcode
if (--count == 0)
return ref Unsafe.NullRef<Pair>();
else
pair = ref Unsafe.Add(ref pair, 1);
}
}
}
private struct ComparerKeySearcher : IKeySearcher {
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static uint GetHashCode (IEqualityComparer<TKey>? comparer, TKey key) {
return FinalizeHashCode(unchecked((uint)comparer!.GetHashCode(key!)));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe ref Pair FindKeyInBucket (
// We have to use UnscopedRef to allow lazy initialization
[UnscopedRef] ref Bucket bucket, int indexInBucket, int bucketCount,
IEqualityComparer<TKey>? comparer, TKey needle, out int matchIndexInBucket
) {
Unsafe.SkipInit(out matchIndexInBucket);
Debug.Assert(indexInBucket >= 0);
Debug.Assert(comparer != null);
int count = bucketCount - indexInBucket;
if (count <= 0)
return ref Unsafe.NullRef<Pair>();
ref Pair pair = ref Unsafe.Add(ref bucket.Pairs.Pair0, indexInBucket);
// FIXME: This loop spills two values to/from the stack every iteration, and it's not clear which.
// The ValueType-with-default-comparer one doesn't.
while (true) {
if (comparer.Equals(needle, pair.Key)) {
// We could optimize out the bucketCount local to prevent a stack spill in some cases by doing
// Unsafe.ByteOffset(...) / sizeof(Pair), but the potential idiv is extremely painful
matchIndexInBucket = bucketCount - count;
return ref pair;
}
// NOTE: --count <= 0 produces an extra 'test' opcode
if (--count == 0)
return ref Unsafe.NullRef<Pair>();
else
pair = ref Unsafe.Add(ref pair, 1);
}
}
}
#pragma warning restore CS8619
}
}