From e50b100dfd36a11775d8aaf1463cb2d37ec8b095 Mon Sep 17 00:00:00 2001 From: Steven Perron Date: Thu, 8 Feb 2024 09:42:45 -0500 Subject: [PATCH] Fixes based on code review. Remove stale test cases. --- source/opt/const_folding_rules.cpp | 101 +++++++++++++++-------------- test/opt/fold_test.cpp | 22 +------ 2 files changed, 56 insertions(+), 67 deletions(-) diff --git a/source/opt/const_folding_rules.cpp b/source/opt/const_folding_rules.cpp index db4212fb17..7723f94ed2 100644 --- a/source/opt/const_folding_rules.cpp +++ b/source/opt/const_folding_rules.cpp @@ -21,8 +21,8 @@ namespace opt { namespace { constexpr uint32_t kExtractCompositeIdInIdx = 0; -// Returns the value obtained by setting clearing the `number_of_bits` most -// significant bits of `value`. +// Returns the value obtained by extracting the |number_of_bits| least +// significant bits from |value|, and sign-extending it to 64-bits. uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) { if (number_of_bits == 64) return value; @@ -38,12 +38,12 @@ uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) { return value; } -// Returns the value obtained from clearing the `number_of_bits` most -// significant bits of `value`. -uint64_t ClearUpperBits(uint64_t value, uint32_t number_of_bits) { - if (number_of_bits == 0) return value; +// Returns the value obtained by extracting the |number_of_bits| least +// significant bits from |value|, and zero-extending it to 64-bits. +uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) { + if (number_of_bits == 64) return value; - uint64_t mask_for_first_bit_to_clear = 1ull << (64 - number_of_bits); + uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits); uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1; value &= mask_for_bits_to_keep; return value; @@ -67,7 +67,7 @@ const analysis::Constant* GenerateIntegerConstant( if (integer_type->IsSigned()) { result = SignExtendValue(result, integer_type->width()); } else { - result = ClearUpperBits(result, 64 - integer_type->width()); + result = ZeroExtendValue(result, integer_type->width()); } words = {static_cast(result)}; } @@ -769,15 +769,18 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) { } // Returns a |ConstantFoldingRule| that folds binary scalar ops -// using |scalar_rule| and unary vectors ops by applying -// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule| -// that is returned assumes that |constants| contains 2 entries. If they are -// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector| -// whose element type is |Float| or |Integer|. +// using |scalar_rule| and binary vectors ops by applying +// |scalar_rule| to the elements of the vector. The folding rule assumes that op +// has two inputs. For regular instruction, those are in operands 0 and 1. For +// extended instruction, they are in operands 1 and 2. If an element in +// |constants| is not nullprt, then the constant's type is |Float|, |Integer|, +// or |Vector| whose element type is |Float| or |Integer|. ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) { return [scalar_rule](IRContext* context, Instruction* inst, const std::vector& constants) -> const analysis::Constant* { + assert(constants.size() == inst->NumInOperands()); + assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2)); analysis::ConstantManager* const_mgr = context->get_constant_mgr(); analysis::TypeManager* type_mgr = context->get_type_mgr(); const analysis::Type* result_type = type_mgr->GetType(inst->type_id()); @@ -788,40 +791,38 @@ ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) { const analysis::Constant* arg2 = (inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1]; - if (arg1 == nullptr) { + if (arg1 == nullptr || arg2 == nullptr) { return nullptr; } - if (arg2 == nullptr) { - return nullptr; + + if (vector_type == nullptr) { + return scalar_rule(result_type, arg1, arg2, const_mgr); } - if (vector_type != nullptr) { - std::vector a_components; - std::vector b_components; - std::vector results_components; + std::vector a_components; + std::vector b_components; + std::vector results_components; - a_components = arg1->GetVectorComponents(const_mgr); - b_components = arg2->GetVectorComponents(const_mgr); + a_components = arg1->GetVectorComponents(const_mgr); + b_components = arg2->GetVectorComponents(const_mgr); + assert(a_components.size() == b_components.size()); - // Fold each component of the vector. - for (uint32_t i = 0; i < a_components.size(); ++i) { - results_components.push_back(scalar_rule(vector_type->element_type(), - a_components[i], - b_components[i], const_mgr)); - if (results_components[i] == nullptr) { - return nullptr; - } + // Fold each component of the vector. + for (uint32_t i = 0; i < a_components.size(); ++i) { + results_components.push_back(scalar_rule(vector_type->element_type(), + a_components[i], b_components[i], + const_mgr)); + if (results_components[i] == nullptr) { + return nullptr; } + } - // Build the constant object and return it. - std::vector ids; - for (const analysis::Constant* member : results_components) { - ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); - } - return const_mgr->GetConstant(vector_type, ids); - } else { - return scalar_rule(result_type, arg1, arg2, const_mgr); + // Build the constant object and return it. + std::vector ids; + for (const analysis::Constant* member : results_components) { + ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id()); } + return const_mgr->GetConstant(vector_type, ids); }; } @@ -1701,9 +1702,9 @@ enum Sign { Signed, Unsigned }; // Returns a BinaryScalarFoldingRule that applies `op` to the scalars. // The `signedness` is used to determine if the operands should be interpreted -// as signed or unsigned. If the operands are signed, the will be sign extended -// before the value is passed to `op`. Otherwise the values will be zero -// extended. +// as signed or unsigned. If the operands are signed, the value will be sign +// extended before the value is passed to `op`. Otherwise the values will be +// zero extended. template BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, uint64_t)) { @@ -1711,7 +1712,7 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, [op](const analysis::Type* result_type, const analysis::Constant* a, const analysis::Constant* b, analysis::ConstantManager* const_mgr) -> const analysis::Constant* { - assert(result_type != nullptr && a != nullptr); + assert(result_type != nullptr && a != nullptr && b != nullptr); const analysis::Integer* integer_type = a->type()->AsInteger(); assert(integer_type != nullptr); assert(integer_type == result_type->AsInteger()); @@ -1732,30 +1733,34 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t, }; } -// A scalar folding rule that foles OpSConvert. +// A scalar folding rule that folds OpSConvert. const analysis::Constant* FoldScalarSConvert( const analysis::Type* result_type, const analysis::Constant* a, analysis::ConstantManager* const_mgr) { - assert(a); + assert(result_type != nullptr); + assert(a != nullptr); + assert(const_mgr != nullptr); const analysis::Integer* integer_type = result_type->AsInteger(); assert(integer_type && "The result type of an SConvert"); int64_t value = a->GetSignExtendedValue(); return GenerateIntegerConstant(integer_type, value, const_mgr); } -// A scalar folding rule that foles OpSConvert. +// A scalar folding rule that folds OpUConvert. const analysis::Constant* FoldScalarUConvert( const analysis::Type* result_type, const analysis::Constant* a, analysis::ConstantManager* const_mgr) { - assert(a); + assert(result_type != nullptr); + assert(a != nullptr); + assert(const_mgr != nullptr); const analysis::Integer* integer_type = result_type->AsInteger(); - assert(integer_type && "The result type of an SConvert"); + assert(integer_type && "The result type of an UConvert"); uint64_t value = a->GetZeroExtendedValue(); // If the operand was an unsigned value with less than 32-bit, it would have // been sign extended earlier, and we need to clear those bits. auto* operand_type = a->type()->AsInteger(); - value = ClearUpperBits(value, 64 - operand_type->width()); + value = ZeroExtendValue(value, operand_type->width()); return GenerateIntegerConstant(integer_type, value, const_mgr); } } // namespace diff --git a/test/opt/fold_test.cpp b/test/opt/fold_test.cpp index 794d8c9797..e5f663f149 100644 --- a/test/opt/fold_test.cpp +++ b/test/opt/fold_test.cpp @@ -4118,23 +4118,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 0), - // Test case 38: Don't fold 2 + 3 (long), bad length - InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%2 = OpIAdd %long %long_2 %long_3\n" + - "OpReturn\n" + - "OpFunctionEnd", - 2, 0), - // Test case 39: Don't fold 2 + 3 (short), bad length - InstructionFoldingCase( - Header() + "%main = OpFunction %void None %void_func\n" + - "%main_lab = OpLabel\n" + - "%2 = OpIAdd %short %short_2 %short_3\n" + - "OpReturn\n" + - "OpFunctionEnd", - 2, 0), - // Test case 40: fold 1*n + // Test case 38: fold 1*n InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -4144,7 +4128,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 3), - // Test case 41: fold n*1 + // Test case 39: fold n*1 InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" + "%main_lab = OpLabel\n" + @@ -4154,7 +4138,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe "OpReturn\n" + "OpFunctionEnd", 2, 3), - // Test case 42: Don't fold comparisons of 64-bit types + // Test case 40: Don't fold comparisons of 64-bit types // (https://github.com/KhronosGroup/SPIRV-Tools/issues/3343). InstructionFoldingCase( Header() + "%main = OpFunction %void None %void_func\n" +