@@ -45,6 +45,7 @@ struct WgpDetails {
45
45
std::array<int32_t , 3 > maxWorkgroupSizes;
46
46
uint32_t maxThreadSize;
47
47
uint32_t maxWorkgroupMemoryBytes;
48
+ std::array<int32_t , 3 > maxWorkgroupCounts;
48
49
};
49
50
50
51
// Chip level feature/limit details
@@ -106,7 +107,9 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,
106
107
MMAOpsArrayAttr::get (context, mmaAttrs),
107
108
DenseI32ArrayAttr::get (context, subgroupSizes),
108
109
DenseI32ArrayAttr::get (context, wgp->maxWorkgroupSizes ),
109
- wgp->maxThreadSize , wgp->maxWorkgroupMemoryBytes , DictionaryAttr{});
110
+ wgp->maxThreadSize , wgp->maxWorkgroupMemoryBytes ,
111
+ DenseI32ArrayAttr::get (context, wgp->maxWorkgroupCounts ),
112
+ DictionaryAttr{});
110
113
111
114
TargetChipAttr targetChip;
112
115
if (details.chip )
@@ -118,6 +121,10 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,
118
121
119
122
// ===----------------------------------------------------------------------===//
120
123
// Known AMD target details
124
+ //
125
+ // Note: the max workgroup size is given as signed int32 max because MLIR's
126
+ // `index` is signed and the workgroup ID is sign-extended, not zero-extended,
127
+ // to 64-bits.
121
128
// ===----------------------------------------------------------------------===//
122
129
123
130
const WgpDetails *getCDNA3WgpDetails () {
@@ -129,11 +136,17 @@ const WgpDetails *getCDNA3WgpDetails() {
129
136
MMAIntrinsic::MFMA_I32_16x16x32_I8,
130
137
MMAIntrinsic::MFMA_I32_32x32x16_I8,
131
138
};
132
- static const WgpDetails cdna3Wgp = {
133
- allComputeBits, allStorageBits, allSubgroupOps,
134
- allDotProductOps, ARRAY_SIZE (cdna3MMAOps), cdna3MMAOps,
135
- {64 , 64 }, {1024 , 1024 , 1024 }, 1024 ,
136
- 64 * 1024 };
139
+ static const WgpDetails cdna3Wgp = {allComputeBits,
140
+ allStorageBits,
141
+ allSubgroupOps,
142
+ allDotProductOps,
143
+ ARRAY_SIZE (cdna3MMAOps),
144
+ cdna3MMAOps,
145
+ {64 , 64 },
146
+ {1024 , 1024 , 1024 },
147
+ 1024 ,
148
+ 64 * 1024 ,
149
+ {0x7fffffff , 0x7fffffff , 0x7fffffff }};
137
150
return &cdna3Wgp;
138
151
}
139
152
@@ -142,11 +155,17 @@ const WgpDetails *getCDNA2WgpDetails() {
142
155
MMAIntrinsic::MFMA_F32_16x16x16_F16,
143
156
MMAIntrinsic::MFMA_F32_32x32x8_F16,
144
157
};
145
- static const WgpDetails cdna2Wgp = {
146
- allComputeBits, allStorageBits, allSubgroupOps,
147
- allDotProductOps, ARRAY_SIZE (cdna2MMAOps), cdna2MMAOps,
148
- {64 , 64 }, {1024 , 1024 , 1024 }, 1024 ,
149
- 64 * 1024 };
158
+ static const WgpDetails cdna2Wgp = {allComputeBits,
159
+ allStorageBits,
160
+ allSubgroupOps,
161
+ allDotProductOps,
162
+ ARRAY_SIZE (cdna2MMAOps),
163
+ cdna2MMAOps,
164
+ {64 , 64 },
165
+ {1024 , 1024 , 1024 },
166
+ 1024 ,
167
+ 64 * 1024 ,
168
+ {0x7fffffff , 0x7fffffff , 0x7fffffff }};
150
169
return &cdna2Wgp;
151
170
}
152
171
@@ -155,11 +174,17 @@ const WgpDetails *getCDNA1WgpDetails() {
155
174
MMAIntrinsic::MFMA_F32_16x16x16_F16,
156
175
MMAIntrinsic::MFMA_F32_32x32x8_F16,
157
176
};
158
- static const WgpDetails cdna1Wgp = {
159
- allComputeBits, allStorageBits, allSubgroupOps,
160
- allDotProductOps, ARRAY_SIZE (cdna1MMAOps), cdna1MMAOps,
161
- {64 , 64 }, {1024 , 1024 , 1024 }, 1024 ,
162
- 64 * 1024 };
177
+ static const WgpDetails cdna1Wgp = {allComputeBits,
178
+ allStorageBits,
179
+ allSubgroupOps,
180
+ allDotProductOps,
181
+ ARRAY_SIZE (cdna1MMAOps),
182
+ cdna1MMAOps,
183
+ {64 , 64 },
184
+ {1024 , 1024 , 1024 },
185
+ 1024 ,
186
+ 64 * 1024 ,
187
+ {0x7fffffff , 0x7fffffff , 0x7fffffff }};
163
188
return &cdna1Wgp;
164
189
}
165
190
@@ -168,27 +193,39 @@ const WgpDetails *getRDNA3WgpDetails() {
168
193
MMAIntrinsic::WMMA_F32_16x16x16_F16,
169
194
MMAIntrinsic::WMMA_F16_16x16x16_F16,
170
195
};
171
- static const WgpDetails rdna3Wgp = {
172
- allComputeBits, allStorageBits, allSubgroupOps,
173
- allDotProductOps, ARRAY_SIZE (rdna3MMAOps), rdna3MMAOps,
174
- {32 , 64 }, {1024 , 1024 , 1024 }, 1024 ,
175
- 64 * 1024 };
196
+ static const WgpDetails rdna3Wgp = {allComputeBits,
197
+ allStorageBits,
198
+ allSubgroupOps,
199
+ allDotProductOps,
200
+ ARRAY_SIZE (rdna3MMAOps),
201
+ rdna3MMAOps,
202
+ {32 , 64 },
203
+ {1024 , 1024 , 1024 },
204
+ 1024 ,
205
+ 64 * 1024 ,
206
+ {0x7fffffff , 0x7fffffff , 0x7fffffff }};
176
207
return &rdna3Wgp;
177
208
}
178
209
179
210
const WgpDetails *getRDNA2WgpDetails () {
180
211
static const WgpDetails rdna2Wgp = {
181
- allComputeBits, allStorageBits, allSubgroupOps, allDotProductOps,
182
- /* mmaCount=*/ 0 , /* mmaOps=*/ nullptr , {32 , 64 }, {1024 , 1024 , 1024 },
183
- 1024 , 64 * 1024 };
212
+ allComputeBits, allStorageBits,
213
+ allSubgroupOps, allDotProductOps,
214
+ /* mmaCount=*/ 0 ,
215
+ /* mmaOps=*/ nullptr , {32 , 64 },
216
+ {1024 , 1024 , 1024 }, 1024 ,
217
+ 64 * 1024 , {0x7fffffff , 0x7fffffff , 0x7fffffff }};
184
218
return &rdna2Wgp;
185
219
}
186
220
187
221
const WgpDetails *getRDNA1WgpDetails () {
188
222
static const WgpDetails rdna1Wgp = {
189
- allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
190
- /* mmaCount=*/ 0 , /* mmaOps=*/ nullptr , {32 , 64 }, {1024 , 1024 , 1024 },
191
- 1024 , 64 * 1024 };
223
+ allComputeBits, allStorageBits,
224
+ allSubgroupOps, DotProductOps::None,
225
+ /* mmaCount=*/ 0 ,
226
+ /* mmaOps=*/ nullptr , {32 , 64 },
227
+ {1024 , 1024 , 1024 }, 1024 ,
228
+ 64 * 1024 , {0x7fffffff , 0x7fffffff , 0x7fffffff }};
192
229
return &rdna1Wgp;
193
230
}
194
231
@@ -281,7 +318,9 @@ std::optional<TargetDetails> getAppleTargetDetails() {
281
318
static const WgpDetails wgp = {
282
319
computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
283
320
/* mmaCount=*/ 0 , /* mmaOps=*/ nullptr , {32 , 32 },
284
- {1024 , 1024 , 1024 }, 1024 , 32 * 1024 };
321
+ {1024 , 1024 , 1024 }, 1024 , 32 * 1024 ,
322
+ // Note: These values have not been checked and may be higher
323
+ {0xffff , 0xffff , 0xffff }};
285
324
// clang-format on
286
325
287
326
return TargetDetails{&wgp, nullptr };
@@ -302,7 +341,9 @@ const WgpDetails *getValhallWgpDetails() {
302
341
static const WgpDetails valhallWgp = {
303
342
computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
304
343
/* mmaCount=*/ 0 , /* mmaOps=*/ nullptr , {16 , 16 }, {512 , 512 , 512 },
305
- 512 , 32 * 1024 };
344
+ 512 , 32 * 1024 ,
345
+ // Note: These values have not been checked and may be higher
346
+ {0xffff , 0xffff , 0xffff }};
306
347
// clang-format on
307
348
return &valhallWgp;
308
349
}
@@ -358,11 +399,17 @@ const WgpDetails *getAmpereWgpDetails() {
358
399
MMAIntrinsic::WMMA_F32_16x16x16_F16,
359
400
MMAIntrinsic::WMMA_F16_16x16x16_F16,
360
401
};
361
- static const WgpDetails ampereWgp = {
362
- allComputeBits, allStorageBits, allSubgroupOps,
363
- allDotProductOps, ARRAY_SIZE (mmaOps), mmaOps,
364
- {32 , 32 }, {1024 , 1024 , 1024 }, 1024 ,
365
- 163 * 1024 };
402
+ static const WgpDetails ampereWgp = {allComputeBits,
403
+ allStorageBits,
404
+ allSubgroupOps,
405
+ allDotProductOps,
406
+ ARRAY_SIZE (mmaOps),
407
+ mmaOps,
408
+ {32 , 32 },
409
+ {1024 , 1024 , 1024 },
410
+ 1024 ,
411
+ 163 * 1024 ,
412
+ {0x7fffffff , 0xffff , 0xffff }};
366
413
return &ereWgp;
367
414
}
368
415
@@ -371,11 +418,17 @@ const WgpDetails *getTuringWgpDetails() {
371
418
MMAIntrinsic::WMMA_F32_16x16x16_F16,
372
419
MMAIntrinsic::WMMA_F16_16x16x16_F16,
373
420
};
374
- static const WgpDetails turingWgp = {
375
- allComputeBits, allStorageBits, allSubgroupOps,
376
- allDotProductOps, ARRAY_SIZE (mmaOps), mmaOps,
377
- {32 , 32 }, {1024 , 1024 , 1024 }, 1024 ,
378
- 64 * 1024 };
421
+ static const WgpDetails turingWgp = {allComputeBits,
422
+ allStorageBits,
423
+ allSubgroupOps,
424
+ allDotProductOps,
425
+ ARRAY_SIZE (mmaOps),
426
+ mmaOps,
427
+ {32 , 32 },
428
+ {1024 , 1024 , 1024 },
429
+ 1024 ,
430
+ 64 * 1024 ,
431
+ {0x7fffffff , 0xffff , 0xffff }};
379
432
return &turingWgp;
380
433
}
381
434
@@ -388,7 +441,8 @@ const WgpDetails *getVoltaWgpDetails() {
388
441
static const WgpDetails voltaWgp = {
389
442
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
390
443
ARRAY_SIZE (mmaOps), mmaOps, {32 , 32 }, {1024 , 1024 , 1024 },
391
- 1024 , 96 * 1024 };
444
+ 1024 , 96 * 1024 ,
445
+ {0x7fffffff , 0xffff , 0xffff }};
392
446
// clang-format on
393
447
return &voltaWgp;
394
448
}
@@ -398,7 +452,8 @@ const WgpDetails *getPascalWgpDetails() {
398
452
static const WgpDetails pascalWgp = {
399
453
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
400
454
0 , nullptr , // Pascal does not have tensor core support.
401
- {32 , 32 }, {1024 , 1024 , 1024 }, 1024 , 48 * 1024 };
455
+ {32 , 32 }, {1024 , 1024 , 1024 }, 1024 , 48 * 1024 ,
456
+ {0x7fffffff , 0xffff , 0xffff }};
402
457
// clang-format on
403
458
return &pascalWgp;
404
459
}
@@ -479,7 +534,9 @@ const WgpDetails *getAdrenoWgpDetails() {
479
534
computeBitwdiths, storageBitwidths, allSubgroupOps,
480
535
allDotProductOps, /* mmaCount=*/ 0 , /* mmaOps=*/ nullptr ,
481
536
{64 , 64 }, {1024 , 1024 , 1024 }, 1024 ,
482
- 32 * 1024 };
537
+ 32 * 1024 ,
538
+ // Note: These values have not been checked and may be higher
539
+ {0xffff , 0xffff , 0xffff }};
483
540
// clang-format on
484
541
return &adrenoWgp;
485
542
}
@@ -545,7 +602,8 @@ const WgpDetails *getAndroidBaseline2022WgpDetails() {
545
602
computeBitwdiths, storageBitwidths, SubgroupOps::None,
546
603
DotProductOps::None, /* mmaCount=*/ 0 , /* mmaOps=*/ nullptr ,
547
604
{64 , 64 }, {128 , 128 , 64 }, 128 ,
548
- 16 * 1024 };
605
+ 16 * 1024 ,
606
+ {0xffff , 0xffff , 0xffff }};
549
607
// clang-format on
550
608
return &androidWgp;
551
609
}
@@ -645,7 +703,8 @@ TargetAttr getWebGPUTargetDetails(MLIRContext *context) {
645
703
computeBitwdiths, storageBitwidths, SubgroupOps::None,
646
704
DotProductOps::None, /* mmaCount=*/ 0 , /* mmaOps=*/ nullptr ,
647
705
{32 , 32 }, {128 , 128 , 64 }, 128 ,
648
- 16 * 1024 };
706
+ 16 * 1024 ,
707
+ {0xffff , 0xffff , 0xffff }};
649
708
// clang-format on
650
709
651
710
return createTargetAttr (
0 commit comments