Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Create Layout with specified axis dtype #13663

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/tir/data_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ class Layout : public ObjectRef {
* the corresponding lower case with factor size
* indicates the split dimension.
* return undefined layout if "__undef__" is passed.
* \param dtype The dtype of generated axes vars in the returned layout.
* It is required to be integer type.
*/
TVM_DLL Layout(const std::string& name); // NOLINT(*)
TVM_DLL Layout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*)

/*!
* \brief access the internal node container
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/tir/data_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def backward_shape(self, shape):
return _ffi_api.BijectiveLayoutBackwardShape(self, shape) # type: ignore


def layout(layout_str: str) -> Layout:
def layout(layout_str: str, dtype: str = "int32") -> Layout:
"""Create a layout node from a string.

Parameters
Expand All @@ -177,12 +177,16 @@ def layout(layout_str: str) -> Layout:
Here subordinate axis channel_block=16 is the factor size of
the primal axis C (channel).

dtype : str
The dtype of generated axes vars in the returned layout.
It is required to be integer type.

Returns
-------
layout : Layout
The created layout
"""
return _ffi_api.Layout(layout_str) # type: ignore
return _ffi_api.Layout(layout_str, dtype) # type: ignore


def bijective_layout(
Expand Down
15 changes: 9 additions & 6 deletions src/tir/ir/data_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ Layout::Layout(const Array<IterVar>& axes) {
data_ = std::move(node);
}

Layout::Layout(const std::string& name) { // NOLINT(*)
Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*)
CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type";
if (name == "__undef__") return;

auto node = make_object<LayoutNode>();
Expand All @@ -106,14 +107,14 @@ Layout::Layout(const std::string& name) { // NOLINT(*)
<< " before dimension " << c;
std::string shape_name("_shape");
shape_name.insert(0, 1, c);
IterVar axis =
IterVar(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), tir::kDataPar);
IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), Var(std::string(1, c), dtype),
tir::kDataPar);
node->axes.push_back(axis);
} else if (c >= 'a' && c <= 'z') {
ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor
<< " for dimension " << c;
IterVar axis =
IterVar(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), tir::kDataPar);
IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(std::string(1, c), dtype),
tir::kDataPar);
node->axes.push_back(axis);
factor = 0;
} else if (c >= '0' && c <= '9') {
Expand Down Expand Up @@ -426,7 +427,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ")";
});

TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); });
TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) {
return Layout(name, dtype);
});

TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::Get(axis));
Expand Down
27 changes: 26 additions & 1 deletion tests/python/unittest/test_tir_data_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
"""Test layout and bijective-layout node"""

import pytest
import tvm
from tvm import te
import tvm.error
from tvm.topi.utils import get_const_tuple


Expand Down Expand Up @@ -52,6 +53,29 @@ def test_layout():
assert layout[-1] == "c"


def test_layout_dtype():
layout_i32 = tvm.tir.layout("NCHW")
assert layout_i32.axes[0].var.dtype == "int32"
assert layout_i32.axes[0].dom.min.dtype == "int32"
assert layout_i32.axes[0].dom.extent.dtype == "int32"
assert layout_i32.axes[1].var.dtype == "int32"
assert layout_i32.axes[1].dom.min.dtype == "int32"
assert layout_i32.axes[1].dom.extent.dtype == "int32"

layout_i64 = tvm.tir.layout("NCHW", dtype="int64")
assert layout_i64.axes[2].var.dtype == "int64"
assert layout_i64.axes[2].dom.min.dtype == "int64"
assert layout_i64.axes[2].dom.extent.dtype == "int64"
assert layout_i64.axes[3].var.dtype == "int64"
assert layout_i64.axes[3].dom.min.dtype == "int64"
assert layout_i64.axes[3].dom.extent.dtype == "int64"

with pytest.raises(TypeError):
tvm.tir.layout("NCHW", dtype="float32")
with pytest.raises(TypeError):
tvm.tir.layout("NCHW", dtype=None)


def test_bilayout_convertible():
# not convertible
assert tvm.tir.bijective_layout("NCHW", "ABCD") is None
Expand Down Expand Up @@ -88,6 +112,7 @@ def test_bilayout_index():

if __name__ == "__main__":
test_layout()
test_layout_dtype()
test_bilayout_convertible()
test_bilayout_shape()
test_bilayout_index()