Skip to content

Commit 300af39

Browse files
authored
[codegen] Add max_workgroup_counts to TargetWgpAttr (iree-org#17771)
This commit adds a max_workgroup_counts to the workgroup processor information attribute and sets values for the known targets. Some of these values may be underestimates as I was not able to locate information on their values. This field is added so that we can annotate calls to workgroup.id and workgroup.count with upper bounds, neabling range inference and strength reduction. Note that in some cases (for instance, AMD) we give a max_workgroup_counts value lower than what is actually supported because a grid dimension greater than int32_max would be sign-extended to a negative number to meet the 64-bit nature of `index`. (This PR is split out of iree-org#17707) Signed-off-by: Krzysztof Drewniak <Krzysztof.Drewniak@amd.com>
1 parent 9b05f17 commit 300af39

25 files changed

+242
-114
lines changed

compiler/plugins/target/MetalSPIRV/test/smoketest.mlir

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ module attributes {
66
#hal.executable.target<"metal-spirv", "metal-msl-fb", {
77
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
88
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
9-
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
9+
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
10+
max_workgroup_counts = [65535, 65535, 65535]>>
1011
}>
1112
]> : !hal.device
1213
]

compiler/plugins/target/ROCM/test/target_device_features.mlir

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
// GFX942-SAME: subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
99
// GFX942-SAME: mma = [<MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>],
1010
// GFX942-SAME: subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
11-
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>,
11+
// GFX942-SAME: max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
12+
// GFX942-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>,
1213
// GFX942-SAME: chip = <wgp_count = 304>>
1314

1415
// GFX940: target = #iree_gpu.target<arch = "gfx940",

compiler/plugins/target/VulkanSPIRV/test/smoketest.mlir

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ module attributes {
66
#hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
77
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
88
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32, 32],
9-
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
9+
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
10+
max_workgroup_counts = [65535, 65535, 65535]>>
1011
}>
1112
]> : !hal.device
1213
]

compiler/plugins/target/WebGPUSPIRV/test/smoketest.mlir

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ module attributes {
77
#hal.executable.target<"webgpu-spirv", "webgpu-wgsl-fb", {
88
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.0,cap:Shader,ext:SPV_KHR_storage_buffer_storage_class", wgp = <
99
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], subgroup_size_choices = [32],
10-
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384>>
10+
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
11+
max_workgroup_counts = [65535, 65535, 65535]>>
1112
}>
1213
]> : !hal.device
1314
]

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td

+2
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,8 @@ def IREEGPU_TargetWgpAttr : AttrDef<IREEGPU_Dialect, "TargetWgp"> {
290290
"uint32_t":$max_thread_count_per_workgroup,
291291
// The maximal number of shared memory bytes we can allocate per workgroup.
292292
"uint32_t":$max_workgroup_memory_bytes,
293+
// Tthe maximum number of workgroups per X/Y/Z dimension in a dispatch.
294+
"DenseI32ArrayAttr":$max_workgroup_counts,
293295

294296
// An optional extra dict
295297
// This field allows to inject more features/limits not supported in the

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/test/target_attrs.mlir

+8-4
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@ func.func @test_target_wgp() attributes {
1111
// CHECK-SAME: subgroup_size_choices = [32, 64],
1212
// CHECK-SAME: max_workgroup_sizes = [1024, 1024, 1024],
1313
// CHECK-SAME: max_thread_count_per_workgroup = 1024,
14-
// CHECK-SAME: max_workgroup_memory_bytes = 65536>
14+
// CHECK-SAME: max_workgroup_memory_bytes = 65536,
15+
// CHECK-SAME: max_workgroup_counts = [2147483647, 2147483647, 2147483647]>
1516
wgp = #iree_gpu.target_wgp<
1617
compute = fp16|fp32|int8, storage = b16|b32,
1718
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
1819
mma = [<MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>],
1920
subgroup_size_choices = [32, 64],
2021
max_workgroup_sizes = [1024, 1024, 1024],
2122
max_thread_count_per_workgroup = 1024,
22-
max_workgroup_memory_bytes = 65536
23+
max_workgroup_memory_bytes = 65536,
24+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]
2325
>
2426
} { return }
2527

@@ -37,7 +39,8 @@ func.func @test_target_wgp_none() attributes {
3739
subgroup_size_choices = [32],
3840
max_workgroup_sizes = [1024, 1024, 1024],
3941
max_thread_count_per_workgroup = 1024,
40-
max_workgroup_memory_bytes = 65536
42+
max_workgroup_memory_bytes = 65536,
43+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]
4144
>
4245
} { return }
4346

@@ -67,7 +70,8 @@ func.func @test_target() attributes {
6770
subgroup_size_choices = [32, 64],
6871
max_workgroup_sizes = [1024, 1024, 1024],
6972
max_thread_count_per_workgroup = 1024,
70-
max_workgroup_memory_bytes = 65536>,
73+
max_workgroup_memory_bytes = 65536,
74+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>,
7175
chip = <wgp_count = 304>
7276
>
7377
} { return }

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.cpp

+103-44
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct WgpDetails {
4545
std::array<int32_t, 3> maxWorkgroupSizes;
4646
uint32_t maxThreadSize;
4747
uint32_t maxWorkgroupMemoryBytes;
48+
std::array<int32_t, 3> maxWorkgroupCounts;
4849
};
4950

5051
// Chip level feature/limit details
@@ -106,7 +107,9 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,
106107
MMAOpsArrayAttr::get(context, mmaAttrs),
107108
DenseI32ArrayAttr::get(context, subgroupSizes),
108109
DenseI32ArrayAttr::get(context, wgp->maxWorkgroupSizes),
109-
wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes, DictionaryAttr{});
110+
wgp->maxThreadSize, wgp->maxWorkgroupMemoryBytes,
111+
DenseI32ArrayAttr::get(context, wgp->maxWorkgroupCounts),
112+
DictionaryAttr{});
110113

111114
TargetChipAttr targetChip;
112115
if (details.chip)
@@ -118,6 +121,10 @@ TargetAttr createTargetAttr(const TargetDetails &details, StringRef arch,
118121

119122
//===----------------------------------------------------------------------===//
120123
// Known AMD target details
124+
//
125+
// Note: the max workgroup size is given as signed int32 max because MLIR's
126+
// `index` is signed and the workgroup ID is sign-extended, not zero-extended,
127+
// to 64-bits.
121128
//===----------------------------------------------------------------------===//
122129

123130
const WgpDetails *getCDNA3WgpDetails() {
@@ -129,11 +136,17 @@ const WgpDetails *getCDNA3WgpDetails() {
129136
MMAIntrinsic::MFMA_I32_16x16x32_I8,
130137
MMAIntrinsic::MFMA_I32_32x32x16_I8,
131138
};
132-
static const WgpDetails cdna3Wgp = {
133-
allComputeBits, allStorageBits, allSubgroupOps,
134-
allDotProductOps, ARRAY_SIZE(cdna3MMAOps), cdna3MMAOps,
135-
{64, 64}, {1024, 1024, 1024}, 1024,
136-
64 * 1024};
139+
static const WgpDetails cdna3Wgp = {allComputeBits,
140+
allStorageBits,
141+
allSubgroupOps,
142+
allDotProductOps,
143+
ARRAY_SIZE(cdna3MMAOps),
144+
cdna3MMAOps,
145+
{64, 64},
146+
{1024, 1024, 1024},
147+
1024,
148+
64 * 1024,
149+
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
137150
return &cdna3Wgp;
138151
}
139152

@@ -142,11 +155,17 @@ const WgpDetails *getCDNA2WgpDetails() {
142155
MMAIntrinsic::MFMA_F32_16x16x16_F16,
143156
MMAIntrinsic::MFMA_F32_32x32x8_F16,
144157
};
145-
static const WgpDetails cdna2Wgp = {
146-
allComputeBits, allStorageBits, allSubgroupOps,
147-
allDotProductOps, ARRAY_SIZE(cdna2MMAOps), cdna2MMAOps,
148-
{64, 64}, {1024, 1024, 1024}, 1024,
149-
64 * 1024};
158+
static const WgpDetails cdna2Wgp = {allComputeBits,
159+
allStorageBits,
160+
allSubgroupOps,
161+
allDotProductOps,
162+
ARRAY_SIZE(cdna2MMAOps),
163+
cdna2MMAOps,
164+
{64, 64},
165+
{1024, 1024, 1024},
166+
1024,
167+
64 * 1024,
168+
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
150169
return &cdna2Wgp;
151170
}
152171

@@ -155,11 +174,17 @@ const WgpDetails *getCDNA1WgpDetails() {
155174
MMAIntrinsic::MFMA_F32_16x16x16_F16,
156175
MMAIntrinsic::MFMA_F32_32x32x8_F16,
157176
};
158-
static const WgpDetails cdna1Wgp = {
159-
allComputeBits, allStorageBits, allSubgroupOps,
160-
allDotProductOps, ARRAY_SIZE(cdna1MMAOps), cdna1MMAOps,
161-
{64, 64}, {1024, 1024, 1024}, 1024,
162-
64 * 1024};
177+
static const WgpDetails cdna1Wgp = {allComputeBits,
178+
allStorageBits,
179+
allSubgroupOps,
180+
allDotProductOps,
181+
ARRAY_SIZE(cdna1MMAOps),
182+
cdna1MMAOps,
183+
{64, 64},
184+
{1024, 1024, 1024},
185+
1024,
186+
64 * 1024,
187+
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
163188
return &cdna1Wgp;
164189
}
165190

@@ -168,27 +193,39 @@ const WgpDetails *getRDNA3WgpDetails() {
168193
MMAIntrinsic::WMMA_F32_16x16x16_F16,
169194
MMAIntrinsic::WMMA_F16_16x16x16_F16,
170195
};
171-
static const WgpDetails rdna3Wgp = {
172-
allComputeBits, allStorageBits, allSubgroupOps,
173-
allDotProductOps, ARRAY_SIZE(rdna3MMAOps), rdna3MMAOps,
174-
{32, 64}, {1024, 1024, 1024}, 1024,
175-
64 * 1024};
196+
static const WgpDetails rdna3Wgp = {allComputeBits,
197+
allStorageBits,
198+
allSubgroupOps,
199+
allDotProductOps,
200+
ARRAY_SIZE(rdna3MMAOps),
201+
rdna3MMAOps,
202+
{32, 64},
203+
{1024, 1024, 1024},
204+
1024,
205+
64 * 1024,
206+
{0x7fffffff, 0x7fffffff, 0x7fffffff}};
176207
return &rdna3Wgp;
177208
}
178209

179210
const WgpDetails *getRDNA2WgpDetails() {
180211
static const WgpDetails rdna2Wgp = {
181-
allComputeBits, allStorageBits, allSubgroupOps, allDotProductOps,
182-
/*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024},
183-
1024, 64 * 1024};
212+
allComputeBits, allStorageBits,
213+
allSubgroupOps, allDotProductOps,
214+
/*mmaCount=*/0,
215+
/*mmaOps=*/nullptr, {32, 64},
216+
{1024, 1024, 1024}, 1024,
217+
64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}};
184218
return &rdna2Wgp;
185219
}
186220

187221
const WgpDetails *getRDNA1WgpDetails() {
188222
static const WgpDetails rdna1Wgp = {
189-
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
190-
/*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 64}, {1024, 1024, 1024},
191-
1024, 64 * 1024};
223+
allComputeBits, allStorageBits,
224+
allSubgroupOps, DotProductOps::None,
225+
/*mmaCount=*/0,
226+
/*mmaOps=*/nullptr, {32, 64},
227+
{1024, 1024, 1024}, 1024,
228+
64 * 1024, {0x7fffffff, 0x7fffffff, 0x7fffffff}};
192229
return &rdna1Wgp;
193230
}
194231

@@ -281,7 +318,9 @@ std::optional<TargetDetails> getAppleTargetDetails() {
281318
static const WgpDetails wgp = {
282319
computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
283320
/*mmaCount=*/0, /*mmaOps=*/nullptr, {32, 32},
284-
{1024, 1024, 1024}, 1024, 32 * 1024};
321+
{1024, 1024, 1024}, 1024, 32 * 1024,
322+
// Note: These values have not been checked and may be higher
323+
{0xffff, 0xffff, 0xffff}};
285324
// clang-format on
286325

287326
return TargetDetails{&wgp, nullptr};
@@ -302,7 +341,9 @@ const WgpDetails *getValhallWgpDetails() {
302341
static const WgpDetails valhallWgp = {
303342
computeBitwdiths, allStorageBits, allSubgroupOps, allDotProductOps,
304343
/*mmaCount=*/0, /*mmaOps=*/nullptr, {16, 16}, {512, 512, 512},
305-
512, 32 * 1024};
344+
512, 32 * 1024,
345+
// Note: These values have not been checked and may be higher
346+
{0xffff, 0xffff, 0xffff}};
306347
// clang-format on
307348
return &valhallWgp;
308349
}
@@ -358,11 +399,17 @@ const WgpDetails *getAmpereWgpDetails() {
358399
MMAIntrinsic::WMMA_F32_16x16x16_F16,
359400
MMAIntrinsic::WMMA_F16_16x16x16_F16,
360401
};
361-
static const WgpDetails ampereWgp = {
362-
allComputeBits, allStorageBits, allSubgroupOps,
363-
allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
364-
{32, 32}, {1024, 1024, 1024}, 1024,
365-
163 * 1024};
402+
static const WgpDetails ampereWgp = {allComputeBits,
403+
allStorageBits,
404+
allSubgroupOps,
405+
allDotProductOps,
406+
ARRAY_SIZE(mmaOps),
407+
mmaOps,
408+
{32, 32},
409+
{1024, 1024, 1024},
410+
1024,
411+
163 * 1024,
412+
{0x7fffffff, 0xffff, 0xffff}};
366413
return &ampereWgp;
367414
}
368415

@@ -371,11 +418,17 @@ const WgpDetails *getTuringWgpDetails() {
371418
MMAIntrinsic::WMMA_F32_16x16x16_F16,
372419
MMAIntrinsic::WMMA_F16_16x16x16_F16,
373420
};
374-
static const WgpDetails turingWgp = {
375-
allComputeBits, allStorageBits, allSubgroupOps,
376-
allDotProductOps, ARRAY_SIZE(mmaOps), mmaOps,
377-
{32, 32}, {1024, 1024, 1024}, 1024,
378-
64 * 1024};
421+
static const WgpDetails turingWgp = {allComputeBits,
422+
allStorageBits,
423+
allSubgroupOps,
424+
allDotProductOps,
425+
ARRAY_SIZE(mmaOps),
426+
mmaOps,
427+
{32, 32},
428+
{1024, 1024, 1024},
429+
1024,
430+
64 * 1024,
431+
{0x7fffffff, 0xffff, 0xffff}};
379432
return &turingWgp;
380433
}
381434

@@ -388,7 +441,8 @@ const WgpDetails *getVoltaWgpDetails() {
388441
static const WgpDetails voltaWgp = {
389442
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
390443
ARRAY_SIZE(mmaOps), mmaOps, {32, 32}, {1024, 1024, 1024},
391-
1024, 96 * 1024};
444+
1024, 96 * 1024,
445+
{0x7fffffff, 0xffff, 0xffff}};
392446
// clang-format on
393447
return &voltaWgp;
394448
}
@@ -398,7 +452,8 @@ const WgpDetails *getPascalWgpDetails() {
398452
static const WgpDetails pascalWgp = {
399453
allComputeBits, allStorageBits, allSubgroupOps, DotProductOps::None,
400454
0, nullptr, // Pascal does not have tensor core support.
401-
{32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024};
455+
{32, 32}, {1024, 1024, 1024}, 1024, 48 * 1024,
456+
{0x7fffffff, 0xffff, 0xffff}};
402457
// clang-format on
403458
return &pascalWgp;
404459
}
@@ -479,7 +534,9 @@ const WgpDetails *getAdrenoWgpDetails() {
479534
computeBitwdiths, storageBitwidths, allSubgroupOps,
480535
allDotProductOps, /*mmaCount=*/0, /*mmaOps=*/nullptr,
481536
{64, 64}, {1024, 1024, 1024}, 1024,
482-
32 * 1024};
537+
32 * 1024,
538+
// Note: These values have not been checked and may be higher
539+
{0xffff, 0xffff, 0xffff}};
483540
// clang-format on
484541
return &adrenoWgp;
485542
}
@@ -545,7 +602,8 @@ const WgpDetails *getAndroidBaseline2022WgpDetails() {
545602
computeBitwdiths, storageBitwidths, SubgroupOps::None,
546603
DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr,
547604
{64, 64}, {128, 128, 64}, 128,
548-
16 * 1024};
605+
16 * 1024,
606+
{0xffff, 0xffff, 0xffff}};
549607
// clang-format on
550608
return &androidWgp;
551609
}
@@ -645,7 +703,8 @@ TargetAttr getWebGPUTargetDetails(MLIRContext *context) {
645703
computeBitwdiths, storageBitwidths, SubgroupOps::None,
646704
DotProductOps::None, /*mmaCount=*/0, /*mmaOps=*/nullptr,
647705
{32, 32}, {128, 128, 64}, 128,
648-
16 * 1024};
706+
16 * 1024,
707+
{0xffff, 0xffff, 0xffff}};
649708
// clang-format on
650709

651710
return createTargetAttr(

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute.mlir

+2-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ func.func @conv_nhwc() {
9393
subgroup = shuffle|arithmetic, dot = dp4xi8toi32,
9494
mma = [],
9595
subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024],
96-
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536>>
96+
max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536,
97+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>
9798
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb", {iree.gpu.target = #target}>
9899
func.func @matmul_256x256x256() attributes {hal.executable.target = #executable_target_rocm_hsaco_fb} {
99100
%cst = arith.constant 0.000000e+00 : f32

0 commit comments

Comments
 (0)