diff --git a/docs/SPIR-V.rst b/docs/SPIR-V.rst index 91aa9e7468..259fc9cf8f 100644 --- a/docs/SPIR-V.rst +++ b/docs/SPIR-V.rst @@ -3840,9 +3840,9 @@ loading from SPIR-V builtin variable ``SubgroupSize`` and ``SubgroupLocalInvocationId`` respectively, the rest are translated into SPIR-V group operations with ``Subgroup`` scope according to the following chart: -============= ============================ =================================== ====================== +============= ============================ =================================== ============================== Wave Category Wave Intrinsics SPIR-V Opcode SPIR-V Group Operation -============= ============================ =================================== ====================== +============= ============================ =================================== ============================== Query ``WaveIsFirstLane()`` ``OpGroupNonUniformElect`` Vote ``WaveActiveAnyTrue()`` ``OpGroupNonUniformAny`` Vote ``WaveActiveAllTrue()`` ``OpGroupNonUniformAll`` @@ -3866,7 +3866,12 @@ Quad ``QuadReadAcrossY()`` ``OpGroupNonUniformQuadSwap`` Quad ``QuadReadAcrossDiagonal()`` ``OpGroupNonUniformQuadSwap`` Quad ``QuadReadLaneAt()`` ``OpGroupNonUniformQuadBroadcast`` N/A ``WaveMatch()`` ``OpGroupNonUniformPartitionNV`` -============= ============================ =================================== ====================== +Multiprefix ``WaveMultiPrefixSum()`` ``OpGroupNonUniform*Add`` ``PartitionedExclusiveScanNV`` +Multiprefix ``WaveMultiPrefixProduct()`` ``OpGroupNonUniform*Mul`` ``PartitionedExclusiveScanNV`` +Multiprefix ``WaveMultiPrefixBitAnd()`` ``OpGroupNonUniformLogicalAnd`` ``PartitionedExclusiveScanNV`` +Multiprefix ``WaveMultiPrefixBitOr()`` ``OpGroupNonUniformLogicalOr`` ``PartitionedExclusiveScanNV`` +Multiprefix ``WaveMultiPrefixBitXor()`` ``OpGroupNonUniformLogicalXor`` ``PartitionedExclusiveScanNV`` +============= ============================ =================================== ============================== The Implicit ``vk`` Namespace ============================= diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index d0c72ceb47..0dac692f82 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -8709,6 +8709,18 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) { callExpr, translateWaveOp(hlslOpcode, retType, srcLoc), spv::GroupOperation::ExclusiveScan); } break; + case hlsl::IntrinsicOp::IOP_WaveMultiPrefixUSum: + case hlsl::IntrinsicOp::IOP_WaveMultiPrefixSum: + case hlsl::IntrinsicOp::IOP_WaveMultiPrefixUProduct: + case hlsl::IntrinsicOp::IOP_WaveMultiPrefixProduct: + case hlsl::IntrinsicOp::IOP_WaveMultiPrefixBitAnd: + case hlsl::IntrinsicOp::IOP_WaveMultiPrefixBitOr: + case hlsl::IntrinsicOp::IOP_WaveMultiPrefixBitXor: { + const auto retType = callExpr->getCallReturnType(astContext); + retVal = processWaveReductionOrPrefix( + callExpr, translateWaveOp(hlslOpcode, retType, srcLoc), + spv::GroupOperation::PartitionedExclusiveScanNV); + } break; case hlsl::IntrinsicOp::IOP_WavePrefixCountBits: retVal = processWaveCountBits(callExpr, spv::GroupOperation::ExclusiveScan); break; @@ -9517,6 +9529,13 @@ spv::Op SpirvEmitter::translateWaveOp(hlsl::IntrinsicOp op, QualType type, WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMax, SMax, UMax, FMax); WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMin, SMin, UMin, FMin); WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMin, SMin, UMin, FMin); + WAVE_OP_CASE_INT_FLOAT(MultiPrefixUSum, IAdd, FAdd); + WAVE_OP_CASE_INT_FLOAT(MultiPrefixSum, IAdd, FAdd); + WAVE_OP_CASE_INT_FLOAT(MultiPrefixUProduct, IMul, FMul); + WAVE_OP_CASE_INT_FLOAT(MultiPrefixProduct, IMul, FMul); + WAVE_OP_CASE_INT(MultiPrefixBitAnd, BitwiseAnd); + WAVE_OP_CASE_INT(MultiPrefixBitOr, BitwiseOr); + WAVE_OP_CASE_INT(MultiPrefixBitXor, BitwiseXor); default: // Only Simple Wave Ops are handled here. break; @@ -9568,14 +9587,33 @@ SpirvInstruction *SpirvEmitter::processWaveReductionOrPrefix( // // WavePrefixProduct( value) // WavePrefixSum( value) - assert(callExpr->getNumArgs() == 1); + // + // WaveMultiPrefixSum( val, uint4 mask ) + // WaveMultiPrefixProduct( val, uint4 mask ) + // WaveMultiPrefixBitAnd( val, uint4 mask ) + // WaveMultiPrefixBitOr( val, uint4 mask ) + // WaveMultiPrefixBitXor( val, uint4 mask ) + + bool isMultiPrefix = + groupOp == spv::GroupOperation::PartitionedExclusiveScanNV; + assert(callExpr->getNumArgs() == (isMultiPrefix ? 2 : 1)); + featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation", callExpr->getExprLoc()); - auto *predicate = doExpr(callExpr->getArg(0)); + + llvm::ArrayRef operands; + auto *value = doExpr(callExpr->getArg(0)); + if (isMultiPrefix) { + SpirvInstruction *mask = doExpr(callExpr->getArg(1)); + operands = {value, mask}; + } else { + operands = {value}; + } + const QualType retType = callExpr->getCallReturnType(astContext); return spvBuilder.createGroupNonUniformOp( - opcode, retType, spv::Scope::Subgroup, {predicate}, - callExpr->getExprLoc(), llvm::Optional(groupOp)); + opcode, retType, spv::Scope::Subgroup, operands, callExpr->getExprLoc(), + llvm::Optional(groupOp)); } SpirvInstruction *SpirvEmitter::processWaveBroadcast(const CallExpr *callExpr) { diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.h b/tools/clang/lib/SPIRV/SpirvEmitter.h index 10e80b9e71..a8db190cc6 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.h +++ b/tools/clang/lib/SPIRV/SpirvEmitter.h @@ -615,7 +615,8 @@ class SpirvEmitter : public ASTConsumer { SpirvInstruction *processWaveCountBits(const CallExpr *, spv::GroupOperation groupOp); - /// Processes SM6.0 wave reduction or scan/prefix intrinsic calls. + /// Processes SM6.0 wave reduction or scan/prefix and SM6.5 wave multiprefix + /// intrinsic calls. SpirvInstruction *processWaveReductionOrPrefix(const CallExpr *, spv::Op op, spv::GroupOperation groupOp); diff --git a/tools/clang/lib/SPIRV/SpirvInstruction.cpp b/tools/clang/lib/SPIRV/SpirvInstruction.cpp index be15965673..91f936edf5 100644 --- a/tools/clang/lib/SPIRV/SpirvInstruction.cpp +++ b/tools/clang/lib/SPIRV/SpirvInstruction.cpp @@ -684,19 +684,12 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp( case spv::Op::OpGroupNonUniformBallotBitCount: case spv::Op::OpGroupNonUniformBallotFindLSB: case spv::Op::OpGroupNonUniformBallotFindMSB: - case spv::Op::OpGroupNonUniformIAdd: - case spv::Op::OpGroupNonUniformFAdd: - case spv::Op::OpGroupNonUniformIMul: - case spv::Op::OpGroupNonUniformFMul: case spv::Op::OpGroupNonUniformSMin: case spv::Op::OpGroupNonUniformUMin: case spv::Op::OpGroupNonUniformFMin: case spv::Op::OpGroupNonUniformSMax: case spv::Op::OpGroupNonUniformUMax: case spv::Op::OpGroupNonUniformFMax: - case spv::Op::OpGroupNonUniformBitwiseAnd: - case spv::Op::OpGroupNonUniformBitwiseOr: - case spv::Op::OpGroupNonUniformBitwiseXor: case spv::Op::OpGroupNonUniformLogicalAnd: case spv::Op::OpGroupNonUniformLogicalOr: case spv::Op::OpGroupNonUniformLogicalXor: @@ -715,6 +708,17 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp( assert(operandsVec.size() == 2); break; + // Group non-uniform operations with a required and optional operand. + case spv::Op::OpGroupNonUniformIAdd: + case spv::Op::OpGroupNonUniformFAdd: + case spv::Op::OpGroupNonUniformIMul: + case spv::Op::OpGroupNonUniformFMul: + case spv::Op::OpGroupNonUniformBitwiseAnd: + case spv::Op::OpGroupNonUniformBitwiseOr: + case spv::Op::OpGroupNonUniformBitwiseXor: + assert(operandsVec.size() >= 1 && operandsVec.size() <= 2); + break; + // Unexpected opcode. default: assert(false && "Unexpected Group non-uniform opcode"); diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.multiprefix.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.multiprefix.hlsl deleted file mode 100644 index 642dbd2525..0000000000 --- a/tools/clang/test/CodeGenSPIRV/intrinsics.multiprefix.hlsl +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: not %dxc -T ps_6_5 -E main -fcgl %s -spirv 2>&1 | FileCheck %s - -StructuredBuffer g_mask; - -uint main(uint input : ATTR0) : SV_Target { - uint4 mask = g_mask[0]; - - uint res = uint4(0, 0, 0, 0); -// CHECK: 10:10: error: WaveMultiPrefixBitAnd intrinsic function unimplemented - res += WaveMultiPrefixBitAnd(input, mask); -// CHECK: 12:10: error: WaveMultiPrefixBitOr intrinsic function unimplemented - res += WaveMultiPrefixBitOr(input, mask); -// CHECK: 14:10: error: WaveMultiPrefixBitXor intrinsic function unimplemented - res += WaveMultiPrefixBitXor(input, mask); -// CHECK: 16:10: error: WaveMultiPrefixProduct intrinsic function unimplemented - res += WaveMultiPrefixProduct(input, mask); -// CHECK: 18:10: error: WaveMultiPrefixSum intrinsic function unimplemented - res += WaveMultiPrefixSum(input, mask); -// CHECK: 20:12: error: WaveMultiPrefixCountBits intrinsic function unimplemented - res.x += WaveMultiPrefixCountBits((input.x == 1), mask); - - return res; -} diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.sm6_5.multiprefix.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.sm6_5.multiprefix.hlsl new file mode 100644 index 0000000000..7b15e17558 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.sm6_5.multiprefix.hlsl @@ -0,0 +1,49 @@ +// RUN: %dxc -E main -T ps_6_5 -spirv -O0 -fspv-target-env=vulkan1.1 %s | FileCheck %s +// RUN: not %dxc -E main -T ps_6_5 -spirv -O0 %s 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR + +// CHECK-ERROR: error: Vulkan 1.1 is required for Wave Operation but not permitted to use + +// CHECK: OpCapability GroupNonUniformPartitionedNV +// CHECK: OpExtension "SPV_NV_shader_subgroup_partitioned" + +StructuredBuffer g_mask; + +uint4 main(int4 input0 : ATTR0, uint4 input1 : ATTR1) : SV_Target { + uint4 mask = g_mask[0]; + +// CHECK: [[input0:%[0-9]+]] = OpLoad %v4int %input0 +// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask +// CHECK: {{%[0-9]+}} = OpGroupNonUniformIMul %v4int %uint_3 PartitionedExclusiveScanNV [[input0]] [[mask]] + int4 res = WaveMultiPrefixProduct(input0, mask); + +// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1 +// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask +// CHECK: {{%[0-9]+}} = OpGroupNonUniformIMul %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]] + res += WaveMultiPrefixProduct(input1, mask); + +// CHECK: [[input0:%[0-9]+]] = OpLoad %v4int %input0 +// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask +// CHECK: {{%[0-9]+}} = OpGroupNonUniformIAdd %v4int %uint_3 PartitionedExclusiveScanNV [[input0]] [[mask]] + res += WaveMultiPrefixSum(input0, mask); + +// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1 +// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask +// CHECK: {{%[0-9]+}} = OpGroupNonUniformIAdd %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]] + res += WaveMultiPrefixSum(input1, mask); + +// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1 +// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask +// CHECK: {{%[0-9]+}} = OpGroupNonUniformBitwiseAnd %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]] + res += WaveMultiPrefixBitAnd(input1, mask); + +// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1 +// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask +// CHECK: {{%[0-9]+}} = OpGroupNonUniformBitwiseOr %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]] + res += WaveMultiPrefixBitOr(input1, mask); + +// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1 +// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask +// CHECK: {{%[0-9]+}} = OpGroupNonUniformBitwiseXor %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]] + res += WaveMultiPrefixBitXor(input1, mask); + return res; +} diff --git a/tools/clang/test/CodeGenSPIRV/intrinsics.sm6_5.multiprefix.unimplemented.hlsl b/tools/clang/test/CodeGenSPIRV/intrinsics.sm6_5.multiprefix.unimplemented.hlsl new file mode 100644 index 0000000000..f9df3466af --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/intrinsics.sm6_5.multiprefix.unimplemented.hlsl @@ -0,0 +1,13 @@ +// RUN: not %dxc -T ps_6_5 -E main -fcgl %s -spirv 2>&1 | FileCheck %s + +StructuredBuffer g_mask; + +uint main(uint input : ATTR0) : SV_Target { + uint4 mask = g_mask[0]; + + uint res = uint4(0, 0, 0, 0); +// CHECK: error: WaveMultiPrefixCountBits intrinsic function unimplemented + res.x += WaveMultiPrefixCountBits((input.x == 1), mask); + + return res; +}