diff --git a/src/auto_scheduler/auto_schedule.cc b/src/auto_scheduler/auto_schedule.cc index 747aa01cfa05..41aa49c77193 100755 --- a/src/auto_scheduler/auto_schedule.cc +++ b/src/auto_scheduler/auto_schedule.cc @@ -78,9 +78,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.TuningOptions") TVM_REGISTER_GLOBAL("auto_scheduler.AutoSchedule") .set_body_typed([](SearchPolicy search_policy, TuningOptions tuning_options) { - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = AutoSchedule(search_policy, tuning_options); + auto [sch, return_tensors] = AutoSchedule(search_policy, tuning_options); return Array{sch, return_tensors}; }); } // namespace auto_scheduler diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index dad55db0303f..5500707fb9af 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -1325,10 +1325,9 @@ State ComputeDAG::InferBound(const State& state) const { Array stages; StageToAxesMap stage_to_axes; - te::Schedule sch; - Array tensors; // Replay steps to tvm::Schedule - std::tie(sch, tensors) = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes); + auto [sch, tensors] = ApplySteps(pstate->transform_steps, &stages, &stage_to_axes); + (void)tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 sch = sch.normalize_for_feature_extraction(); // Get bound information from TVM schedule Map bounds = te::InferBound(sch); @@ -1382,9 +1381,8 @@ Array ComputeDAG::InferBound(const Array& states) const { } ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array& transform_steps) const { - te::Schedule sch; - Array old_tensors; - std::tie(sch, old_tensors) = ApplySteps(transform_steps); + auto [sch, old_tensors] = ApplySteps(transform_steps); + (void)old_tensors; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 return ComputeDAG(sch); } @@ -1481,11 +1479,8 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAG") TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGApplyStepsFromState") .set_body_typed([](const ComputeDAG& dag, const State& state, int layout_rewrite) { - te::Schedule sch; - Array return_tensors; - std::tie(sch, return_tensors) = - dag.ApplySteps(state->transform_steps, nullptr, nullptr, - static_cast(layout_rewrite)); + auto [sch, return_tensors] = dag.ApplySteps(state->transform_steps, nullptr, nullptr, + static_cast(layout_rewrite)); return Array{sch, return_tensors}; }); diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index c930bf0c4e73..e079018151a7 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -952,9 +952,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { unique_lines = std::max(unique_lines, 1.0f); } - ReuseType reuse_type; - float reuse_dis_iter, reuse_dis_bytes, reuse_ct; - std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) = + auto [reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct] = ComputeReuse(t, acc.indices, for_loop_stack_, for_touch_regions_, ana_); acc_feas.emplace_back(); @@ -1356,10 +1354,7 @@ void GetPerStoreFeatureName(int max_n_bufs, std::vector* ret) { void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, int max_n_bufs, std::vector* feature, std::atomic* error_ct) { - te::Schedule sch; - Array tensors; - - std::tie(sch, tensors) = task->compute_dag.ApplySteps(state->transform_steps); + auto [sch, tensors] = task->compute_dag.ApplySteps(state->transform_steps); // When inlining, replace const matrices with const values. // Produces wrong IR, but good enough for feature extraction, and diff --git a/src/auto_scheduler/search_policy/search_policy.cc b/src/auto_scheduler/search_policy/search_policy.cc index 702eec087668..196bee8ff0e2 100644 --- a/src/auto_scheduler/search_policy/search_policy.cc +++ b/src/auto_scheduler/search_policy/search_policy.cc @@ -106,9 +106,7 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyRunCallbacks") TVM_REGISTER_GLOBAL("auto_scheduler.SearchPolicyContinueSearchOneRound") .set_body_typed([](SearchPolicy policy, int num_measure, ProgramMeasurer measurer) { - Array inputs; - Array results; - std::tie(inputs, results) = policy->ContinueSearchOneRound(num_measure, measurer); + auto [inputs, results] = policy->ContinueSearchOneRound(num_measure, measurer); return Array{inputs, results}; }); diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 8df69fc7ce3b..862e593c9dd3 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -343,8 +343,7 @@ SketchGenerationRule::ConditionKind RuleCrossThreadReduction::MeetCondition( const auto& op = state->stages[stage_id]->op; if (op->IsInstance()) { // Compute the product of lengths of all space iters and all reduce iters - int cum_space_len, cum_reduce_len; - std::tie(cum_space_len, cum_reduce_len) = + auto [cum_space_len, cum_reduce_len] = GetCumulativeSpaceAndReductionLength(state->stages[stage_id]); if (NeedsMultilevelTiling(policy.search_task, state, stage_id)) { diff --git a/src/ir/instrument.cc b/src/ir/instrument.cc index 795e5b8cb542..6701308fbfb7 100644 --- a/src/ir/instrument.cc +++ b/src/ir/instrument.cc @@ -288,10 +288,7 @@ String RenderPassProfiles() { os << std::fixed; while (profiles.size() > 0) { - size_t depth; - PassProfile::Duration parent_duration; - PassProfile* profile; - std::tie(depth, parent_duration, profile) = profiles.top(); + auto [depth, parent_duration, profile] = profiles.top(); profiles.pop(); // indent depth diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index f8fb64e92407..2e4f85260835 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -115,9 +115,7 @@ class JSONDatabaseNode : public DatabaseNode { Workload CommitWorkload(const IRModule& mod) { // Try to insert `mod` into `workloads_` - decltype(this->workloads2idx_)::iterator it; - bool inserted = false; - std::tie(it, inserted) = + auto [it, inserted] = this->workloads2idx_.emplace(Workload(mod, tvm::StructuralHash()(mod)), -1); Workload workload = it->first; // If `mod` is new in `workloads2idx_`, append it to the workload file diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 3ed56df1b381..9d6d69ba355f 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -86,9 +86,7 @@ std::vector MutateComputeLocationNode::Fin int old_decision = Downcast(decision)->value; // Step 2. Collect all the compute_at locations. - Array location_srefs; - std::vector location_indices; - std::tie(location_srefs, location_indices) = CollectComputeLocation(sch->state(), block_sref); + auto [location_srefs, location_indices] = CollectComputeLocation(sch->state(), block_sref); // Step 3. Remove the old decision. auto it = std::find(location_indices.begin(), location_indices.end(), old_decision); if (it != location_indices.end()) { diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 242f1aea89c5..0f0ab99e7259 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -64,15 +64,11 @@ class CrossThreadReductionNode : public ScheduleRuleNode { // Step 2. Check the opportunity for block fusion. We say "fusible", if we can compute-at the // block to its consumers. We want to fuse as much as possible because it results in // significantly faster schedule. - bool fusible = false; // `target_loop` is the loop position where the input block will be computed at. - tir::LoopRV target_loop{nullptr}; // `target_block` is the consumer block that we want to compute-at the input block to. - tir::BlockRV target_block{nullptr}; // `tgt_block_innermost_loop` is the innermost loop outside the target block. - tir::LoopRV tgt_block_innermost_loop{nullptr}; - std::tie(fusible, target_loop, target_block, tgt_block_innermost_loop) = + auto [fusible, target_loop, target_block, tgt_block_innermost_loop] = GetComputeTargetLoopAndBlock(tmp_sch, block_rv); // Step 3. Try block fusion. diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index eab084f8978f..9be89e2d9c70 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -140,9 +140,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { result.clear(); while (!stack.empty()) { // get the stack.top() - tir::Schedule sch; - Array blocks; - std::tie(sch, blocks) = stack.back(); + auto [sch, blocks] = stack.back(); stack.pop_back(); // if all blocks are visited if (blocks.empty()) { diff --git a/src/relay/collage/partition_rule.cc b/src/relay/collage/partition_rule.cc index e11f740acfe9..1d8c5e9723ee 100644 --- a/src/relay/collage/partition_rule.cc +++ b/src/relay/collage/partition_rule.cc @@ -92,9 +92,7 @@ std::vector DFPatternPartitionRuleNode::AllCandidates( continue; } IndexSet inside = MatcherToIndexSet(matcher); - OpPatternKind kind; - String label; - std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside); SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); String rule_name = rule_name_.empty() ? sub_graph->label_ : rule_name_; CandidatePartition candidate(std::move(rule_name), std::move(sub_graph), spec); @@ -256,9 +254,7 @@ std::vector OpCallByKindPartitionRuleNode::AllCandidates( auto node = dataflow_graph.index_to_node(index); Expr sub_expr = node->ref(); if (sub_expr->IsInstance()) { - OpPatternKind kind; - String label; - std::tie(kind, label) = SubExprKindAndLabel(sub_expr); + auto [kind, label] = SubExprKindAndLabel(sub_expr); if (kind <= kOutEWiseFusable) { IndexSet inside(dataflow_graph.size(), {index}); SubGraph sub_graph(dataflow_graph, std::move(inside), kind, std::move(label)); @@ -404,9 +400,7 @@ std::vector HostPartitionRuleNode::AllCandidates( continue; } IndexSet inside(dataflow_graph.size(), {index}); - OpPatternKind kind; - String label; - std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside); SubGraph sub_graph(dataflow_graph, std::move(inside), kind, label); String rule_name = NestLabels(rule_name_, sub_graph->label_); // We'll a zero cost for the candidate since we'll never want to actually estimate the cost diff --git a/src/relay/collage/sub_graph.cc b/src/relay/collage/sub_graph.cc index 63edc8c079fb..dee72093fd2f 100644 --- a/src/relay/collage/sub_graph.cc +++ b/src/relay/collage/sub_graph.cc @@ -439,9 +439,7 @@ std::pair SubGraphKindAndLabel(const DataflowGraph& bool first = true; OpPatternKind max_kind = kElemWise; for (PostDfsIndex index : inside) { - OpPatternKind sub_kind; - std::string sub_label; - std::tie(sub_kind, sub_label) = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref()); + auto [sub_kind, sub_label] = SubExprKindAndLabel(dataflow_graph.index_to_node(index)->ref()); if (!sub_label.empty()) { if (first) { first = false; @@ -995,9 +993,7 @@ transform::Pass PartitionForTesting(Integer max_exits, Bool allow_taps, String c // Build the overall sub-graph, which will include any "Composite" functions as // well as any nodes without a label. IndexSet inside(dataflow_graph.size(), node_indexes); - OpPatternKind kind; - String label; - std::tie(kind, label) = SubGraphKindAndLabel(dataflow_graph, inside); + auto [kind, label] = SubGraphKindAndLabel(dataflow_graph, inside); SubGraph sub_graph(dataflow_graph, inside, kind, label, std::move(nested_sub_graphs)); // Push the overall sub-graph into the final "Compiler" function. diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc index 42e4540f0f2c..64a5a02e6e25 100644 --- a/src/relay/qnn/op/convolution.cc +++ b/src/relay/qnn/op/convolution.cc @@ -722,9 +722,9 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array& new_args, << "qnn.conv2d supports only OIHW/HWIO/HWOI/OHWI kernel data layout."; ICHECK(param->kernel_size.defined()) << "qnn.conv2d requires kernel size to be specified."; - int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier; - std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) = + auto [batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier] = GetWorkload(arg_types, param); + (void)batch_size; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 // zero points are allowed to be non-scalar. Let's check if that's the case. bool dynamic_zp = false; diff --git a/src/relay/qnn/op/leaky_relu.cc b/src/relay/qnn/op/leaky_relu.cc index 75bfabb7db85..458fde0d8a08 100644 --- a/src/relay/qnn/op/leaky_relu.cc +++ b/src/relay/qnn/op/leaky_relu.cc @@ -125,13 +125,11 @@ Expr QnnLeakyReluCanonicalize(const Attrs& attrs, const Array& new_args, output_zero_point, input_shape); // alpha * Q_i' - int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(alpha); + auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(alpha); auto prod = FixedPointMultiply(requantized_expr, fixed_point_multiplier, shift); // (1 - alpha) * zp_o - int32_t fixed_point_multiplier_z, shift_z; - std::tie(fixed_point_multiplier_z, shift_z) = GetFixedPointMultiplierShift(1 - alpha); + auto [fixed_point_multiplier_z, shift_z] = GetFixedPointMultiplierShift(1 - alpha); auto scaled_z = FixedPointMultiply(output_zero_point, fixed_point_multiplier_z, shift_z); // alpha * Q_i' + (1 - alpha) * zp_o diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc index 5bf53a95edda..ae321b459788 100644 --- a/src/relay/qnn/op/requantize.cc +++ b/src/relay/qnn/op/requantize.cc @@ -223,8 +223,7 @@ Expr RequantizeLowerInt(const Expr& input_tensor, const Expr& input_scale, static_cast(input_scale_float) / static_cast(output_scale_float); // Skip if input and output scales are same. if (!IsEqualScalar(input_scale, output_scale)) { - int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(double_multiplier); + auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(double_multiplier); const bool is_upward_rounding = (param->rounding == "UPWARD"); diff --git a/src/relay/qnn/utils.cc b/src/relay/qnn/utils.cc index 7dfd788d96c6..ed7a415cf6af 100644 --- a/src/relay/qnn/utils.cc +++ b/src/relay/qnn/utils.cc @@ -64,8 +64,7 @@ Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier, tensor = Cast(tensor, hp_dtype); // 1) Calculating the integer multiplier and integer shift - int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = GetFixedPointMultiplierShift(multiplier); + auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(multiplier); int left_shift = shift > 0 ? shift : 0; int right_shift = shift > 0 ? 0 : -shift; @@ -128,8 +127,7 @@ Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector multipliers, std::vector fixed_pt_multipliers, lshifts, rshifts; bool is_lshift_required = false; for (auto multiplier : multipliers) { - int32_t fixed_pt_multiplier, shift; - std::tie(fixed_pt_multiplier, shift) = GetFixedPointMultiplierShift(multiplier); + auto [fixed_pt_multiplier, shift] = GetFixedPointMultiplierShift(multiplier); int lshift = shift > 0 ? shift : 0; int rshift = shift > 0 ? 0 : -shift; fixed_pt_multipliers.push_back(fixed_pt_multiplier); diff --git a/src/relay/quantize/realize.cc b/src/relay/quantize/realize.cc index 5766c62eaa43..720ef25cd33d 100644 --- a/src/relay/quantize/realize.cc +++ b/src/relay/quantize/realize.cc @@ -77,8 +77,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype, return Multiply(data, MakeConstantScalar(dtype, factor)); } else { if (cfg->rounding == "UPWARD") { - int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = qnn::GetFixedPointMultiplierShift(factor); + auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(factor); data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); } else { data = qnn::FixedPointMultiplyToNearest(data, factor, data_shape); @@ -135,8 +134,7 @@ Expr QuantizeRealize(const Call& ref_call, const Array& new_args, const Ob } else { data = Cast(data, DataType::Int(64)); if (cfg->rounding == "UPWARD") { - int32_t fixed_point_multiplier, shift; - std::tie(fixed_point_multiplier, shift) = + auto [fixed_point_multiplier, shift] = qnn::GetFixedPointMultiplierShift(idom_scale_imm / odom_scale_imm); data = relay::FixedPointMultiply(data, fixed_point_multiplier, shift); } else { diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index 20b206e0423c..9c7bcc27ec82 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -83,9 +83,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { Call MakeCombinedOp(const Group& branches) { const Op& conv2d = Op::Get("nn.conv2d"); Expr data = branches[0][0]->args[0]; - Expr new_weight; - IndexExpr new_channels; - std::tie(new_weight, new_channels) = TransformWeight(branches); + auto [new_weight, new_channels] = TransformWeight(branches); const CallNode* group_root = branches[0][0]; const auto* attrs = group_root->attrs.as(); diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index d5404ba30f90..7cf102b5bcab 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -116,10 +116,8 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner { Call MakeCombinedOp(const Group& branches) { const Op& dense_op = Op::Get("nn.dense"); Expr input = branches[0][0]->args[0]; - Expr new_weight; - IndexExpr new_output_dims; // concat all weights into one - std::tie(new_weight, new_output_dims) = TransformWeight(branches); + auto [new_weight, new_output_dims] = TransformWeight(branches); const auto* origin_attrs = branches[0][0]->attrs.as(); ICHECK(origin_attrs); const auto dense_attrs = make_object(); diff --git a/src/runtime/graph_executor/graph_executor.cc b/src/runtime/graph_executor/graph_executor.cc index e3113dbfe54c..fc7e82bed4e2 100644 --- a/src/runtime/graph_executor/graph_executor.cc +++ b/src/runtime/graph_executor/graph_executor.cc @@ -674,9 +674,7 @@ PackedFunc GraphExecutor::GetFunction(const std::string& name, }); } else if (name == "get_input_info") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - GraphExecutor::ShapeInfo shape_info; - GraphExecutor::DtypeInfo dtype_info; - std::tie(shape_info, dtype_info) = this->GetInputInfo(); + auto [shape_info, dtype_info] = this->GetInputInfo(); Map input_info; input_info.Set("shape", shape_info); input_info.Set("dtype", dtype_info); diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index c5e3bf98ec2d..881c425e7742 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -403,8 +403,7 @@ class Replacer { } std::string rewrite(std::string str) { for (auto&& rule : _rules) { - std::string pattern, replacement; - std::tie(pattern, replacement) = rule; + auto [pattern, replacement] = rule; size_t len = pattern.size(); size_t new_len = replacement.size(); size_t pos = str.find(pattern); @@ -532,8 +531,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo dtype_c = ptx::DTypeFromString(C_dtype); ptx::LayoutType layout_a = ptx::LayoutTypeFromString(A_layout), layout_b = ptx::LayoutTypeFromString(B_layout); - int m, n, k; - std::tie(m, n, k) = ptx::ParseMMAShape(shape); + auto [m, n, k] = ptx::ParseMMAShape(shape); CheckMMAConfigValidity(m, n, k, layout_a, layout_b, dtype_a, dtype_b, dtype_c, bit_op, sparse, saturate); std::string asm_code = R"( @@ -545,8 +543,7 @@ std::string PrintMMAAssembly(const std::string& shape, const std::string& A_layo : {inputs}); } )"; - std::string templates_str, inputs_str, outputs_str; - std::tie(templates_str, inputs_str, outputs_str) = + auto [templates_str, inputs_str, outputs_str] = GetMMAOperands(m, n, k, dtype_a, dtype_b, dtype_c, sparse); // replace patterns @@ -622,8 +619,7 @@ std::string PrintLoadMatrixAssembly(bool trans, int num, const std::string& type ); } )"; - std::string templates_str, outputs_str; - std::tie(templates_str, outputs_str) = GetLoadMatrixOperands(num, local_ptr, local_elem_offset); + auto [templates_str, outputs_str] = GetLoadMatrixOperands(num, local_ptr, local_elem_offset); Replacer replacer; replacer.register_rule("{.shape}", ".m8n8"); diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc index 28f57c77da70..26047e879e9b 100644 --- a/src/te/autodiff/ad_simplify.cc +++ b/src/te/autodiff/ad_simplify.cc @@ -1183,21 +1183,19 @@ PrimExpr RemoveJacobianAndLiftNonzeroCondImpl(const PrimExpr& expr_orig, const A return RemoveJacobianAndLiftNonzeroCondImpl(new_red, axis, vranges); } - PrimExpr new_outer_cond, new_reduce_cond; Array new_source = red->source; // Partially lift conditions from the reduce condition - std::tie(new_outer_cond, new_reduce_cond) = + auto [new_outer_cond, new_reduce_cond] = LiftConditionsThroughReduction(red->condition, red->axis, axis); // If it's not sum then we haven't yet lifted nonzeroness cond from the source if (!is_sum) { - PrimExpr outer_nz_cond, nz_cond, nz_source; auto nz = NonzeronessCondition(red->source[red->value_index]); // Append conditions from the reduction - nz_cond = new_reduce_cond && nz.cond; - nz_source = nz.value; - std::tie(outer_nz_cond, nz_cond) = LiftConditionsThroughReduction(nz_cond, red->axis, axis); + PrimExpr nz_source = nz.value; + auto [outer_nz_cond, nz_cond] = + LiftConditionsThroughReduction(new_reduce_cond && nz.cond, red->axis, axis); new_outer_cond = new_outer_cond && outer_nz_cond; new_source.Set(red->value_index, Select(nz_cond, nz_source, make_zero(nz_source.dtype()))); } diff --git a/src/te/autodiff/ad_utils.cc b/src/te/autodiff/ad_utils.cc index 268abab9cacb..0d1e4927cdfe 100644 --- a/src/te/autodiff/ad_utils.cc +++ b/src/te/autodiff/ad_utils.cc @@ -47,9 +47,7 @@ std::pair, Map> CloneIterVars(const Array PrimExpr CloneReduction(const PrimExpr& expr) { if (const ReduceNode* red = expr.as()) { - Array new_axis; - Map vmap; - std::tie(new_axis, vmap) = CloneIterVars(red->axis); + auto [new_axis, vmap] = CloneIterVars(red->axis); Array src_with_newaxis; for (const auto& src : red->source) { @@ -71,9 +69,7 @@ Operation ComputeOpFromExprs(const Array& exprs, const Array& const std::string& name, const std::string& tag, const Map& attrs, bool clone_axis) { if (clone_axis) { - Array new_axis = axis; - Map vmap; - std::tie(new_axis, vmap) = CloneIterVars(axis); + auto [new_axis, vmap] = CloneIterVars(axis); Array new_exprs; for (const PrimExpr& e : exprs) { new_exprs.push_back(Substitute(CloneReduction(e), vmap)); diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index 7104424957af..e61a590c409d 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -317,9 +317,7 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { // We have to clone the iteration axes because otherwise the original expression // cannot be used together with the derivative (it will lead to errors during lowering) - Array new_axis; - Map vmap; - std::tie(new_axis, vmap) = te::CloneIterVars(op->axis); + auto [new_axis, vmap] = te::CloneIterVars(op->axis); Array input_indices; size_t i = 0; diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index b9e99257f37c..fb09a3480a3a 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -558,9 +558,13 @@ bool IsWriteCache(const StmtSRef& block_sref) { } const BufferRegion& write_region = block->writes[0]; for (const BufferRegion& read_region : block->reads) { - bool exists, surjective, injective, ordered, no_const_read, no_shift_read; - std::tie(exists, surjective, injective, ordered, no_const_read, no_shift_read) = + auto [exists, surjective, injective, ordered, no_const_read, no_shift_read] = AnalyzeReadWritePattern(read_region, write_region); + // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 + (void)exists; + (void)surjective; + (void)no_const_read; + (void)no_shift_read; if (!(injective && ordered)) { return false; } @@ -2118,8 +2122,7 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } // Cond 6. Can successfully calculating the cumulative loop length. - int64_t cum_space_len, cum_reduce_len; - std::tie(cum_space_len, cum_reduce_len) = GetCumulativeSpaceAndReductionLength(self, block_sref); + auto [cum_space_len, cum_reduce_len] = GetCumulativeSpaceAndReductionLength(self, block_sref); if (cum_space_len == -1 || cum_reduce_len == -1) { return false; } diff --git a/src/tir/schedule/primitive/block_annotate.cc b/src/tir/schedule/primitive/block_annotate.cc index 31c938313fed..0912e36836e3 100644 --- a/src/tir/schedule/primitive/block_annotate.cc +++ b/src/tir/schedule/primitive/block_annotate.cc @@ -82,9 +82,7 @@ class NonAllocatedBufferError : public ScheduleError { static StmtSRef CheckAndGetBufferAllocationSite(const IRModule& mod, const StmtSRef& block_sref, const Buffer& buffer) { - Optional defining_site_sref; - bool is_alloc; - std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, buffer); + auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, buffer); if (!defining_site_sref.defined() || !is_alloc) { throw NonAllocatedBufferError(mod, buffer); } diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index b4e40fa120fe..8e2643db0103 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -137,9 +137,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); - Optional defining_site_sref; - bool is_alloc; - std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer); + auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); } @@ -155,9 +153,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ Buffer new_buffer{new_buffer_node}; // Step 2: Rewrite access indices and regions of the buffer - Stmt new_stmt; - Map block_sref_reuse; - std::tie(new_stmt, block_sref_reuse) = TransformLayoutRewriter::Rewrite( + auto [new_stmt, block_sref_reuse] = TransformLayoutRewriter::Rewrite( GetRef(scope_block), old_buffer, new_buffer, index_map); Block new_scope_block = Downcast(new_stmt); @@ -492,9 +488,7 @@ void SetAxisSeparator(ScheduleState self, const StmtSRef& block_sref, int buffer const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); - Optional defining_site_sref; - bool is_alloc; - std::tie(defining_site_sref, is_alloc) = GetBufferDefiningSite(block_sref, old_buffer); + auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); } diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 2db3eb902aba..992817e87e2d 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -704,9 +704,7 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { // the input array // - the bottom of the reorder range is the last loop in the input array which is not visited in // the previous traversals - const StmtSRefNode* top = nullptr; - const StmtSRefNode* bottom = nullptr; - std::tie(top, bottom) = GetBoundaryOfReorderRange(self, loop_srefs); + auto [top, bottom] = GetBoundaryOfReorderRange(self, loop_srefs); // Step 3. Collect all loops in the chain and check the loops are single-branch std::vector chain = GetLoopsInReorderRange(self, top, bottom); // Step 4. Check the block below has all its block_var to be data-parallel or reduction, diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 7a4ace736e48..1198e67d710a 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -278,9 +278,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, body = Substitute(body, loop_var_map); // Step 6. Mutate IR const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(scope_root_sref); - Block new_scope_root{nullptr}; - Block new_reduction_block{nullptr}; - std::tie(new_scope_root, new_reduction_block) = DecomposeReductionBlockReplacer::Replace( + auto [new_scope_root, new_reduction_block] = DecomposeReductionBlockReplacer::Replace( GetRef(old_scope_root), GetRef(loop), body, GetRef(block)); self->Replace(scope_root_sref, new_scope_root, {{GetRef(old_scope_root), new_scope_root}, @@ -1042,12 +1040,8 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the // reduction combiner. The lhs will be used when constructing the write-back block, and the rhs // will be used when constructing the rfactor block. - BufferStore init; - BufferStore update; - CommReducer reducer; - PrimExpr combiner_lhs, combiner_rhs; - std::tie(init, update) = GetBufferStoresFromReductionBlock(self, block); - std::tie(reducer, combiner_lhs, combiner_rhs) = + auto [init, update] = GetBufferStoresFromReductionBlock(self, block); + auto [reducer, combiner_lhs, combiner_rhs] = GetReducerAndCombinerLhsRhs(self, init->value, update); // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 52b5add2bc9e..b1001a7f9455 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -348,9 +348,7 @@ tir::StmtSRef SampleComputeLocation(tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, const StmtSRef& block_sref, Optional* decision) { // Step 1. Collect all possible compute-at locations. - Array location_srefs; - std::vector location_indices; - std::tie(location_srefs, location_indices) = CollectComputeLocation(self, block_sref); + auto [location_srefs, location_indices] = CollectComputeLocation(self, block_sref); ICHECK_EQ(location_srefs.size(), location_indices.size()); // Step 2. If there was a previous decision, keep the decision unchanged if it exists in the diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index 677506889e57..6ecc6459b904 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -29,6 +29,7 @@ #include #include +#include #include #include @@ -553,25 +554,39 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim if (finder.partitions.empty()) return Stmt(); arith::IntervalSet for_interval(min, max); - bool cond_value; - IntSet middle_interval; - ExpressionSet cond_set; - // find an interval in which all conditions on var are true - std::tie(middle_interval, cond_set) = - GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_); - if (middle_interval.IsNothing()) { - // if such interval doesn't exist, find an interval in which all - // conditions on var are false - std::tie(middle_interval, cond_set) = - GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_); - if (middle_interval.IsNothing()) - // we couldn't find an interval in which the conditions are provably true or false - // Therefore, we can't partition the loop based on those conds - return Stmt(); - cond_value = false; - } else { - cond_value = true; + + auto [middle_interval, cond_set, + opt_cond_value] = [&]() -> std::tuple> { + { + // find an interval in which all conditions on var are true + auto [middle_interval, cond_set] = + GetIntervalAndCondset(finder.partitions, for_interval, true, has_partition_hint_); + if (!middle_interval.IsNothing()) { + return {middle_interval, cond_set, true}; + } + } + + { + // if such interval doesn't exist, find an interval in which all + // conditions on var are false + auto [middle_interval, cond_set] = + GetIntervalAndCondset(finder.partitions, for_interval, false, has_partition_hint_); + + if (!middle_interval.IsNothing()) { + return {middle_interval, cond_set, false}; + } + } + + // we couldn't find an interval in which the conditions are + // provably true or false. Therefore, we can't partition the loop + // based on those conds + return {{}, {}, std::nullopt}; + }(); + + if (!opt_cond_value.has_value()) { + return Stmt(); } + bool cond_value = opt_cond_value.value(); IntervalSet middle_interval_i = Downcast(middle_interval); // middle_interval is the subrange of the loop variable range for which a diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index df8bf69e7468..04b025b5f9ae 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -497,14 +497,10 @@ class CrossThreadReductionTransformer : public StmtMutator { // both be BufferStores with the same buffer and indices; // Extract the commutative reducer, combiner lhs and combiner rhs from the reduction identity // and the reduction combiner. - BufferStore init{nullptr}; - BufferStore update{nullptr}; - CommReducer reducer{nullptr}; - PrimExpr combiner_lhs{nullptr}; - PrimExpr combiner_rhs{nullptr}; - std::tie(init, update) = GetBufferStoresFromReductionBlock(NullOpt, GetRef(block)); - std::tie(reducer, combiner_lhs, combiner_rhs) = + auto [init, update] = GetBufferStoresFromReductionBlock(NullOpt, GetRef(block)); + auto [reducer, combiner_lhs, combiner_rhs] = GetReducerAndCombinerLhsRhs(NullOpt, init->value, update); + (void)combiner_lhs; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 // Condition 5. The block should be the last block under the first reduction-related loop. bool visit = false; @@ -577,10 +573,7 @@ class CrossThreadReductionTransformer : public StmtMutator { ++reduction_id_; // Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on // which condition the block violates. - int n_bound_reduction_loops = 0; - CommReducer reducer{nullptr}; - PrimExpr combiner_rhs{nullptr}; - std::tie(n_bound_reduction_loops, reducer, combiner_rhs) = + auto [n_bound_reduction_loops, reducer, combiner_rhs] = CheckCanApplyCrossThreadReduction(block, reduction_loops); // Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when // - not all the reduction-related loops are bound to thread axes, or diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 43f7a103db7f..bd6b5185eb4a 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -301,9 +301,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // sort according to dim_index std::sort(block_threads.begin(), block_threads.end()); for (auto&& thr_attr : block_threads) { - int dim_index, extent; - bool is_reduce; - std::tie(dim_index, extent, is_reduce) = thr_attr; + auto [dim_index, extent, is_reduce] = thr_attr; + (void)dim_index; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 if (is_reduce) { contiguous_reduce_extent *= extent; } else { diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 408cdbd04ec7..e12e2772ab22 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -311,8 +311,8 @@ class WarpAccessRewriter : protected StmtExprMutator { << "Has StorageFlatten (TE-based schedule) or " << "FlattenBuffer (TIR-based schedules) been run?"; - PrimExpr local_index, group; - std::tie(local_index, group) = SplitIndexByGroup(store->indices[0]); + auto [local_index, group] = SplitIndexByGroup(store->indices[0]); + (void)group; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 auto writer = store.CopyOnWrite(); writer->indices = {local_index}; @@ -332,8 +332,7 @@ class WarpAccessRewriter : protected StmtExprMutator { << "Has StorageFlatten (TE-based schedule) or " << "FlattenBuffer (TIR-based schedules) been run?"; - PrimExpr local_index, group; - std::tie(local_index, group) = SplitIndexByGroup(op->indices[0]); + auto [local_index, group] = SplitIndexByGroup(op->indices[0]); // invariance: local index must do not contain warp id ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->indices[0] @@ -357,12 +356,10 @@ class WarpAccessRewriter : protected StmtExprMutator { // in this access pattern. std::pair SplitIndexByGroup(const PrimExpr& index) { if (index.dtype().lanes() != 1) { - PrimExpr local_index, group; - arith::PVar base; ICHECK(arith::ramp(base, 1, index.dtype().lanes()).Match(index)); - std::tie(local_index, group) = SplitIndexByGroup(base.Eval()); + auto [local_index, group] = SplitIndexByGroup(base.Eval()); local_index = Ramp(local_index, make_const(local_index.dtype(), 1), index.dtype().lanes()); return std::make_pair(local_index, group); } diff --git a/src/tir/transforms/manifest_shared_memory_local_stage.cc b/src/tir/transforms/manifest_shared_memory_local_stage.cc index 16c85642d1e5..0f56c8b8b7c9 100644 --- a/src/tir/transforms/manifest_shared_memory_local_stage.cc +++ b/src/tir/transforms/manifest_shared_memory_local_stage.cc @@ -61,9 +61,7 @@ class IntermediateStageRewriter { std::vector relaxed_loops = CollectRelaxedOuterLoops(block, target_buffer); // Step 1: Create buffer for the local stage - Buffer new_buffer{nullptr}; - Array buffer_indices; - std::tie(new_buffer, buffer_indices) = CreateIntermediateBuffer(relaxed_loops, target_buffer); + auto [new_buffer, buffer_indices] = CreateIntermediateBuffer(relaxed_loops, target_buffer); // Step 2: Create the local stage block Stmt local_stage = MakeLocalStage(block, new_buffer, buffer_indices, relaxed_loops, store); @@ -190,12 +188,8 @@ class SharedMemoryLocalStageInserter : public StmtMutator { // The annotated block must be a leaf block (will be checked during rewriting). No need to // visit its body recursively. - Buffer target_buffer{nullptr}; - Buffer new_buffer{nullptr}; - Block new_block{nullptr}; - Stmt local_stage{nullptr}; IntermediateStageRewriter rewriter(ancestor_loop_or_blocks_); - std::tie(target_buffer, new_buffer, new_block, local_stage) = rewriter.Rewrite(op); + auto [target_buffer, new_buffer, new_block, local_stage] = rewriter.Rewrite(op); buffer_remap_.Set(target_buffer, new_buffer); new_block.CopyOnWrite()->annotations.erase(attr::manifest_shared_memory_local_stage);