diff --git a/source/opt/aggressive_dead_code_elim_pass.cpp b/source/opt/aggressive_dead_code_elim_pass.cpp index e9c831e544..44432399ca 100644 --- a/source/opt/aggressive_dead_code_elim_pass.cpp +++ b/source/opt/aggressive_dead_code_elim_pass.cpp @@ -1005,6 +1005,7 @@ void AggressiveDCEPass::InitExtensions() { "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives", + "SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix" }); // clang-format on diff --git a/source/opt/folding_rules.cpp b/source/opt/folding_rules.cpp index 41f535d85d..2ebc385cb4 100644 --- a/source/opt/folding_rules.cpp +++ b/source/opt/folding_rules.cpp @@ -112,6 +112,12 @@ bool IsValidResult(T val) { } } +// Returns true if `type` is a cooperative matrix. +bool IsCooperativeMatrix(const analysis::Type* type) { + return type->kind() == analysis::Type::kCooperativeMatrixKHR || + type->kind() == analysis::Type::kCooperativeMatrixNV; +} + const analysis::Constant* ConstInput( const std::vector& constants) { return constants[0] ? constants[0] : constants[1]; @@ -314,7 +320,7 @@ FoldingRule ReciprocalFDiv() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -400,7 +406,7 @@ FoldingRule MergeNegateMulDivArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -466,7 +472,7 @@ FoldingRule MergeNegateAddSubArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -702,7 +708,7 @@ FoldingRule MergeMulMulArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -761,7 +767,7 @@ FoldingRule MergeMulDivArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -839,7 +845,7 @@ FoldingRule MergeMulNegateArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -884,7 +890,7 @@ FoldingRule MergeDivDivArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -962,7 +968,7 @@ FoldingRule MergeDivMulArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -1109,7 +1115,7 @@ FoldingRule MergeSubNegateArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -1162,7 +1168,7 @@ FoldingRule MergeAddAddArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::Kind::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -1215,7 +1221,7 @@ FoldingRule MergeAddSubArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::Kind::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -1280,7 +1286,7 @@ FoldingRule MergeSubAddArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -1351,7 +1357,7 @@ FoldingRule MergeSubSubArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } @@ -1449,7 +1455,7 @@ FoldingRule MergeGenericAddSubArithmetic() { const analysis::Type* type = context->get_type_mgr()->GetType(inst->type_id()); - if (type->kind() == analysis::Type::kCooperativeMatrixKHR) { + if (IsCooperativeMatrix(type)) { return false; } diff --git a/source/opt/local_access_chain_convert_pass.cpp b/source/opt/local_access_chain_convert_pass.cpp index 987e9a6b70..174e7d86da 100644 --- a/source/opt/local_access_chain_convert_pass.cpp +++ b/source/opt/local_access_chain_convert_pass.cpp @@ -429,7 +429,7 @@ void LocalAccessChainConvertPass::InitExtensions() { "SPV_KHR_fragment_shader_barycentric", "SPV_KHR_vulkan_memory_model", "SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives", - "SPV_KHR_cooperative_matrix"}); + "SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix"}); } bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds( diff --git a/source/opt/local_single_block_elim_pass.cpp b/source/opt/local_single_block_elim_pass.cpp index dd0b594a9d..3a7d25a4ba 100644 --- a/source/opt/local_single_block_elim_pass.cpp +++ b/source/opt/local_single_block_elim_pass.cpp @@ -292,6 +292,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() { "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives", + "SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix"}); } diff --git a/source/opt/local_single_store_elim_pass.cpp b/source/opt/local_single_store_elim_pass.cpp index aa7a756955..7dfc4adfc8 100644 --- a/source/opt/local_single_store_elim_pass.cpp +++ b/source/opt/local_single_store_elim_pass.cpp @@ -142,6 +142,7 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() { "SPV_EXT_shader_atomic_float_add", "SPV_EXT_fragment_shader_interlock", "SPV_NV_compute_shader_derivatives", + "SPV_NV_cooperative_matrix", "SPV_KHR_cooperative_matrix"}); } bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) { diff --git a/source/opt/mem_pass.cpp b/source/opt/mem_pass.cpp index 6e59bf0f8d..80ec8da18e 100644 --- a/source/opt/mem_pass.cpp +++ b/source/opt/mem_pass.cpp @@ -43,6 +43,7 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const { case spv::Op::OpTypeSampler: case spv::Op::OpTypeSampledImage: case spv::Op::OpTypePointer: + case spv::Op::OpTypeCooperativeMatrixNV: case spv::Op::OpTypeCooperativeMatrixKHR: return true; default: