Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Refactor] Replace std::tie with structured bindings (apache#12610)
Browse files Browse the repository at this point in the history
* [Refactor] Replace std::tie with structured bindings

With C++17 enabled in apache#12337, using
structured bindings to replace cases where `std::tie` is used to
define local variables.

* Added missing header for <optional>

* Silenced unused variable warnings after structured bindings

This is a bug in gcc version 7, resolved in gcc 8.  While gcc version
7 is used for CI, we'll need to silence unused variable warnings
resulting from using only part of a structured binding.

More information: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
  • Loading branch information
Lunderberg authored and xinetzone committed Nov 25, 2022
1 parent a9f2bf3 commit ea78fc0
Show file tree
Hide file tree
Showing 35 changed files with 105 additions and 185 deletions.
4 changes: 1 addition & 3 deletions src/auto_scheduler/auto_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions")

TVM_REGISTER_GLOBAL("auto_scheduler.AutoSchedule")
.set_body_typed([](SearchPolicy search_policy, TuningOptions tuning_options) {
te::Schedule sch;
Array<te::Tensor> return_tensors;
std::tie(sch, return_tensors) = AutoSchedule(search_policy, tuning_options);
auto [sch, return_tensors] = AutoSchedule(search_policy, tuning_options);
return Array<ObjectRef>{sch, return_tensors};
});
} // namespace auto_scheduler
Expand Down
17 changes: 6 additions & 11 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1325,10 +1325,9 @@ State ComputeDAG::InferBound(const State& state) const {

Array<te::Stage> stages;
StageToAxesMap stage_to_axes;
te::Schedule sch;
Array<te::Tensor> tensors;
// Replay steps to tvm::Schedule
std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes);
auto [sch, tensors] = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes);
(void)tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
sch = sch.normalize_for_feature_extraction();
// Get bound information from TVM schedule
Map<IterVar, Range> bounds = te::InferBound(sch);
Expand Down Expand Up @@ -1382,9 +1381,8 @@ Array<State> ComputeDAG::InferBound(const Array<State>& states) const {
}

ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array<Step>& transform_steps) const {
te::Schedule sch;
Array<te::Tensor> old_tensors;
std::tie(sch, old_tensors) = ApplySteps(transform_steps);
auto [sch, old_tensors] = ApplySteps(transform_steps);
(void)old_tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767
return ComputeDAG(sch);
}

Expand Down Expand Up @@ -1481,11 +1479,8 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG")

TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState")
.set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) {
te::Schedule sch;
Array<te::Tensor> return_tensors;
std::tie(sch, return_tensors) =
dag.ApplySteps(state->transform_steps, nullptr, nullptr,
static_cast<LayoutRewriteOption>(layout_rewrite));
auto [sch, return_tensors] = dag.ApplySteps(state->transform_steps, nullptr, nullptr,
static_cast<LayoutRewriteOption>(layout_rewrite));
return Array<ObjectRef>{sch, return_tensors};
});

Expand Down
9 changes: 2 additions & 7 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -952,9 +952,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
unique_lines = std::max(unique_lines, 1.0f);
}

ReuseType reuse_type;
float reuse_dis_iter, reuse_dis_bytes, reuse_ct;
std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) =
auto [reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct] =
ComputeReuse(t, acc.indices, for_loop_stack_, for_touch_regions_, ana_);

acc_feas.emplace_back();
Expand Down Expand Up @@ -1356,10 +1354,7 @@ void GetPerStoreFeatureName(int max_n_bufs, std::vector<std::string>* ret) {

void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs,
std::vector<float>* feature, std::atomic<int>* error_ct) {
te::Schedule sch;
Array<te::Tensor> tensors;

std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps);
auto [sch, tensors] = task->compute_dag.ApplySteps(state->transform_steps);

// When inlining, replace const matrices with const values.
// Produces wrong IR, but good enough for feature extraction, and
Expand Down
4 changes: 1 addition & 3 deletions src/auto_scheduler/search_policy/search_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks")

TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyContinueSearchOneRound")
.set_body_typed([](SearchPolicy policy, int num_measure, ProgramMeasurer measurer) {
Array<MeasureInput> inputs;
Array<MeasureResult> results;
std::tie(inputs, results) = policy->ContinueSearchOneRound(num_measure, measurer);
auto [inputs, results] = policy->ContinueSearchOneRound(num_measure, measurer);
return Array<ObjectRef>{inputs, results};
});

Expand Down
3 changes: 1 addition & 2 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,7 @@ SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition(
const auto& op = state->stages[stage_id]->op;
if (op->IsInstance<te::ComputeOpNode>()) {
// Compute the product of lengths of all space iters and all reduce iters
int cum_space_len, cum_reduce_len;
std::tie(cum_space_len, cum_reduce_len) =
auto [cum_space_len, cum_reduce_len] =
GetCumulativeSpaceAndReductionLength(state->stages[stage_id]);

if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) {
Expand Down
5 changes: 1 addition & 4 deletions src/ir/instrument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,7 @@ String RenderPassProfiles() {
os << std::fixed;

while (profiles.size() > 0) {
size_t depth;
PassProfile::Duration parent_duration;
PassProfile* profile;
std::tie(depth, parent_duration, profile) = profiles.top();
auto [depth, parent_duration, profile] = profiles.top();
profiles.pop();

// indent depth
Expand Down
4 changes: 1 addition & 3 deletions src/meta_schedule/database/json_database.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ class JSONDatabaseNode : public DatabaseNode {

Workload CommitWorkload(const IRModule& mod) {
// Try to insert `mod` into `workloads_`
decltype(this->workloads2idx_)::iterator it;
bool inserted = false;
std::tie(it, inserted) =
auto [it, inserted] =
this->workloads2idx_.emplace(Workload(mod, tvm::StructuralHash()(mod)), -1);
Workload workload = it->first;
// If `mod` is new in `workloads2idx_`, append it to the workload file
Expand Down
4 changes: 1 addition & 3 deletions src/meta_schedule/mutator/mutate_compute_location.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@ std::vector<MutateComputeLocationNode::Candidate> MutateComputeLocationNode::Fin
int old_decision = Downcast<Integer>(decision)->value;

// Step 2. Collect all the compute_at locations.
Array<tir::StmtSRef> location_srefs;
std::vector<int> location_indices;
std::tie(location_srefs, location_indices) = CollectComputeLocation(sch->state(), block_sref);
auto [location_srefs, location_indices] = CollectComputeLocation(sch->state(), block_sref);
// Step 3. Remove the old decision.
auto it = std::find(location_indices.begin(), location_indices.end(), old_decision);
if (it != location_indices.end()) {
Expand Down
6 changes: 1 addition & 5 deletions src/meta_schedule/schedule_rule/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,11 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
// Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the
// block to its consumers. We want to fuse as much as possible because it results in
// significantly faster schedule.
bool fusible = false;
// `target_loop` is the loop position where the input block will be computed at.
tir::LoopRV target_loop{nullptr};
// `target_block` is the consumer block that we want to compute-at the input block to.
tir::BlockRV target_block{nullptr};
// `tgt_block_innermost_loop` is the innermost loop outside the target block.
tir::LoopRV tgt_block_innermost_loop{nullptr};

std::tie(fusible, target_loop, target_block, tgt_block_innermost_loop) =
auto [fusible, target_loop, target_block, tgt_block_innermost_loop] =
GetComputeTargetLoopAndBlock(tmp_sch, block_rv);

// Step 3. Try block fusion.
Expand Down
4 changes: 1 addition & 3 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
result.clear();
while (!stack.empty()) {
// get the stack.top()
tir::Schedule sch;
Array<tir::BlockRV> blocks;
std::tie(sch, blocks) = stack.back();
auto [sch, blocks] = stack.back();
stack.pop_back();
// if all blocks are visited
if (blocks.empty()) {
Expand Down
12 changes: 3 additions & 9 deletions src/relay/collage/partition_rule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ std::vector<CandidatePartition> DFPatternPartitionRuleNode::AllCandidates(
continue;
}
IndexSet inside = MatcherToIndexSet(matcher);
OpPatternKind kind;
String label;
std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside);
auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside);
SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label));
String rule_name = rule_name_.empty() ? sub_graph->label_ : rule_name_;
CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec);
Expand Down Expand Up @@ -256,9 +254,7 @@ std::vector<CandidatePartition> OpCallByKindPartitionRuleNode::AllCandidates(
auto node = dataflow_graph.index_to_node(index);
Expr sub_expr = node->ref();
if (sub_expr->IsInstance<CallNode>()) {
OpPatternKind kind;
String label;
std::tie(kind, label) = SubExprKindAndLabel(sub_expr);
auto [kind, label] = SubExprKindAndLabel(sub_expr);
if (kind <= kOutEWiseFusable) {
IndexSet inside(dataflow_graph.size(), {index});
SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label));
Expand Down Expand Up @@ -404,9 +400,7 @@ std::vector<CandidatePartition> HostPartitionRuleNode::AllCandidates(
continue;
}
IndexSet inside(dataflow_graph.size(), {index});
OpPatternKind kind;
String label;
std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside);
auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside);
SubGraph sub_graph(dataflow_graph, std::move(inside), kind, label);
String rule_name = NestLabels(rule_name_, sub_graph->label_);
// We'll a zero cost for the candidate since we'll never want to actually estimate the cost
Expand Down
8 changes: 2 additions & 6 deletions src/relay/collage/sub_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,7 @@ std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph&
bool first = true;
OpPatternKind max_kind = kElemWise;
for (PostDfsIndex index : inside) {
OpPatternKind sub_kind;
std::string sub_label;
std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
auto [sub_kind, sub_label] = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref());
if (!sub_label.empty()) {
if (first) {
first = false;
Expand Down Expand Up @@ -995,9 +993,7 @@ transform::Pass PartitionForTesting(Integer max_exits, Bool allow_taps, String c
// Build the overall sub-graph, which will include any "Composite" functions as
// well as any nodes without a label.
IndexSet inside(dataflow_graph.size(), node_indexes);
OpPatternKind kind;
String label;
std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside);
auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside);
SubGraph sub_graph(dataflow_graph, inside, kind, label, std::move(nested_sub_graphs));

// Push the overall sub-graph into the final "Compiler" function.
Expand Down
4 changes: 2 additions & 2 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,9 +722,9 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
<< "qnn.conv2d supports only OIHW/HWIO/HWOI/OHWI kernel data layout.";
ICHECK(param->kernel_size.defined()) << "qnn.conv2d requires kernel size to be specified.";

int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
auto [batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier] =
GetWorkload(arg_types, param);
(void)batch_size; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767

// zero points are allowed to be non-scalar. Let's check if that's the case.
bool dynamic_zp = false;
Expand Down
6 changes: 2 additions & 4 deletions src/relay/qnn/op/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,11 @@ Expr QnnLeakyReluCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
output_zero_point, input_shape);

// alpha * Q_i'
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(alpha);
auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(alpha);
auto prod = FixedPointMultiply(requantized_expr, fixed_point_multiplier, shift);

// (1 - alpha) * zp_o
int32_t fixed_point_multiplier_z, shift_z;
std::tie(fixed_point_multiplier_z, shift_z) = GetFixedPointMultiplierShift(1 - alpha);
auto [fixed_point_multiplier_z, shift_z] = GetFixedPointMultiplierShift(1 - alpha);
auto scaled_z = FixedPointMultiply(output_zero_point, fixed_point_multiplier_z, shift_z);

// alpha * Q_i' + (1 - alpha) * zp_o
Expand Down
3 changes: 1 addition & 2 deletions src/relay/qnn/op/requantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,7 @@ Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale,
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
// Skip if input and output scales are same.
if (!IsEqualScalar(input_scale, output_scale)) {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier);
auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(double_multiplier);

const bool is_upward_rounding = (param->rounding == "UPWARD");

Expand Down
6 changes: 2 additions & 4 deletions src/relay/qnn/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
tensor = Cast(tensor, hp_dtype);

// 1) Calculating the integer multiplier and integer shift
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(multiplier);
auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(multiplier);
int left_shift = shift > 0 ? shift : 0;
int right_shift = shift > 0 ? 0 : -shift;

Expand Down Expand Up @@ -128,8 +127,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
std::vector<int32_t> fixed_pt_multipliers, lshifts, rshifts;
bool is_lshift_required = false;
for (auto multiplier : multipliers) {
int32_t fixed_pt_multiplier, shift;
std::tie(fixed_pt_multiplier, shift) = GetFixedPointMultiplierShift(multiplier);
auto [fixed_pt_multiplier, shift] = GetFixedPointMultiplierShift(multiplier);
int lshift = shift > 0 ? shift : 0;
int rshift = shift > 0 ? 0 : -shift;
fixed_pt_multipliers.push_back(fixed_pt_multiplier);
Expand Down
6 changes: 2 additions & 4 deletions src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype,
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
if (cfg->rounding == "UPWARD") {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) = qnn::GetFixedPointMultiplierShift(factor);
auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(factor);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape);
Expand Down Expand Up @@ -135,8 +134,7 @@ Expr QuantizeRealize(const Call& ref_call, const Array<Expr>& new_args, const Ob
} else {
data = Cast(data, DataType::Int(64));
if (cfg->rounding == "UPWARD") {
int32_t fixed_point_multiplier, shift;
std::tie(fixed_point_multiplier, shift) =
auto [fixed_point_multiplier, shift] =
qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm);
data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift);
} else {
Expand Down
4 changes: 1 addition & 3 deletions src/relay/transforms/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
Call MakeCombinedOp(const Group& branches) {
const Op& conv2d = Op::Get("nn.conv2d");
Expr data = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_channels;
std::tie(new_weight, new_channels) = TransformWeight(branches);
auto [new_weight, new_channels] = TransformWeight(branches);

const CallNode* group_root = branches[0][0];
const auto* attrs = group_root->attrs.as<Conv2DAttrs>();
Expand Down
4 changes: 1 addition & 3 deletions src/relay/transforms/combine_parallel_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,8 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner {
Call MakeCombinedOp(const Group& branches) {
const Op& dense_op = Op::Get("nn.dense");
Expr input = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_output_dims;
// concat all weights into one
std::tie(new_weight, new_output_dims) = TransformWeight(branches);
auto [new_weight, new_output_dims] = TransformWeight(branches);
const auto* origin_attrs = branches[0][0]->attrs.as<DenseAttrs>();
ICHECK(origin_attrs);
const auto dense_attrs = make_object<DenseAttrs>();
Expand Down
4 changes: 1 addition & 3 deletions src/runtime/graph_executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,9 +674,7 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name,
});
} else if (name == "get_input_info") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
GraphExecutor::ShapeInfo shape_info;
GraphExecutor::DtypeInfo dtype_info;
std::tie(shape_info, dtype_info) = this->GetInputInfo();
auto [shape_info, dtype_info] = this->GetInputInfo();
Map<String, ObjectRef> input_info;
input_info.Set("shape", shape_info);
input_info.Set("dtype", dtype_info);
Expand Down
12 changes: 4 additions & 8 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,8 +403,7 @@ class Replacer {
}
std::string rewrite(std::string str) {
for (auto&& rule : _rules) {
std::string pattern, replacement;
std::tie(pattern, replacement) = rule;
auto [pattern, replacement] = rule;
size_t len = pattern.size();
size_t new_len = replacement.size();
size_t pos = str.find(pattern);
Expand Down Expand Up @@ -532,8 +531,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
dtype_c = ptx::DTypeFromString(C_dtype);
ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout),
layout_b = ptx::LayoutTypeFromString(B_layout);
int m, n, k;
std::tie(m, n, k) = ptx::ParseMMAShape(shape);
auto [m, n, k] = ptx::ParseMMAShape(shape);
CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse,
saturate);
std::string asm_code = R"(
Expand All @@ -545,8 +543,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo
: {inputs});
}
)";
std::string templates_str, inputs_str, outputs_str;
std::tie(templates_str, inputs_str, outputs_str) =
auto [templates_str, inputs_str, outputs_str] =
GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse);

// replace patterns
Expand Down Expand Up @@ -622,8 +619,7 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type
);
}
)";
std::string templates_str, outputs_str;
std::tie(templates_str, outputs_str) = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);
auto [templates_str, outputs_str] = GetLoadMatrixOperands(num, local_ptr, local_elem_offset);

Replacer replacer;
replacer.register_rule("{.shape}", ".m8n8");
Expand Down
10 changes: 4 additions & 6 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1183,21 +1183,19 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A
return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges);
}

PrimExpr new_outer_cond, new_reduce_cond;
Array<PrimExpr> new_source = red->source;

// Partially lift conditions from the reduce condition
std::tie(new_outer_cond, new_reduce_cond) =
auto [new_outer_cond, new_reduce_cond] =
LiftConditionsThroughReduction(red->condition, red->axis, axis);

// If it's not sum then we haven't yet lifted nonzeroness cond from the source
if (!is_sum) {
PrimExpr outer_nz_cond, nz_cond, nz_source;
auto nz = NonzeronessCondition(red->source[red->value_index]);
// Append conditions from the reduction
nz_cond = new_reduce_cond && nz.cond;
nz_source = nz.value;
std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis);
PrimExpr nz_source = nz.value;
auto [outer_nz_cond, nz_cond] =
LiftConditionsThroughReduction(new_reduce_cond && nz.cond, red->axis, axis);
new_outer_cond = new_outer_cond && outer_nz_cond;
new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype())));
}
Expand Down
Loading

0 comments on commit ea78fc0

Please sign in to comment.