Skip to content

Commit

Permalink
Fixes based on code review. Remove stale test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
s-perron committed Feb 8, 2024
1 parent b5f23bf commit e50b100
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 67 deletions.
101 changes: 53 additions & 48 deletions source/opt/const_folding_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ namespace opt {
namespace {
constexpr uint32_t kExtractCompositeIdInIdx = 0;

// Returns the value obtained by setting clearing the `number_of_bits` most
// significant bits of `value`.
// Returns the value obtained by extracting the |number_of_bits| least
// significant bits from |value|, and sign-extending it to 64-bits.
uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) {
if (number_of_bits == 64) return value;

Expand All @@ -38,12 +38,12 @@ uint64_t SignExtendValue(uint64_t value, uint32_t number_of_bits) {
return value;
}

// Returns the value obtained from clearing the `number_of_bits` most
// significant bits of `value`.
uint64_t ClearUpperBits(uint64_t value, uint32_t number_of_bits) {
if (number_of_bits == 0) return value;
// Returns the value obtained by extracting the |number_of_bits| least
// significant bits from |value|, and zero-extending it to 64-bits.
uint64_t ZeroExtendValue(uint64_t value, uint32_t number_of_bits) {
if (number_of_bits == 64) return value;

uint64_t mask_for_first_bit_to_clear = 1ull << (64 - number_of_bits);
uint64_t mask_for_first_bit_to_clear = 1ull << (number_of_bits);
uint64_t mask_for_bits_to_keep = mask_for_first_bit_to_clear - 1;
value &= mask_for_bits_to_keep;
return value;
Expand All @@ -67,7 +67,7 @@ const analysis::Constant* GenerateIntegerConstant(
if (integer_type->IsSigned()) {
result = SignExtendValue(result, integer_type->width());
} else {
result = ClearUpperBits(result, 64 - integer_type->width());
result = ZeroExtendValue(result, integer_type->width());
}
words = {static_cast<uint32_t>(result)};
}
Expand Down Expand Up @@ -769,15 +769,18 @@ ConstantFoldingRule FoldUnaryOp(UnaryScalarFoldingRule scalar_rule) {
}

// Returns a |ConstantFoldingRule| that folds binary scalar ops
// using |scalar_rule| and unary vectors ops by applying
// |scalar_rule| to the elements of the vector. The |ConstantFoldingRule|
// that is returned assumes that |constants| contains 2 entries. If they are
// not |nullptr|, then their type is either |Float| or |Integer| or a |Vector|
// whose element type is |Float| or |Integer|.
// using |scalar_rule| and binary vectors ops by applying
// |scalar_rule| to the elements of the vector. The folding rule assumes that op
// has two inputs. For regular instruction, those are in operands 0 and 1. For
// extended instruction, they are in operands 1 and 2. If an element in
// |constants| is not nullprt, then the constant's type is |Float|, |Integer|,
// or |Vector| whose element type is |Float| or |Integer|.
ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) {
return [scalar_rule](IRContext* context, Instruction* inst,
const std::vector<const analysis::Constant*>& constants)
-> const analysis::Constant* {
assert(constants.size() == inst->NumInOperands());
assert(constants.size() == (inst->opcode() == spv::Op::OpExtInst ? 3 : 2));
analysis::ConstantManager* const_mgr = context->get_constant_mgr();
analysis::TypeManager* type_mgr = context->get_type_mgr();
const analysis::Type* result_type = type_mgr->GetType(inst->type_id());
Expand All @@ -788,40 +791,38 @@ ConstantFoldingRule FoldBinaryOp(BinaryScalarFoldingRule scalar_rule) {
const analysis::Constant* arg2 =
(inst->opcode() == spv::Op::OpExtInst) ? constants[2] : constants[1];

if (arg1 == nullptr) {
if (arg1 == nullptr || arg2 == nullptr) {
return nullptr;
}
if (arg2 == nullptr) {
return nullptr;

if (vector_type == nullptr) {
return scalar_rule(result_type, arg1, arg2, const_mgr);
}

if (vector_type != nullptr) {
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> b_components;
std::vector<const analysis::Constant*> results_components;
std::vector<const analysis::Constant*> a_components;
std::vector<const analysis::Constant*> b_components;
std::vector<const analysis::Constant*> results_components;

a_components = arg1->GetVectorComponents(const_mgr);
b_components = arg2->GetVectorComponents(const_mgr);
a_components = arg1->GetVectorComponents(const_mgr);
b_components = arg2->GetVectorComponents(const_mgr);
assert(a_components.size() == b_components.size());

// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
results_components.push_back(scalar_rule(vector_type->element_type(),
a_components[i],
b_components[i], const_mgr));
if (results_components[i] == nullptr) {
return nullptr;
}
// Fold each component of the vector.
for (uint32_t i = 0; i < a_components.size(); ++i) {
results_components.push_back(scalar_rule(vector_type->element_type(),
a_components[i], b_components[i],
const_mgr));
if (results_components[i] == nullptr) {
return nullptr;
}
}

// Build the constant object and return it.
std::vector<uint32_t> ids;
for (const analysis::Constant* member : results_components) {
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
} else {
return scalar_rule(result_type, arg1, arg2, const_mgr);
// Build the constant object and return it.
std::vector<uint32_t> ids;
for (const analysis::Constant* member : results_components) {
ids.push_back(const_mgr->GetDefiningInstruction(member)->result_id());
}
return const_mgr->GetConstant(vector_type, ids);
};
}

Expand Down Expand Up @@ -1701,17 +1702,17 @@ enum Sign { Signed, Unsigned };

// Returns a BinaryScalarFoldingRule that applies `op` to the scalars.
// The `signedness` is used to determine if the operands should be interpreted
// as signed or unsigned. If the operands are signed, the will be sign extended
// before the value is passed to `op`. Otherwise the values will be zero
// extended.
// as signed or unsigned. If the operands are signed, the value will be sign
// extended before the value is passed to `op`. Otherwise the values will be
// zero extended.
template <Sign signedness>
BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
uint64_t)) {
return
[op](const analysis::Type* result_type, const analysis::Constant* a,
const analysis::Constant* b,
analysis::ConstantManager* const_mgr) -> const analysis::Constant* {
assert(result_type != nullptr && a != nullptr);
assert(result_type != nullptr && a != nullptr && b != nullptr);
const analysis::Integer* integer_type = a->type()->AsInteger();
assert(integer_type != nullptr);
assert(integer_type == result_type->AsInteger());
Expand All @@ -1732,30 +1733,34 @@ BinaryScalarFoldingRule FoldBinaryIntegerOperation(uint64_t (*op)(uint64_t,
};
}

// A scalar folding rule that foles OpSConvert.
// A scalar folding rule that folds OpSConvert.
const analysis::Constant* FoldScalarSConvert(
const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) {
assert(a);
assert(result_type != nullptr);
assert(a != nullptr);
assert(const_mgr != nullptr);
const analysis::Integer* integer_type = result_type->AsInteger();
assert(integer_type && "The result type of an SConvert");
int64_t value = a->GetSignExtendedValue();
return GenerateIntegerConstant(integer_type, value, const_mgr);
}

// A scalar folding rule that foles OpSConvert.
// A scalar folding rule that folds OpUConvert.
const analysis::Constant* FoldScalarUConvert(
const analysis::Type* result_type, const analysis::Constant* a,
analysis::ConstantManager* const_mgr) {
assert(a);
assert(result_type != nullptr);
assert(a != nullptr);
assert(const_mgr != nullptr);
const analysis::Integer* integer_type = result_type->AsInteger();
assert(integer_type && "The result type of an SConvert");
assert(integer_type && "The result type of an UConvert");
uint64_t value = a->GetZeroExtendedValue();

// If the operand was an unsigned value with less than 32-bit, it would have
// been sign extended earlier, and we need to clear those bits.
auto* operand_type = a->type()->AsInteger();
value = ClearUpperBits(value, 64 - operand_type->width());
value = ZeroExtendValue(value, operand_type->width());
return GenerateIntegerConstant(integer_type, value, const_mgr);
}
} // namespace
Expand Down
22 changes: 3 additions & 19 deletions test/opt/fold_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4118,23 +4118,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 38: Don't fold 2 + 3 (long), bad length
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIAdd %long %long_2 %long_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 39: Don't fold 2 + 3 (short), bad length
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
"%2 = OpIAdd %short %short_2 %short_3\n" +
"OpReturn\n" +
"OpFunctionEnd",
2, 0),
// Test case 40: fold 1*n
// Test case 38: fold 1*n
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
Expand All @@ -4144,7 +4128,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 41: fold n*1
// Test case 39: fold n*1
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
"%main_lab = OpLabel\n" +
Expand All @@ -4154,7 +4138,7 @@ INSTANTIATE_TEST_SUITE_P(IntegerArithmeticTestCases, GeneralInstructionFoldingTe
"OpReturn\n" +
"OpFunctionEnd",
2, 3),
// Test case 42: Don't fold comparisons of 64-bit types
// Test case 40: Don't fold comparisons of 64-bit types
// (https://github.com/KhronosGroup/SPIRV-Tools/issues/3343).
InstructionFoldingCase<uint32_t>(
Header() + "%main = OpFunction %void None %void_func\n" +
Expand Down

0 comments on commit e50b100

Please sign in to comment.