Skip to content

Commit

Permalink
Add NV coop matrix as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
s-perron committed Jun 25, 2024
1 parent 21178a6 commit 8770000
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 15 deletions.
1 change: 1 addition & 0 deletions source/opt/aggressive_dead_code_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,7 @@ void AggressiveDCEPass::InitExtensions() {
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix"
});
// clang-format on
Expand Down
34 changes: 20 additions & 14 deletions source/opt/folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ bool IsValidResult(T val) {
}
}

// Returns true if `type` is a cooperative matrix.
bool IsCooperativeMatrix(const analysis::Type* type) {
return type->kind() == analysis::Type::kCooperativeMatrixKHR ||
type->kind() == analysis::Type::kCooperativeMatrixNV;
}

const analysis::Constant* ConstInput(
const std::vector<const analysis::Constant*>& constants) {
return constants[0] ? constants[0] : constants[1];
Expand Down Expand Up @@ -314,7 +320,7 @@ FoldingRule ReciprocalFDiv() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -400,7 +406,7 @@ FoldingRule MergeNegateMulDivArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -466,7 +472,7 @@ FoldingRule MergeNegateAddSubArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -702,7 +708,7 @@ FoldingRule MergeMulMulArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -761,7 +767,7 @@ FoldingRule MergeMulDivArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -839,7 +845,7 @@ FoldingRule MergeMulNegateArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -884,7 +890,7 @@ FoldingRule MergeDivDivArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -962,7 +968,7 @@ FoldingRule MergeDivMulArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -1109,7 +1115,7 @@ FoldingRule MergeSubNegateArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -1162,7 +1168,7 @@ FoldingRule MergeAddAddArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::Kind::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -1215,7 +1221,7 @@ FoldingRule MergeAddSubArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::Kind::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -1280,7 +1286,7 @@ FoldingRule MergeSubAddArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -1351,7 +1357,7 @@ FoldingRule MergeSubSubArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down Expand Up @@ -1449,7 +1455,7 @@ FoldingRule MergeGenericAddSubArithmetic() {
const analysis::Type* type =
context->get_type_mgr()->GetType(inst->type_id());

if (type->kind() == analysis::Type::kCooperativeMatrixKHR) {
if (IsCooperativeMatrix(type)) {
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion source/opt/local_access_chain_convert_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ void LocalAccessChainConvertPass::InitExtensions() {
"SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model",
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives",
"SPV_KHR_cooperative_matrix"});
"SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix"});
}

bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
Expand Down
1 change: 1 addition & 0 deletions source/opt/local_single_block_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix"});
}

Expand Down
1 change: 1 addition & 0 deletions source/opt/local_single_store_elim_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
"SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_NV_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix"});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
Expand Down
1 change: 1 addition & 0 deletions source/opt/mem_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
case spv::Op::OpTypeSampler:
case spv::Op::OpTypeSampledImage:
case spv::Op::OpTypePointer:
case spv::Op::OpTypeCooperativeMatrixNV:
case spv::Op::OpTypeCooperativeMatrixKHR:
return true;
default:
Expand Down

0 comments on commit 8770000

Please sign in to comment.