From 96d55ec83fd2e6a5f8baa9727188e9f2b18a86e3 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 28 Dec 2022 06:11:06 -0500 Subject: [PATCH] [TIR] Create Layout with specified axis dtype (#13663) --- include/tvm/tir/data_layout.h | 4 ++- python/tvm/tir/data_layout.py | 8 ++++-- src/tir/ir/data_layout.cc | 15 ++++++----- tests/python/unittest/test_tir_data_layout.py | 27 ++++++++++++++++++- 4 files changed, 44 insertions(+), 10 deletions(-) diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 81c3e98e663d..7aefef6e485b 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -137,8 +137,10 @@ class Layout : public ObjectRef { * the corresponding lower case with factor size * indicates the split dimension. * return undefined layout if "__undef__" is passed. + * \param dtype The dtype of generated axes vars in the returned layout. + * It is required to be integer type. */ - TVM_DLL Layout(const std::string& name); // NOLINT(*) + TVM_DLL Layout(const std::string& name, DataType dtype = DataType::Int(32)); // NOLINT(*) /*! * \brief access the internal node container diff --git a/python/tvm/tir/data_layout.py b/python/tvm/tir/data_layout.py index f46a154612e1..71cc404ee23b 100644 --- a/python/tvm/tir/data_layout.py +++ b/python/tvm/tir/data_layout.py @@ -163,7 +163,7 @@ def backward_shape(self, shape): return _ffi_api.BijectiveLayoutBackwardShape(self, shape) # type: ignore -def layout(layout_str: str) -> Layout: +def layout(layout_str: str, dtype: str = "int32") -> Layout: """Create a layout node from a string. Parameters @@ -177,12 +177,16 @@ def layout(layout_str: str) -> Layout: Here subordinate axis channel_block=16 is the factor size of the primal axis C (channel). + dtype : str + The dtype of generated axes vars in the returned layout. + It is required to be integer type. + Returns ------- layout : Layout The created layout """ - return _ffi_api.Layout(layout_str) # type: ignore + return _ffi_api.Layout(layout_str, dtype) # type: ignore def bijective_layout( diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 3b22ffc71173..3bcb6e8d53fc 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -90,7 +90,8 @@ Layout::Layout(const Array& axes) { data_ = std::move(node); } -Layout::Layout(const std::string& name) { // NOLINT(*) +Layout::Layout(const std::string& name, DataType dtype) { // NOLINT(*) + CHECK(dtype.is_int()) << "TypeError: The input dtype should be integer type"; if (name == "__undef__") return; auto node = make_object(); @@ -106,14 +107,14 @@ Layout::Layout(const std::string& name) { // NOLINT(*) << " before dimension " << c; std::string shape_name("_shape"); shape_name.insert(0, 1, c); - IterVar axis = - IterVar(Range(PrimExpr(0), Var(shape_name)), Var(std::string(1, c)), tir::kDataPar); + IterVar axis(Range(IntImm(dtype, 0), Var(shape_name, dtype)), Var(std::string(1, c), dtype), + tir::kDataPar); node->axes.push_back(axis); } else if (c >= 'a' && c <= 'z') { ICHECK_GT(factor, 0) << "Invalid layout " << name << ": invalid factor size " << factor << " for dimension " << c; - IterVar axis = - IterVar(Range(PrimExpr(0), PrimExpr(factor)), Var(std::string(1, c)), tir::kDataPar); + IterVar axis(Range(IntImm(dtype, 0), IntImm(dtype, factor)), Var(std::string(1, c), dtype), + tir::kDataPar); node->axes.push_back(axis); factor = 0; } else if (c >= '0' && c <= '9') { @@ -426,7 +427,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); -TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name) { return Layout(name); }); +TVM_REGISTER_GLOBAL("tir.Layout").set_body_typed([](std::string name, DataType dtype) { + return Layout(name, dtype); +}); TVM_REGISTER_GLOBAL("tir.LayoutIndexOf").set_body_typed([](Layout layout, std::string axis) -> int { return layout.IndexOf(LayoutAxis::Get(axis)); diff --git a/tests/python/unittest/test_tir_data_layout.py b/tests/python/unittest/test_tir_data_layout.py index 5c2eb8febd9b..a76cb50da3bd 100644 --- a/tests/python/unittest/test_tir_data_layout.py +++ b/tests/python/unittest/test_tir_data_layout.py @@ -16,8 +16,9 @@ # under the License. """Test layout and bijective-layout node""" +import pytest import tvm -from tvm import te +import tvm.error from tvm.topi.utils import get_const_tuple @@ -52,6 +53,29 @@ def test_layout(): assert layout[-1] == "c" +def test_layout_dtype(): + layout_i32 = tvm.tir.layout("NCHW") + assert layout_i32.axes[0].var.dtype == "int32" + assert layout_i32.axes[0].dom.min.dtype == "int32" + assert layout_i32.axes[0].dom.extent.dtype == "int32" + assert layout_i32.axes[1].var.dtype == "int32" + assert layout_i32.axes[1].dom.min.dtype == "int32" + assert layout_i32.axes[1].dom.extent.dtype == "int32" + + layout_i64 = tvm.tir.layout("NCHW", dtype="int64") + assert layout_i64.axes[2].var.dtype == "int64" + assert layout_i64.axes[2].dom.min.dtype == "int64" + assert layout_i64.axes[2].dom.extent.dtype == "int64" + assert layout_i64.axes[3].var.dtype == "int64" + assert layout_i64.axes[3].dom.min.dtype == "int64" + assert layout_i64.axes[3].dom.extent.dtype == "int64" + + with pytest.raises(TypeError): + tvm.tir.layout("NCHW", dtype="float32") + with pytest.raises(TypeError): + tvm.tir.layout("NCHW", dtype=None) + + def test_bilayout_convertible(): # not convertible assert tvm.tir.bijective_layout("NCHW", "ABCD") is None @@ -88,6 +112,7 @@ def test_bilayout_index(): if __name__ == "__main__": test_layout() + test_layout_dtype() test_bilayout_convertible() test_bilayout_shape() test_bilayout_index()