From 7693a17e7550a665ef3b16848e032b5b29f90475 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 1 Feb 2022 14:24:11 -0800 Subject: [PATCH 1/3] Reenable warning about unscheduled update definitions and fix associated issues in the tests and apps. This is an old warning that stopped triggering because it wasn't tested. We should either remove it, fix the trigger conditions, or perhaps make it an error. This PR fixes the trigger conditions and fixes all instances of the warning in our tests and apps. The warning triggers if you schedule some but not all of the update definitions of a Func. It's to protect against the common error of only scheduling the pure definition of something like a summation. The warning can be suppressed by inserting a call to func.update(idx). --- apps/HelloMatlab/iir_blur.cpp | 2 ++ apps/fft/fft.cpp | 4 +++ .../linear_algebra/src/blas_l1_generators.cpp | 1 + src/Derivative.cpp | 2 +- src/Func.cpp | 26 ++++++++++++++++++- src/Func.h | 1 - src/ScheduleFunctions.cpp | 4 +-- .../adams2019/cost_model_generator.cpp | 2 +- test/correctness/atomics.cpp | 13 +++++----- test/correctness/compute_with.cpp | 8 ++++++ test/correctness/extern_bounds_inference.cpp | 1 + test/correctness/named_updates.cpp | 2 ++ test/correctness/tuple_reduction.cpp | 11 +++----- test/warning/CMakeLists.txt | 1 + test/warning/unscheduled_update_def.cpp | 17 ++++++++++++ 15 files changed, 76 insertions(+), 19 deletions(-) create mode 100644 test/warning/unscheduled_update_def.cpp diff --git a/apps/HelloMatlab/iir_blur.cpp b/apps/HelloMatlab/iir_blur.cpp index 3507ad38c2b5..bfcd31fbfa9f 100644 --- a/apps/HelloMatlab/iir_blur.cpp +++ b/apps/HelloMatlab/iir_blur.cpp @@ -45,6 +45,8 @@ Func blur_cols_transpose(Func input, Expr height, Expr alpha) { blur.compute_at(transpose, yo); // Vectorize computations within the strips. + blur.update(0) + .vectorize(x); blur.update(1) .reorder(x, ry) .vectorize(x); diff --git a/apps/fft/fft.cpp b/apps/fft/fft.cpp index 24c6ab04f8e0..993612d08c2d 100644 --- a/apps/fft/fft.cpp +++ b/apps/fft/fft.cpp @@ -871,6 +871,10 @@ ComplexFunc fft2d_r2c(Func r, dft.update(4).allow_race_conditions().vectorize(n0z1, vector_size); dft.update(5).allow_race_conditions().vectorize(n0z2, vector_size); + // Intentionally serial + dft.update(0); + dft.update(3); + // Our result is undefined outside these bounds. dft.bound(n0, 0, N0); dft.bound(n1, 0, (N1 + 1) / 2 + 1); diff --git a/apps/linear_algebra/src/blas_l1_generators.cpp b/apps/linear_algebra/src/blas_l1_generators.cpp index eb59b74d56e8..2ce48e8ecc2c 100644 --- a/apps/linear_algebra/src/blas_l1_generators.cpp +++ b/apps/linear_algebra/src/blas_l1_generators.cpp @@ -60,6 +60,7 @@ class AXPYGenerator : public Generator> { Var ii("ii"); result_.update().vectorize(vecs, vec_size); } + result_.update(1); // Leave the tail unvectorized result_.bound(i, 0, x_.width()); result_.dim(0).set_bounds(0, x_.width()); diff --git a/src/Derivative.cpp b/src/Derivative.cpp index fb69522de580..c5ba26253367 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -1808,7 +1808,7 @@ void ReverseAccumulationVisitor::propagate_halide_function_call( // If previous update has a different set of reduction variables, // don't merge const vector &rvars = - func_to_update.update(update_id).get_schedule().rvars(); + func_to_update.function().update(update_id).schedule().rvars(); if (!merged_r.defined()) { return rvars.empty(); } diff --git a/src/Func.cpp b/src/Func.cpp index 35e89f57cc71..79ed61671e80 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -344,6 +344,7 @@ bool is_const_assignment(const string &func_name, const vector &args, cons } // namespace void Stage::set_dim_type(const VarOrRVar &var, ForType t) { + definition.schedule().touched() = true; bool found = false; vector &dims = definition.schedule().dims(); for (auto &dim : dims) { @@ -407,6 +408,7 @@ void Stage::set_dim_type(const VarOrRVar &var, ForType t) { } void Stage::set_dim_device_api(const VarOrRVar &var, DeviceAPI device_api) { + definition.schedule().touched() = true; bool found = false; vector &dims = definition.schedule().dims(); for (auto &dim : dims) { @@ -662,12 +664,15 @@ bool apply_split_directive(const Split &s, vector &rvars, } // anonymous namespace Func Stage::rfactor(const RVar &r, const Var &v) { + definition.schedule().touched() = true; return rfactor({{r, v}}); } Func Stage::rfactor(vector> preserved) { user_assert(!definition.is_init()) << "rfactor() must be called on an update definition\n"; + definition.schedule().touched() = true; + const string &func_name = function.name(); vector &args = definition.args(); vector &values = definition.values(); @@ -969,6 +974,8 @@ void Stage::split(const string &old, const string &outer, const string &inner, c << outer << " and " << inner << " with factor of " << factor << "\n"; vector &dims = definition.schedule().dims(); + definition.schedule().touched() = true; + // Check that the new names aren't already in the dims list. for (auto &dim : dims) { string new_names[2] = {inner, outer}; @@ -1116,6 +1123,7 @@ void Stage::split(const string &old, const string &outer, const string &inner, c } Stage &Stage::split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVar &inner, const Expr &factor, TailStrategy tail) { + definition.schedule().touched() = true; if (old.is_rvar) { user_assert(outer.is_rvar) << "Can't split RVar " << old.name() << " into Var " << outer.name() << "\n"; user_assert(inner.is_rvar) << "Can't split RVar " << old.name() << " into Var " << inner.name() << "\n"; @@ -1128,6 +1136,7 @@ Stage &Stage::split(const VarOrRVar &old, const VarOrRVar &outer, const VarOrRVa } Stage &Stage::fuse(const VarOrRVar &inner, const VarOrRVar &outer, const VarOrRVar &fused) { + definition.schedule().touched() = true; if (!fused.is_rvar) { user_assert(!outer.is_rvar) << "Can't fuse Var " << fused.name() << " from RVar " << outer.name() << "\n"; @@ -1211,6 +1220,8 @@ class CheckForFreeVars : public IRGraphVisitor { Stage Stage::specialize(const Expr &condition) { user_assert(condition.type().is_bool()) << "Argument passed to specialize must be of type bool\n"; + definition.schedule().touched() = true; + // The condition may not depend on Vars or RVars Internal::CheckForFreeVars check; condition.accept(&check); @@ -1242,6 +1253,9 @@ void Stage::specialize_fail(const std::string &message) { const vector &specializations = definition.specializations(); user_assert(specializations.empty() || specializations.back().failure_message.empty()) << "Only one specialize_fail() may be defined per Stage."; + + definition.schedule().touched() = true; + (void)definition.add_specialization(const_true()); Specialization &s = definition.specializations().back(); s.failure_message = message; @@ -1383,6 +1397,8 @@ void Stage::remove(const string &var) { } Stage &Stage::rename(const VarOrRVar &old_var, const VarOrRVar &new_var) { + definition.schedule().touched() = true; + if (old_var.is_rvar) { user_assert(new_var.is_rvar) << "In schedule for " << name() @@ -1472,11 +1488,13 @@ Stage &Stage::rename(const VarOrRVar &old_var, const VarOrRVar &new_var) { } Stage &Stage::allow_race_conditions() { + definition.schedule().touched() = true; definition.schedule().allow_race_conditions() = true; return *this; } Stage &Stage::atomic(bool override_associativity_test) { + definition.schedule().touched() = true; definition.schedule().atomic() = true; definition.schedule().override_atomic_associativity_test() = override_associativity_test; return *this; @@ -1600,6 +1618,7 @@ Stage &Stage::tile(const std::vector &previous, } Stage &Stage::reorder(const std::vector &vars) { + definition.schedule().touched() = true; const string &func_name = function.name(); vector &args = definition.args(); vector &values = definition.values(); @@ -1839,18 +1858,21 @@ Stage &Stage::hexagon(const VarOrRVar &x) { } Stage &Stage::prefetch(const Func &f, const VarOrRVar &at, const VarOrRVar &from, Expr offset, PrefetchBoundStrategy strategy) { + definition.schedule().touched() = true; PrefetchDirective prefetch = {f.name(), at.name(), from.name(), std::move(offset), strategy, Parameter()}; definition.schedule().prefetches().push_back(prefetch); return *this; } Stage &Stage::prefetch(const Internal::Parameter ¶m, const VarOrRVar &at, const VarOrRVar &from, Expr offset, PrefetchBoundStrategy strategy) { + definition.schedule().touched() = true; PrefetchDirective prefetch = {param.name(), at.name(), from.name(), std::move(offset), strategy, param}; definition.schedule().prefetches().push_back(prefetch); return *this; } Stage &Stage::compute_with(LoopLevel loop_level, const map &align) { + definition.schedule().touched() = true; loop_level.lock(); user_assert(!loop_level.is_inlined() && !loop_level.is_root()) << "Undefined loop level to compute with\n"; @@ -2738,7 +2760,9 @@ void Func::debug_to_file(const string &filename) { Stage Func::update(int idx) { user_assert(idx < num_update_definitions()) << "Call to update with index larger than last defined update stage for Func \"" << name() << "\".\n"; invalidate_cache(); - return Stage(func, func.update(idx), idx + 1); + Definition d = func.update(idx); + d.schedule().touched() = true; + return Stage(func, d, idx + 1); } Func::operator Stage() const { diff --git a/src/Func.h b/src/Func.h index bbfffeab92d5..bc853a4aee0b 100644 --- a/src/Func.h +++ b/src/Func.h @@ -94,7 +94,6 @@ class Stage { Stage(Internal::Function f, Internal::Definition d, size_t stage_index) : function(std::move(f)), definition(std::move(d)), stage_index(stage_index) { internal_assert(definition.defined()); - definition.schedule().touched() = true; dim_vars.reserve(function.args().size()); for (const auto &arg : function.args()) { diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index 5f0285a921f2..ba87dec47786 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -2078,10 +2078,10 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_ const Definition &r = f.update((int)i); if (!r.schedule().touched()) { user_warning - << "Warning: Update step " << i + << "Update definition " << i << " of function " << f.name() << " has not been scheduled, even though some other" - << " steps have been. You may have forgotten to" + << " definitions have been. You may have forgotten to" << " schedule it. If this was intentional, call " << f.name() << ".update(" << i << ") to suppress" << " this warning.\n"; diff --git a/src/autoschedulers/adams2019/cost_model_generator.cpp b/src/autoschedulers/adams2019/cost_model_generator.cpp index ab864f588e51..dfca665505b1 100644 --- a/src/autoschedulers/adams2019/cost_model_generator.cpp +++ b/src/autoschedulers/adams2019/cost_model_generator.cpp @@ -533,7 +533,7 @@ class CostModel : public Generator> { }; // Pipeline features processing - conv1_stage1.compute_root().vectorize(c); + conv1_stage1.compute_root().vectorize(c).update().vectorize(c); squashed_head1_filter.compute_root().vectorize(c); // Schedule features processing. The number of schedule diff --git a/test/correctness/atomics.cpp b/test/correctness/atomics.cpp index 3832a1f480df..205aacf0984a 100644 --- a/test/correctness/atomics.cpp +++ b/test/correctness/atomics.cpp @@ -325,14 +325,15 @@ void test_predicated_hist(const Backend &backend) { hist(im(r2)) = min(hist(im(r2)) + cast(1), cast(100)); // cas loop hist.compute_root(); - for (int update_id = 0; update_id < 3; update_id++) { + for (int update_id = 0; update_id < hist.num_update_definitions(); update_id++) { + RVar rv = update_id < 3 ? r : r2; switch (backend) { case Backend::CPU: { // Can't prove associativity. // Set override_associativity_test to true to remove the check. hist.update(update_id) .atomic(true /*override_associativity_test*/) - .parallel(r); + .parallel(rv); } break; case Backend::CPUVectorize: { // Doesn't support predicated store yet. @@ -344,7 +345,7 @@ void test_predicated_hist(const Backend &backend) { RVar ro, ri; hist.update(update_id) .atomic(true /*override_associativity_test*/) - .split(r, ro, ri, 32) + .split(rv, ro, ri, 32) .gpu_blocks(ro, DeviceAPI::OpenCL) .gpu_threads(ri, DeviceAPI::OpenCL); } break; @@ -354,7 +355,7 @@ void test_predicated_hist(const Backend &backend) { RVar ro, ri; hist.update(update_id) .atomic(true /*override_associativity_test*/) - .split(r, ro, ri, 32) + .split(rv, ro, ri, 32) .gpu_blocks(ro, DeviceAPI::CUDA) .gpu_threads(ri, DeviceAPI::CUDA); } break; @@ -363,7 +364,7 @@ void test_predicated_hist(const Backend &backend) { RVar rio, rii; hist.update(update_id) .atomic(true /*override_assciativity_test*/) - .split(r, ro, ri, 32) + .split(rv, ro, ri, 32) .split(ri, rio, rii, 4) .gpu_blocks(ro, DeviceAPI::CUDA) .gpu_threads(rio, DeviceAPI::CUDA) @@ -531,7 +532,7 @@ void test_nested_atomics(const Backend &backend) { Expr new_max = max(im(r), old_max); arg_max() = {new_index, new_max}; - im.compute_inline().atomic(); + im.compute_inline().atomic().update().atomic(); arg_max.compute_root(); switch (backend) { case Backend::CPU: { diff --git a/test/correctness/compute_with.cpp b/test/correctness/compute_with.cpp index b978d39b4f71..c5dadfbaf727 100644 --- a/test/correctness/compute_with.cpp +++ b/test/correctness/compute_with.cpp @@ -242,6 +242,9 @@ int multiple_fuse_group_test() { p.fuse(x, y, t).parallel(t); h.fuse(x, y, t).parallel(t); h.compute_with(p, t); + h.update(0); // unfused + h.update(1); // unfused + h.update(2); // unfused f.update(0).compute_with(g, y, LoopAlignStrategy::AlignEnd); f.compute_with(g, x); @@ -1278,6 +1281,8 @@ int update_stage_test() { f.compute_root(); f.update(1).compute_with(g.update(0), y); + f.update(0); // unfused + g.update(1); // unfused g.bound(x, 0, g_size).bound(y, 0, g_size); f.bound(x, 0, f_size).bound(y, 0, f_size); @@ -1351,6 +1356,7 @@ int update_stage2_test() { f.update(0).compute_with(g.update(0), y); f.update(1).compute_with(g.update(0), y); + g.update(1); // unfused g.bound(x, 0, g_size).bound(y, 0, g_size); f.bound(x, 0, f_size).bound(y, 0, f_size); @@ -1659,6 +1665,8 @@ int update_stage_diagonal_test() { f.update(1).compute_with(g.update(0), y); g.update(0).compute_with(h, y); + f.update(0); + g.update(1); g.bound(x, 0, g_size).bound(y, 0, g_size); f.bound(x, 0, f_size).bound(y, 0, f_size); diff --git a/test/correctness/extern_bounds_inference.cpp b/test/correctness/extern_bounds_inference.cpp index 5dd537529e35..79c1cf5b5675 100644 --- a/test/correctness/extern_bounds_inference.cpp +++ b/test/correctness/extern_bounds_inference.cpp @@ -118,6 +118,7 @@ int main(int argc, char **argv) { f1.compute_at(g, y); f2.compute_at(g, x); g.reorder(y, x).vectorize(y, 4); + g.update(); g.infer_input_bounds({W, H}); diff --git a/test/correctness/named_updates.cpp b/test/correctness/named_updates.cpp index 6f50ff0cd514..55e366a3d097 100644 --- a/test/correctness/named_updates.cpp +++ b/test/correctness/named_updates.cpp @@ -41,6 +41,8 @@ int main(int argc, char **argv) { more_updates.a.vectorize(r, 4); more_updates.b.vectorize(r, 4); more_updates.c.vectorize(r, 4); + + f.update(); // fix_first isn't scheduled } // Define the same thing without all the weird syntax and without diff --git a/test/correctness/tuple_reduction.cpp b/test/correctness/tuple_reduction.cpp index 894ac98241ba..8120f45f85a6 100644 --- a/test/correctness/tuple_reduction.cpp +++ b/test/correctness/tuple_reduction.cpp @@ -61,14 +61,13 @@ int main(int argc, char **argv) { f.hexagon(y).vectorize(x, 32); } for (int i = 0; i < 10; i++) { + f.update(i); if (i & 1) { if (target.has_gpu_feature()) { f.update(i).gpu_tile(x, y, xo, yo, xi, yi, 16, 16); } else if (target.has_feature(Target::HVX)) { f.update(i).hexagon(y).vectorize(x, 32); } - } else { - f.update(i); } } @@ -103,14 +102,13 @@ int main(int argc, char **argv) { // Schedule the even update steps on the gpu for (int i = 0; i < 10; i++) { + f.update(i); if (i & 1) { if (target.has_gpu_feature()) { f.update(i).gpu_tile(x, y, xo, yo, xi, yi, 16, 16); } else if (target.has_feature(Target::HVX)) { f.update(i).hexagon(y).vectorize(x, 32); } - } else { - f.update(i); } } @@ -146,9 +144,8 @@ int main(int argc, char **argv) { // Schedule the even update steps on the gpu for (int i = 0; i < 10; i++) { - if (i & 1) { - f.update(i); - } else { + f.update(i); + if ((i & 1) == 0) { if (target.has_gpu_feature()) { f.update(i).gpu_tile(x, y, xo, yo, xi, yi, 16, 16); } else if (target.has_feature(Target::HVX)) { diff --git a/test/warning/CMakeLists.txt b/test/warning/CMakeLists.txt index 3106da1d1cbe..a4d60cd4ffa0 100644 --- a/test/warning/CMakeLists.txt +++ b/test/warning/CMakeLists.txt @@ -3,6 +3,7 @@ tests(GROUPS warning hidden_pure_definition.cpp require_const_false.cpp sliding_vectors.cpp + unscheduled_update_def.cpp ) # Don't look for "Success!" in warning tests, look for "Warning:" instead. diff --git a/test/warning/unscheduled_update_def.cpp b/test/warning/unscheduled_update_def.cpp new file mode 100644 index 000000000000..fdefccff1308 --- /dev/null +++ b/test/warning/unscheduled_update_def.cpp @@ -0,0 +1,17 @@ +#include "Halide.h" + +using namespace Halide; + +int main(int argc, char **argv) { + Func f; + Var x; + + f(x) = 0; + f(x) += 5; + + f.vectorize(x, 8); + + f.realize({1024}); + + return 0; +} From 8e8ea9b0e0cc4c961863820abb57db2a7236bebb Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 2 Feb 2022 10:36:57 -0800 Subject: [PATCH 2/3] Fix atomics test --- test/correctness/atomics.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/test/correctness/atomics.cpp b/test/correctness/atomics.cpp index 205aacf0984a..21050ee3d6d1 100644 --- a/test/correctness/atomics.cpp +++ b/test/correctness/atomics.cpp @@ -320,20 +320,21 @@ void test_predicated_hist(const Backend &backend) { hist(im(r)) = min(hist(im(r)) + cast(1), cast(100)); // cas loop RDom r2(0, img_size); + // This predicate means that the update definitions below can't actually be + // atomic, because the if isn't included in the atomic block. r2.where(hist(im(r2)) > cast(0) && hist(im(r2)) < cast(90)); - hist(im(r2)) -= cast(1); // atomic add - hist(im(r2)) = min(hist(im(r2)) + cast(1), cast(100)); // cas loop + hist(im(r2)) -= cast(1); + hist(im(r2)) = min(hist(im(r2)) + cast(1), cast(100)); hist.compute_root(); - for (int update_id = 0; update_id < hist.num_update_definitions(); update_id++) { - RVar rv = update_id < 3 ? r : r2; + for (int update_id = 0; update_id < 3; update_id++) { switch (backend) { case Backend::CPU: { // Can't prove associativity. // Set override_associativity_test to true to remove the check. hist.update(update_id) .atomic(true /*override_associativity_test*/) - .parallel(rv); + .parallel(r); } break; case Backend::CPUVectorize: { // Doesn't support predicated store yet. @@ -345,7 +346,7 @@ void test_predicated_hist(const Backend &backend) { RVar ro, ri; hist.update(update_id) .atomic(true /*override_associativity_test*/) - .split(rv, ro, ri, 32) + .split(r, ro, ri, 32) .gpu_blocks(ro, DeviceAPI::OpenCL) .gpu_threads(ri, DeviceAPI::OpenCL); } break; @@ -355,7 +356,7 @@ void test_predicated_hist(const Backend &backend) { RVar ro, ri; hist.update(update_id) .atomic(true /*override_associativity_test*/) - .split(rv, ro, ri, 32) + .split(r, ro, ri, 32) .gpu_blocks(ro, DeviceAPI::CUDA) .gpu_threads(ri, DeviceAPI::CUDA); } break; @@ -364,7 +365,7 @@ void test_predicated_hist(const Backend &backend) { RVar rio, rii; hist.update(update_id) .atomic(true /*override_assciativity_test*/) - .split(rv, ro, ri, 32) + .split(r, ro, ri, 32) .split(ri, rio, rii, 4) .gpu_blocks(ro, DeviceAPI::CUDA) .gpu_threads(rio, DeviceAPI::CUDA) From d82a456e69324061bdb8e3815495b8b708adbac2 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Thu, 17 Feb 2022 13:44:35 -0800 Subject: [PATCH 3/3] Add Stage::unscheduled() --- apps/fft/fft.cpp | 4 ++-- apps/linear_algebra/src/blas_l1_generators.cpp | 2 +- src/Func.cpp | 9 ++++++--- src/Func.h | 6 ++++++ src/ScheduleFunctions.cpp | 2 +- test/correctness/atomics.cpp | 3 +++ test/correctness/compute_with.cpp | 14 ++++++-------- test/correctness/extern_bounds_inference.cpp | 2 +- test/correctness/named_updates.cpp | 2 +- test/correctness/parallel_reductions.cpp | 2 +- test/correctness/sliding_reduction.cpp | 4 ++-- test/correctness/tuple_reduction.cpp | 6 +++--- test/correctness/vectorized_initialization.cpp | 2 +- 13 files changed, 34 insertions(+), 24 deletions(-) diff --git a/apps/fft/fft.cpp b/apps/fft/fft.cpp index 993612d08c2d..79382129c763 100644 --- a/apps/fft/fft.cpp +++ b/apps/fft/fft.cpp @@ -872,8 +872,8 @@ ComplexFunc fft2d_r2c(Func r, dft.update(5).allow_race_conditions().vectorize(n0z2, vector_size); // Intentionally serial - dft.update(0); - dft.update(3); + dft.update(0).unscheduled(); + dft.update(3).unscheduled(); // Our result is undefined outside these bounds. dft.bound(n0, 0, N0); diff --git a/apps/linear_algebra/src/blas_l1_generators.cpp b/apps/linear_algebra/src/blas_l1_generators.cpp index 64864de3e43e..c25025cde85a 100644 --- a/apps/linear_algebra/src/blas_l1_generators.cpp +++ b/apps/linear_algebra/src/blas_l1_generators.cpp @@ -60,7 +60,7 @@ class AXPYGenerator : public Generator> { Var ii("ii"); result_.update().vectorize(vecs, vec_size); } - result_.update(1); // Leave the tail unvectorized + result_.update(1).unscheduled(); // Leave the tail unvectorized result_.bound(i, 0, x_.width()); result_.dim(0).set_bounds(0, x_.width()); diff --git a/src/Func.cpp b/src/Func.cpp index 2103422f6915..bb5842dd0b4d 100644 --- a/src/Func.cpp +++ b/src/Func.cpp @@ -1928,6 +1928,11 @@ std::string Stage::source_location() const { return definition.source_location(); } +void Stage::unscheduled() { + user_assert(!definition.schedule().touched()) << "Stage::unscheduled called on an update definition with a schedule\n"; + definition.schedule().touched() = true; +} + void Func::invalidate_cache() { if (pipeline_.defined()) { pipeline_.invalidate_cache(); @@ -2760,9 +2765,7 @@ void Func::debug_to_file(const string &filename) { Stage Func::update(int idx) { user_assert(idx < num_update_definitions()) << "Call to update with index larger than last defined update stage for Func \"" << name() << "\".\n"; invalidate_cache(); - Definition d = func.update(idx); - d.schedule().touched() = true; - return Stage(func, d, idx + 1); + return Stage(func, func.update(idx), idx + 1); } Func::operator Stage() const { diff --git a/src/Func.h b/src/Func.h index 37c38bb628d5..589e18709586 100644 --- a/src/Func.h +++ b/src/Func.h @@ -473,6 +473,12 @@ class Stage { * empty string if no debug symbols were found or the debug * symbols were not understood. Works on OS X and Linux only. */ std::string source_location() const; + + /** Assert that this stage has intentionally been given no schedule, and + * suppress the warning about unscheduled update definitions that would + * otherwise fire. This counts as a schedule, so calling this twice on the + * same Stage will fail the assertion. */ + void unscheduled(); }; // For backwards compatibility, keep the ScheduleHandle name. diff --git a/src/ScheduleFunctions.cpp b/src/ScheduleFunctions.cpp index ba87dec47786..52724e309ba6 100644 --- a/src/ScheduleFunctions.cpp +++ b/src/ScheduleFunctions.cpp @@ -2083,7 +2083,7 @@ bool validate_schedule(Function f, const Stmt &s, const Target &target, bool is_ << " has not been scheduled, even though some other" << " definitions have been. You may have forgotten to" << " schedule it. If this was intentional, call " - << f.name() << ".update(" << i << ") to suppress" + << f.name() << ".update(" << i << ").unscheduled() to suppress" << " this warning.\n"; } } diff --git a/test/correctness/atomics.cpp b/test/correctness/atomics.cpp index 21050ee3d6d1..9c1547d1be78 100644 --- a/test/correctness/atomics.cpp +++ b/test/correctness/atomics.cpp @@ -326,6 +326,9 @@ void test_predicated_hist(const Backend &backend) { hist(im(r2)) -= cast(1); hist(im(r2)) = min(hist(im(r2)) + cast(1), cast(100)); + hist.update(3).unscheduled(); + hist.update(4).unscheduled(); + hist.compute_root(); for (int update_id = 0; update_id < 3; update_id++) { switch (backend) { diff --git a/test/correctness/compute_with.cpp b/test/correctness/compute_with.cpp index c5dadfbaf727..8de8f660928d 100644 --- a/test/correctness/compute_with.cpp +++ b/test/correctness/compute_with.cpp @@ -242,9 +242,9 @@ int multiple_fuse_group_test() { p.fuse(x, y, t).parallel(t); h.fuse(x, y, t).parallel(t); h.compute_with(p, t); - h.update(0); // unfused - h.update(1); // unfused - h.update(2); // unfused + h.update(0).unscheduled(); + h.update(1).unscheduled(); + h.update(2).unscheduled(); f.update(0).compute_with(g, y, LoopAlignStrategy::AlignEnd); f.compute_with(g, x); @@ -1280,9 +1280,8 @@ int update_stage_test() { g.compute_root(); f.compute_root(); + f.update(0).unscheduled(); f.update(1).compute_with(g.update(0), y); - f.update(0); // unfused - g.update(1); // unfused g.bound(x, 0, g_size).bound(y, 0, g_size); f.bound(x, 0, f_size).bound(y, 0, f_size); @@ -1356,7 +1355,6 @@ int update_stage2_test() { f.update(0).compute_with(g.update(0), y); f.update(1).compute_with(g.update(0), y); - g.update(1); // unfused g.bound(x, 0, g_size).bound(y, 0, g_size); f.bound(x, 0, f_size).bound(y, 0, f_size); @@ -1665,8 +1663,8 @@ int update_stage_diagonal_test() { f.update(1).compute_with(g.update(0), y); g.update(0).compute_with(h, y); - f.update(0); - g.update(1); + f.update(0).unscheduled(); + g.update(1).unscheduled(); g.bound(x, 0, g_size).bound(y, 0, g_size); f.bound(x, 0, f_size).bound(y, 0, f_size); diff --git a/test/correctness/extern_bounds_inference.cpp b/test/correctness/extern_bounds_inference.cpp index 79c1cf5b5675..dab90168256f 100644 --- a/test/correctness/extern_bounds_inference.cpp +++ b/test/correctness/extern_bounds_inference.cpp @@ -118,7 +118,7 @@ int main(int argc, char **argv) { f1.compute_at(g, y); f2.compute_at(g, x); g.reorder(y, x).vectorize(y, 4); - g.update(); + g.update().unscheduled(); g.infer_input_bounds({W, H}); diff --git a/test/correctness/named_updates.cpp b/test/correctness/named_updates.cpp index 55e366a3d097..92df862fdf43 100644 --- a/test/correctness/named_updates.cpp +++ b/test/correctness/named_updates.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) { more_updates.b.vectorize(r, 4); more_updates.c.vectorize(r, 4); - f.update(); // fix_first isn't scheduled + f.update().unscheduled(); // fix_first isn't scheduled } // Define the same thing without all the weird syntax and without diff --git a/test/correctness/parallel_reductions.cpp b/test/correctness/parallel_reductions.cpp index 0c1e388b4a51..62960447a8b6 100644 --- a/test/correctness/parallel_reductions.cpp +++ b/test/correctness/parallel_reductions.cpp @@ -68,7 +68,7 @@ int main(int argc, char **argv) { sum_rows.compute_root().vectorize(i, 4).parallel(j); sum_rows.update().parallel(j); sum_cols.compute_root().vectorize(j, 4); - sum_cols.update(); + sum_cols.update().unscheduled(); out.output_buffer().dim(0).set_bounds(0, 256); Buffer result = out.realize({256}); diff --git a/test/correctness/sliding_reduction.cpp b/test/correctness/sliding_reduction.cpp index 3ce75056a09b..40e95252bca7 100644 --- a/test/correctness/sliding_reduction.cpp +++ b/test/correctness/sliding_reduction.cpp @@ -95,8 +95,8 @@ int main(int argc, char **argv) { f(x, y) = call_count(f(x, y)); f.unroll(y, 2); - f.update(0); - f.update(1); + f.update(0).unscheduled(); + f.update(1).unscheduled(); Func g("g"); g(x, y) = f(x, y) + f(x, y - 1) + f(x, y - 2); diff --git a/test/correctness/tuple_reduction.cpp b/test/correctness/tuple_reduction.cpp index 8120f45f85a6..576df9fb77e1 100644 --- a/test/correctness/tuple_reduction.cpp +++ b/test/correctness/tuple_reduction.cpp @@ -61,7 +61,7 @@ int main(int argc, char **argv) { f.hexagon(y).vectorize(x, 32); } for (int i = 0; i < 10; i++) { - f.update(i); + f.update(i).unscheduled(); if (i & 1) { if (target.has_gpu_feature()) { f.update(i).gpu_tile(x, y, xo, yo, xi, yi, 16, 16); @@ -102,7 +102,7 @@ int main(int argc, char **argv) { // Schedule the even update steps on the gpu for (int i = 0; i < 10; i++) { - f.update(i); + f.update(i).unscheduled(); if (i & 1) { if (target.has_gpu_feature()) { f.update(i).gpu_tile(x, y, xo, yo, xi, yi, 16, 16); @@ -144,7 +144,7 @@ int main(int argc, char **argv) { // Schedule the even update steps on the gpu for (int i = 0; i < 10; i++) { - f.update(i); + f.update(i).unscheduled(); if ((i & 1) == 0) { if (target.has_gpu_feature()) { f.update(i).gpu_tile(x, y, xo, yo, xi, yi, 16, 16); diff --git a/test/correctness/vectorized_initialization.cpp b/test/correctness/vectorized_initialization.cpp index f38fe5cdc100..ca01b6fb0fb8 100644 --- a/test/correctness/vectorized_initialization.cpp +++ b/test/correctness/vectorized_initialization.cpp @@ -19,7 +19,7 @@ int main(int argc, char **argv) { f(x) = x; f(r) = f(r - 1) + f(r + 1); f.compute_root().vectorize(x, 4); - f.update(); + f.update().unscheduled(); g(x) = f(x); Buffer result = g.realize({4});