Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add schedule primitive TransformBlockLayout #11485

Merged
merged 2 commits into from
May 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,17 @@ class ScheduleNode : public runtime::Object {
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;

/*!
* \brief Apply a transformation represented by IndexMap to block
* \details The block iters and the block body are transformed by the given index_map.
* Outer loops corresponding to each new block iter are regenerated.
* The index_map is required to be bijective affine since we need its inverse mapping.
* \param self The state of the schedule
* \param block_sref The block sref that refers to the block to be transformed
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
* \param affine_index_map The transformation to apply.
*/
virtual void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) = 0;

/*!
* \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
* or write index
Expand Down
61 changes: 61 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,67 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
self, block, buffer_index, buffer_index_type_enum, axis_separators
)

@type_checked
def transform_block_layout(
self,
block: BlockRV,
index_map: Union[IndexMap, Callable],
) -> None:
"""Apply a transformation represented by IndexMap to block

Parameters
----------
block_rv : BlockRV
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
The block to be transformed

index_map : Union[IndexMap, Callable]
The transformation to apply.

Examples
--------

Before transform_block_layout, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_transform_block_layout(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"]
) -> None:
for i, j in T.grid(16, 16):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do transform_block_layout:

.. code-block:: python

sch = tir.Schedule(before_transform_block_layout)
sch.transform_block_layout(sch.get_block("B"), lambda i, j: (i * 16 + j,))
print(sch.mod["main"].script())

After applying transform_block_layout, the IR becomes:

.. code-block:: python

@T.prim_func
def after_transform_block_layout(
A: T.Buffer[(16, 16), "float32"],
B: T.Buffer[(16, 16), "float32"]
) -> None:
for i in range(256):
with T.block("B"):
vi, = T.axis.remap("S", [i])
B[vi // 16, vi % 16] = A[vi // 16, vi % 16] * 2.0
"""
if callable(index_map):
index_map = IndexMap.from_func(index_map)
_ffi_api.ScheduleTransformBlockLayout( # type: ignore # pylint: disable=no-member
self, block, index_map
)

@type_checked
def set_axis_separator(
self,
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,16 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
std::unordered_set<const VarNode*>* data_par_vars,
std::unordered_set<const VarNode*>* reduce_vars);

/******** Loop properties ********/
/*!
* \brief Check the loop starts with zero.
* \param self The schedule state
* \param loop_sref The StmtSRef that points to the loop to be checked
* \param analyzer The arithmetic analyzer
* \throw ScheduleError If the loop doesn't starts with zero.
*/
void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer);

/******** Block-loop relation ********/

/*!
Expand Down
27 changes: 27 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,33 @@ bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize,
return has_block_vars_of_other_types;
}

/******** Loop properties ********/

void CheckLoopStartsWithZero(const ScheduleState& self, const StmtSRef& loop_sref, arith::Analyzer* analyzer) {
class LoopNotStartWithZeroError : public ScheduleError {
public:
explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}

String FastErrorString() const final {
return "ScheduleError: The primitive only supports loop starting with 0";
}

String DetailRenderTemplate() const final {
return "The loop {0} does not start with 0, which is not supported";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }

IRModule mod_;
For loop_;
};
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
if (!analyzer->CanProve(loop->min == 0)) {
throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
}
}

/******** Block-loop relation ********/

Array<StmtSRef> GetChildBlockSRefOnSRefTree(const ScheduleState& self,
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,14 @@ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_i
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
}

void ConcreteScheduleNode::TransformBlockLayout(const BlockRV& block_rv,
const IndexMap& index_map) {
TVM_TIR_SCHEDULE_BEGIN();
tir::TransformBlockLayout(state_, this->GetSRef(block_rv), index_map);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("transform_block_layout", this->error_render_level_);
}

void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
const IndexMap& index_map) override;
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const Array<IntImm>& axis_separators) override;
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,18 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String&
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map);

/*!
* \brief Apply a transformation represented by IndexMap to block
* \details The block iters and the block body are transformed by the given index_map.
* Outer loops corresponding to each new block iter are regenerated.
* The index_map is required to be bijective affine since we need its inverse mapping.
* \param self The state of the schedule
* \param block_sref The block sref that refers to the block to be transformed
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
* \param affine_index_map The transformation to apply.
*/
TVM_DLL void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
const IndexMap& index_map);

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
Loading