Skip to content

Commit

Permalink
[TIR] Create Layout with specified axis dtype (apache#13663)
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 authored and fzi-peccia committed Mar 27, 2023
1 parent 9c16365 commit 96d55ec
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 10 deletions.
4 changes: 3 additions & 1 deletion include/tvm/tir/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ class Layout : public ObjectRef {
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
* \param dtype The dtype of generated axes vars in the returned layout.
* It is required to be integer type.
*/
TVM_DLL Layout(const std::string& name); // NOLINT(*)
TVM_DLL Layout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*)

/*!
* \brief access the internal node container
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/tir/data_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def backward_shape(self, shape):
return _ffi_api.BijectiveLayoutBackwardShape(self, shape) # type: ignore


def layout(layout_str: str) -> Layout:
def layout(layout_str: str, dtype: str = "int32") -> Layout:
"""Create a layout node from a string.
Parameters
Expand All @@ -177,12 +177,16 @@ def layout(layout_str: str) -> Layout:
Here subordinate axis channel_block=16 is the factor size of
the primal axis C (channel).
dtype : str
The dtype of generated axes vars in the returned layout.
It is required to be integer type.
Returns
-------
layout : Layout
The created layout
"""
return _ffi_api.Layout(layout_str) # type: ignore
return _ffi_api.Layout(layout_str, dtype) # type: ignore


def bijective_layout(
Expand Down
15 changes: 9 additions & 6 deletions src/tir/ir/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ Layout::Layout(const Array<IterVar>& axes) {
data_ = std::move(node);
}

Layout::Layout(const std::string& name) { // NOLINT(*)
Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*)
CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type";
if (name == "__undef__") return;

auto node = make_object<LayoutNode>();
Expand All @@ -106,14 +107,14 @@ Layout::Layout(const std::string& name) { // NOLINT(*)
<< " before dimension " << c;
std::string shape_name("_shape");
shape_name.insert(0, 1, c);
IterVar axis =
IterVar(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), tir::kDataPar);
IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), Var(std::string(1, c), dtype),
tir::kDataPar);
node->axes.push_back(axis);
} else if (c >= 'a' && c <= 'z') {
ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
<< " for dimension " << c;
IterVar axis =
IterVar(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), tir::kDataPar);
IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(std::string(1, c), dtype),
tir::kDataPar);
node->axes.push_back(axis);
factor = 0;
} else if (c >= '0' && c <= '9') {
Expand Down Expand Up @@ -426,7 +427,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ")";
});

TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); });
TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) {
return Layout(name, dtype);
});

TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::Get(axis));
Expand Down
27 changes: 26 additions & 1 deletion tests/python/unittest/test_tir_data_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
"""Test layout and bijective-layout node"""

import pytest
import tvm
from tvm import te
import tvm.error
from tvm.topi.utils import get_const_tuple


Expand Down Expand Up @@ -52,6 +53,29 @@ def test_layout():
assert layout[-1] == "c"


def test_layout_dtype():
layout_i32 = tvm.tir.layout("NCHW")
assert layout_i32.axes[0].var.dtype == "int32"
assert layout_i32.axes[0].dom.min.dtype == "int32"
assert layout_i32.axes[0].dom.extent.dtype == "int32"
assert layout_i32.axes[1].var.dtype == "int32"
assert layout_i32.axes[1].dom.min.dtype == "int32"
assert layout_i32.axes[1].dom.extent.dtype == "int32"

layout_i64 = tvm.tir.layout("NCHW", dtype="int64")
assert layout_i64.axes[2].var.dtype == "int64"
assert layout_i64.axes[2].dom.min.dtype == "int64"
assert layout_i64.axes[2].dom.extent.dtype == "int64"
assert layout_i64.axes[3].var.dtype == "int64"
assert layout_i64.axes[3].dom.min.dtype == "int64"
assert layout_i64.axes[3].dom.extent.dtype == "int64"

with pytest.raises(TypeError):
tvm.tir.layout("NCHW", dtype="float32")
with pytest.raises(TypeError):
tvm.tir.layout("NCHW", dtype=None)


def test_bilayout_convertible():
# not convertible
assert tvm.tir.bijective_layout("NCHW", "ABCD") is None
Expand Down Expand Up @@ -88,6 +112,7 @@ def test_bilayout_index():

if __name__ == "__main__":
test_layout()
test_layout_dtype()
test_bilayout_convertible()
test_bilayout_shape()
test_bilayout_index()

0 comments on commit 96d55ec

Please sign in to comment.