Skip to content

Commit

Permalink
[SPIR-V] Implement WaveMutliPrefix*
Browse files Browse the repository at this point in the history
Implements the Shader Model 6.5 WaveMultiPrefix* intrinsic functions
using the group operation from SPV_NV_shader_subgroup_partitioned,
PartitionedExclusiveScanNV, which performs a partitioned exclusive scan
operation across a subset of invocations ("lanes") in a subgroup
("wave"). The subset of the partition is determined by the provided
ballot ("mask") parameter, which follows the same requirements for
valid partitioning and active invocations/lanes as the HLSL parameter.

Note that WaveMultiPrefixCountBits remains unimplemented because it does
not directly map to a SPIR-V GroupNonUniformArithmetic instruction that
accepts the PartitionedExclusiveScanNV Group Operation.

DirectX Spec: https://microsoft.github.io/DirectX-Specs/d3d/HLSL_ShaderModel6_5.html#wavemultiprefix-functions
SPIR-V Extension: https://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/NV/SPV_NV_shader_subgroup_partitioned.html

Depends on microsoft#6596
Fixes microsoft#6600
  • Loading branch information
sudonatalie committed May 10, 2024
1 parent 9376983 commit d46fc48
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 38 deletions.
11 changes: 8 additions & 3 deletions docs/SPIR-V.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3840,9 +3840,9 @@ loading from SPIR-V builtin variable ``SubgroupSize`` and
``SubgroupLocalInvocationId`` respectively, the rest are translated into SPIR-V
group operations with ``Subgroup`` scope according to the following chart:

============= ============================ =================================== ======================
============= ============================ =================================== ==============================
Wave Category Wave Intrinsics SPIR-V Opcode SPIR-V Group Operation
============= ============================ =================================== ======================
============= ============================ =================================== ==============================
Query ``WaveIsFirstLane()`` ``OpGroupNonUniformElect``
Vote ``WaveActiveAnyTrue()`` ``OpGroupNonUniformAny``
Vote ``WaveActiveAllTrue()`` ``OpGroupNonUniformAll``
Expand All @@ -3866,7 +3866,12 @@ Quad ``QuadReadAcrossY()`` ``OpGroupNonUniformQuadSwap``
Quad ``QuadReadAcrossDiagonal()`` ``OpGroupNonUniformQuadSwap``
Quad ``QuadReadLaneAt()`` ``OpGroupNonUniformQuadBroadcast``
N/A ``WaveMatch()`` ``OpGroupNonUniformPartitionNV``
============= ============================ =================================== ======================
Multiprefix ``WaveMultiPrefixSum()`` ``OpGroupNonUniform*Add`` ``PartitionedExclusiveScanNV``
Multiprefix ``WaveMultiPrefixProduct()`` ``OpGroupNonUniform*Mul`` ``PartitionedExclusiveScanNV``
Multiprefix ``WaveMultiPrefixBitAnd()`` ``OpGroupNonUniformLogicalAnd`` ``PartitionedExclusiveScanNV``
Multiprefix ``WaveMultiPrefixBitOr()`` ``OpGroupNonUniformLogicalOr`` ``PartitionedExclusiveScanNV``
Multiprefix ``WaveMultiPrefixBitXor()`` ``OpGroupNonUniformLogicalXor`` ``PartitionedExclusiveScanNV``
============= ============================ =================================== ==============================

The Implicit ``vk`` Namespace
=============================
Expand Down
46 changes: 42 additions & 4 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8709,6 +8709,18 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
callExpr, translateWaveOp(hlslOpcode, retType, srcLoc),
spv::GroupOperation::ExclusiveScan);
} break;
case hlsl::IntrinsicOp::IOP_WaveMultiPrefixUSum:
case hlsl::IntrinsicOp::IOP_WaveMultiPrefixSum:
case hlsl::IntrinsicOp::IOP_WaveMultiPrefixUProduct:
case hlsl::IntrinsicOp::IOP_WaveMultiPrefixProduct:
case hlsl::IntrinsicOp::IOP_WaveMultiPrefixBitAnd:
case hlsl::IntrinsicOp::IOP_WaveMultiPrefixBitOr:
case hlsl::IntrinsicOp::IOP_WaveMultiPrefixBitXor: {
const auto retType = callExpr->getCallReturnType(astContext);
retVal = processWaveReductionOrPrefix(
callExpr, translateWaveOp(hlslOpcode, retType, srcLoc),
spv::GroupOperation::PartitionedExclusiveScanNV);
} break;
case hlsl::IntrinsicOp::IOP_WavePrefixCountBits:
retVal = processWaveCountBits(callExpr, spv::GroupOperation::ExclusiveScan);
break;
Expand Down Expand Up @@ -9517,6 +9529,13 @@ spv::Op SpirvEmitter::translateWaveOp(hlsl::IntrinsicOp op, QualType type,
WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMax, SMax, UMax, FMax);
WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveUMin, SMin, UMin, FMin);
WAVE_OP_CASE_SINT_UINT_FLOAT(ActiveMin, SMin, UMin, FMin);
WAVE_OP_CASE_INT_FLOAT(MultiPrefixUSum, IAdd, FAdd);
WAVE_OP_CASE_INT_FLOAT(MultiPrefixSum, IAdd, FAdd);
WAVE_OP_CASE_INT_FLOAT(MultiPrefixUProduct, IMul, FMul);
WAVE_OP_CASE_INT_FLOAT(MultiPrefixProduct, IMul, FMul);
WAVE_OP_CASE_INT(MultiPrefixBitAnd, BitwiseAnd);
WAVE_OP_CASE_INT(MultiPrefixBitOr, BitwiseOr);
WAVE_OP_CASE_INT(MultiPrefixBitXor, BitwiseXor);
default:
// Only Simple Wave Ops are handled here.
break;
Expand Down Expand Up @@ -9568,14 +9587,33 @@ SpirvInstruction *SpirvEmitter::processWaveReductionOrPrefix(
//
// <type> WavePrefixProduct(<type> value)
// <type> WavePrefixSum(<type> value)
assert(callExpr->getNumArgs() == 1);
//
// <type> WaveMultiPrefixSum( <type> val, uint4 mask )
// <type> WaveMultiPrefixProduct( <type> val, uint4 mask )
// <int_type> WaveMultiPrefixBitAnd( <int_type> val, uint4 mask )
// <int_type> WaveMultiPrefixBitOr( <int_type> val, uint4 mask )
// <int_type> WaveMultiPrefixBitXor( <int_type> val, uint4 mask )

bool isMultiPrefix =
groupOp == spv::GroupOperation::PartitionedExclusiveScanNV;
assert(callExpr->getNumArgs() == (isMultiPrefix ? 2 : 1));

featureManager.requestTargetEnv(SPV_ENV_VULKAN_1_1, "Wave Operation",
callExpr->getExprLoc());
auto *predicate = doExpr(callExpr->getArg(0));

llvm::ArrayRef<SpirvInstruction *> operands;
auto *value = doExpr(callExpr->getArg(0));
if (isMultiPrefix) {
SpirvInstruction *mask = doExpr(callExpr->getArg(1));
operands = {value, mask};
} else {
operands = {value};
}

const QualType retType = callExpr->getCallReturnType(astContext);
return spvBuilder.createGroupNonUniformOp(
opcode, retType, spv::Scope::Subgroup, {predicate},
callExpr->getExprLoc(), llvm::Optional<spv::GroupOperation>(groupOp));
opcode, retType, spv::Scope::Subgroup, operands, callExpr->getExprLoc(),
llvm::Optional<spv::GroupOperation>(groupOp));
}

SpirvInstruction *SpirvEmitter::processWaveBroadcast(const CallExpr *callExpr) {
Expand Down
3 changes: 2 additions & 1 deletion tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,8 @@ class SpirvEmitter : public ASTConsumer {
SpirvInstruction *processWaveCountBits(const CallExpr *,
spv::GroupOperation groupOp);

/// Processes SM6.0 wave reduction or scan/prefix intrinsic calls.
/// Processes SM6.0 wave reduction or scan/prefix and SM6.5 wave multiprefix
/// intrinsic calls.
SpirvInstruction *processWaveReductionOrPrefix(const CallExpr *, spv::Op op,
spv::GroupOperation groupOp);

Expand Down
18 changes: 11 additions & 7 deletions tools/clang/lib/SPIRV/SpirvInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,19 +684,12 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(
case spv::Op::OpGroupNonUniformBallotBitCount:
case spv::Op::OpGroupNonUniformBallotFindLSB:
case spv::Op::OpGroupNonUniformBallotFindMSB:
case spv::Op::OpGroupNonUniformIAdd:
case spv::Op::OpGroupNonUniformFAdd:
case spv::Op::OpGroupNonUniformIMul:
case spv::Op::OpGroupNonUniformFMul:
case spv::Op::OpGroupNonUniformSMin:
case spv::Op::OpGroupNonUniformUMin:
case spv::Op::OpGroupNonUniformFMin:
case spv::Op::OpGroupNonUniformSMax:
case spv::Op::OpGroupNonUniformUMax:
case spv::Op::OpGroupNonUniformFMax:
case spv::Op::OpGroupNonUniformBitwiseAnd:
case spv::Op::OpGroupNonUniformBitwiseOr:
case spv::Op::OpGroupNonUniformBitwiseXor:
case spv::Op::OpGroupNonUniformLogicalAnd:
case spv::Op::OpGroupNonUniformLogicalOr:
case spv::Op::OpGroupNonUniformLogicalXor:
Expand All @@ -715,6 +708,17 @@ SpirvGroupNonUniformOp::SpirvGroupNonUniformOp(
assert(operandsVec.size() == 2);
break;

// Group non-uniform operations with a required and optional operand.
case spv::Op::OpGroupNonUniformIAdd:
case spv::Op::OpGroupNonUniformFAdd:
case spv::Op::OpGroupNonUniformIMul:
case spv::Op::OpGroupNonUniformFMul:
case spv::Op::OpGroupNonUniformBitwiseAnd:
case spv::Op::OpGroupNonUniformBitwiseOr:
case spv::Op::OpGroupNonUniformBitwiseXor:
assert(operandsVec.size() >= 1 && operandsVec.size() <= 2);
break;

// Unexpected opcode.
default:
assert(false && "Unexpected Group non-uniform opcode");
Expand Down
23 changes: 0 additions & 23 deletions tools/clang/test/CodeGenSPIRV/intrinsics.multiprefix.hlsl

This file was deleted.

49 changes: 49 additions & 0 deletions tools/clang/test/CodeGenSPIRV/intrinsics.sm6_5.multiprefix.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: %dxc -E main -T ps_6_5 -spirv -O0 -fspv-target-env=vulkan1.1 %s | FileCheck %s
// RUN: not %dxc -E main -T ps_6_5 -spirv -O0 %s 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

// CHECK-ERROR: error: Vulkan 1.1 is required for Wave Operation but not permitted to use

// CHECK: OpCapability GroupNonUniformPartitionedNV
// CHECK: OpExtension "SPV_NV_shader_subgroup_partitioned"

StructuredBuffer<uint4> g_mask;

uint4 main(int4 input0 : ATTR0, uint4 input1 : ATTR1) : SV_Target {
uint4 mask = g_mask[0];

// CHECK: [[input0:%[0-9]+]] = OpLoad %v4int %input0
// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask
// CHECK: {{%[0-9]+}} = OpGroupNonUniformIMul %v4int %uint_3 PartitionedExclusiveScanNV [[input0]] [[mask]]
int4 res = WaveMultiPrefixProduct(input0, mask);

// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1
// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask
// CHECK: {{%[0-9]+}} = OpGroupNonUniformIMul %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]]
res += WaveMultiPrefixProduct(input1, mask);

// CHECK: [[input0:%[0-9]+]] = OpLoad %v4int %input0
// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask
// CHECK: {{%[0-9]+}} = OpGroupNonUniformIAdd %v4int %uint_3 PartitionedExclusiveScanNV [[input0]] [[mask]]
res += WaveMultiPrefixSum(input0, mask);

// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1
// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask
// CHECK: {{%[0-9]+}} = OpGroupNonUniformIAdd %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]]
res += WaveMultiPrefixSum(input1, mask);

// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1
// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask
// CHECK: {{%[0-9]+}} = OpGroupNonUniformBitwiseAnd %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]]
res += WaveMultiPrefixBitAnd(input1, mask);

// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1
// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask
// CHECK: {{%[0-9]+}} = OpGroupNonUniformBitwiseOr %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]]
res += WaveMultiPrefixBitOr(input1, mask);

// CHECK: [[input1:%[0-9]+]] = OpLoad %v4uint %input1
// CHECK: [[mask:%[0-9]+]] = OpLoad %v4uint %mask
// CHECK: {{%[0-9]+}} = OpGroupNonUniformBitwiseXor %v4uint %uint_3 PartitionedExclusiveScanNV [[input1]] [[mask]]
res += WaveMultiPrefixBitXor(input1, mask);
return res;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: not %dxc -T ps_6_5 -E main -fcgl %s -spirv 2>&1 | FileCheck %s

StructuredBuffer<uint4> g_mask;

uint main(uint input : ATTR0) : SV_Target {
uint4 mask = g_mask[0];

uint res = uint4(0, 0, 0, 0);
// CHECK: error: WaveMultiPrefixCountBits intrinsic function unimplemented
res.x += WaveMultiPrefixCountBits((input.x == 1), mask);

return res;
}

0 comments on commit d46fc48

Please sign in to comment.