diff --git a/source/opt/inline_exhaustive_pass.cpp b/source/opt/inline_exhaustive_pass.cpp index bef45017f2..9cdea43d7d 100644 --- a/source/opt/inline_exhaustive_pass.cpp +++ b/source/opt/inline_exhaustive_pass.cpp @@ -55,6 +55,11 @@ Pass::Status InlineExhaustivePass::InlineExhaustive(Function* func) { } } } + + if (modified) { + FixDebugDeclares(func); + } + return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange); } diff --git a/source/opt/inline_pass.cpp b/source/opt/inline_pass.cpp index 318643341a..193e276ff2 100644 --- a/source/opt/inline_pass.cpp +++ b/source/opt/inline_pass.cpp @@ -30,6 +30,8 @@ namespace { constexpr int kSpvFunctionCallFunctionId = 2; constexpr int kSpvFunctionCallArgumentId = 3; constexpr int kSpvReturnValueId = 0; +constexpr int kSpvDebugDeclareVarInIdx = 3; +constexpr int kSpvAccessChainBaseInIdx = 0; } // namespace uint32_t InlinePass::AddPointerToType(uint32_t type_id, @@ -858,5 +860,68 @@ void InlinePass::InitializeInline() { InlinePass::InlinePass() {} +void InlinePass::FixDebugDeclares(Function* func) { + std::map access_chains; + std::vector debug_declare_insts; + + func->ForEachInst([&access_chains, &debug_declare_insts](Instruction* inst) { + if (inst->opcode() == spv::Op::OpAccessChain) { + access_chains[inst->result_id()] = inst; + } + if (inst->GetCommonDebugOpcode() == CommonDebugInfoDebugDeclare) { + debug_declare_insts.push_back(inst); + } + }); + + for (auto& inst : debug_declare_insts) { + FixDebugDeclare(inst, access_chains); + } +} + +void InlinePass::FixDebugDeclare( + Instruction* dbg_declare_inst, + const std::map& access_chains) { + do { + uint32_t var_id = + dbg_declare_inst->GetSingleWordInOperand(kSpvDebugDeclareVarInIdx); + + // The def-use chains are not kept up to date while inlining, so we need to + // get the variable by traversing the functions. + auto it = access_chains.find(var_id); + if (it == access_chains.end()) { + return; + } + Instruction* access_chain = it->second; + + // If the variable id in the debug declare is an access chain, it is + // invalid. it needs to be fixed up. The debug declare will be updated so + // that its Var operand becomes the base of the access chain. The indexes of + // the access chain are prepended before the indexes of the debug declare. + + std::vector operands; + for (int i = 0; i < kSpvDebugDeclareVarInIdx; i++) { + operands.push_back(dbg_declare_inst->GetInOperand(i)); + } + + uint32_t access_chain_base = + access_chain->GetSingleWordInOperand(kSpvAccessChainBaseInIdx); + operands.push_back(Operand(SPV_OPERAND_TYPE_ID, {access_chain_base})); + operands.push_back( + dbg_declare_inst->GetInOperand(kSpvDebugDeclareVarInIdx + 1)); + + for (uint32_t i = kSpvAccessChainBaseInIdx + 1; + i < access_chain->NumInOperands(); ++i) { + operands.push_back(access_chain->GetInOperand(i)); + } + + for (uint32_t i = kSpvDebugDeclareVarInIdx + 2; + i < dbg_declare_inst->NumInOperands(); ++i) { + operands.push_back(dbg_declare_inst->GetInOperand(i)); + } + + dbg_declare_inst->SetInOperands(std::move(operands)); + } while (true); +} + } // namespace opt } // namespace spvtools diff --git a/source/opt/inline_pass.h b/source/opt/inline_pass.h index 1c9d60e32d..7bea31d1d5 100644 --- a/source/opt/inline_pass.h +++ b/source/opt/inline_pass.h @@ -150,6 +150,12 @@ class InlinePass : public Pass { // Initialize state for optimization of |module| void InitializeInline(); + // Fixes invalid debug declare functions in `func` that were caused by + // inlining. This function cannot be called while in the middle of inlining + // because it needs to be able to find the instructions that define an + // id. + void FixDebugDeclares(Function* func); + // Map from function's result id to function. std::unordered_map id2function_; @@ -241,6 +247,11 @@ class InlinePass : public Pass { // structural dominance. void UpdateSingleBlockLoopContinueTarget( uint32_t new_id, std::vector>* new_blocks); + + // Replaces the `var` operand of `dbg_declare_inst` and updates the indexes + // accordingly, if it is the id of an access chain in `access_chains`. + void FixDebugDeclare(Instruction* dbg_declare_inst, + const std::map& access_chains); }; } // namespace opt diff --git a/source/val/validate_function.cpp b/source/val/validate_function.cpp index 4879a7db5f..50ae0a2dbd 100644 --- a/source/val/validate_function.cpp +++ b/source/val/validate_function.cpp @@ -258,7 +258,8 @@ spv_result_t ValidateFunctionCall(ValidationState_t& _, _.HasCapability(spv::Capability::VariablePointers) && sc == spv::StorageClass::Workgroup; const bool uc_ptr = sc == spv::StorageClass::UniformConstant; - if (!ssbo_vptr && !wg_vptr && !uc_ptr) { + if (!_.options()->before_hlsl_legalization && !ssbo_vptr && + !wg_vptr && !uc_ptr) { return _.diag(SPV_ERROR_INVALID_ID, inst) << "Pointer operand " << _.getIdName(argument_id) << " must be a memory object declaration";