Skip to content

Commit

Permalink
fixup! [TIR] Add schedule primitive TransformBlockLayout
Browse files Browse the repository at this point in the history
Fix doc
  • Loading branch information
vinx13 committed May 27, 2022
1 parent 936b864 commit a82836d
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 13 deletions.
5 changes: 2 additions & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -550,9 +550,8 @@ class ScheduleNode : public runtime::Object {
* \details The block iters and the block body are transformed by the given index_map.
* Outer loops corresponding to each new block iter are regenerated.
* The index_map is required to be bijective affine since we need its inverse mapping.
* \param self The state of the schedule
* \param block_sref The block sref that refers to the block to be transformed
* \param affine_index_map The transformation to apply.
* \param block_rv The block to be transformed
* \param index_map The transformation to apply.
*/
virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0;

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,7 +2296,7 @@ def transform_block_layout(
Parameters
----------
block_rv : BlockRV
block : BlockRV
The block to be transformed
index_map : Union[IndexMap, Callable]
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
* \param analyzer The arithmetic analyzer
* \throw ScheduleError If the loop doesn't starts with zero.
*/
void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer);
void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref,
arith::Analyzer* analyzer);

/******** Block-loop relation ********/

Expand Down
8 changes: 5 additions & 3 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,10 +688,12 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,

/******** Loop properties ********/

void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer) {
void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref,
arith::Analyzer* analyzer) {
class LoopNotStartWithZeroError : public ScheduleError {
public:
explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}
public:
explicit LoopNotStartWithZeroError(IRModule mod, For loop)
: mod_(mod), loop_(std::move(loop)) {}

String FastErrorString() const final {
return "ScheduleError: The primitive only supports loop starting with 0";
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int
* The index_map is required to be bijective affine since we need its inverse mapping.
* \param self The state of the schedule
* \param block_sref The block sref that refers to the block to be transformed
* \param affine_index_map The transformation to apply.
* \param index_map The transformation to apply.
*/
TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
const IndexMap& index_map);
Expand Down
9 changes: 5 additions & 4 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError {
IRModule mod_;
Block block_;
IndexMap index_map_;

};

class NotTrivialBindingError : public ScheduleError {
Expand Down Expand Up @@ -405,7 +404,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
Array<IterVar> new_block_iters; // new block iters
Array<PrimExpr> new_block_vars; // iter_var->var of new block iters
for (size_t i = 0; i < index_map->final_indices.size(); ++i) {
Var new_block_var{"bv" + std::to_string(i), DataType::Int(32)};
Var new_block_var{"v" + std::to_string(i), DataType::Int(32)};
new_block_vars.push_back(new_block_var);
IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type);
if (iter_type == kOpaque) {
Expand All @@ -419,7 +418,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
// in the body.

auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars);
// Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant zero.
// Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant
// zero.
for (const auto& iter_var : block_ptr->iter_vars) {
if (inverse_map.find(iter_var->var) == inverse_map.end()) {
ICHECK(is_one(iter_var->dom->extent));
Expand Down Expand Up @@ -447,7 +447,8 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
// Generate outer loops
Stmt body = GetRef<Stmt>(new_block_realize);
for (int i = static_cast<int>(new_loop_vars.size()) - 1; i >= 0; --i) {
body = For(Downcast<Var>(new_loop_vars[i]), 0, new_block_iter_range[i], ForKind::kSerial, std::move(body));
body = For(Downcast<Var>(new_loop_vars[i]), 0, new_block_iter_range[i], ForKind::kSerial,
std::move(body));
}

// Step 6: Do the actual replacement
Expand Down

0 comments on commit a82836d

Please sign in to comment.