diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index bfbaa7cddd4f..78c09e81b16f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -100,16 +100,7 @@ class PrimExprNode : public BaseExprNode { */ DataType dtype; - /*! - * \brief Returns the TVMScript format - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ - TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false, - int num_context_lines = -1, - Optional path_to_underline = NullOpt) const; + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "PrimExpr"; static constexpr const uint32_t _type_child_slots = 38; diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 4cd357d4180b..0a5bac182fd9 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -328,16 +328,7 @@ class IRModuleNode : public Object { */ TVM_DLL std::unordered_set Imports() const; - /*! - * \brief Returns the TVMScript format - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ - TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false, - int num_context_lines = -1, - Optional path_to_underline = NullOpt) const; + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "IRModule"; static constexpr const bool _type_has_method_sequal_reduce = true; diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index e3f59fcc14a1..2a2d0bf3fb05 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -24,6 +24,7 @@ #define TVM_NODE_REPR_PRINTER_H_ #include +#include #include #include diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h new file mode 100644 index 000000000000..af50aae71a43 --- /dev/null +++ b/include/tvm/node/script_printer.h @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/repr_printer.h + * \brief Printer class to print repr string of each AST/IR nodes. + */ +#ifndef TVM_NODE_SCRIPT_PRINTER_H_ +#define TVM_NODE_SCRIPT_PRINTER_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { + +class PrinterConfigNode : public Object { + public: + /*! \brief The prefix of IR nodes */ + std::string ir_prefix = "I"; + /*! \brief The prefix of TIR nodes */ + std::string tir_prefix = "T"; + /*! \brief The prefix of Relax nodes */ + std::string relax_prefix = "R"; + /*! \brief Default data type of TIR buffer */ + DataType buffer_dtype = DataType::Float(32); + /*! \brief Default data type of integer literals */ + DataType int_dtype = DataType::Int(32); + /*! + * \brief Default data type of float literals. Right now we always print out the explicit type + * of floating point values, so setting it to Void means we do not print without the + * T.float32/T.float64 wrapper. + */ + DataType float_dtype = DataType::Void(); + /*! \brief Whether or not to verbose print expressions. */ + bool verbose_expr = false; + /* \brief Number of spaces used for indentation*/ + int indent_spaces = 4; + /* \brief Whether to print line numbers */ + bool print_line_numbers = false; + /* \brief Number of context lines to print around the underlined text */ + int num_context_lines = -1; + /* \brief Object path to be underlined */ + Optional path_to_underline = NullOpt; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("ir_prefix", &ir_prefix); + v->Visit("buffer_dtype", &buffer_dtype); + v->Visit("int_dtype", &int_dtype); + v->Visit("float_dtype", &float_dtype); + v->Visit("verbose_expr", &verbose_expr); + v->Visit("indent_spaces", &indent_spaces); + v->Visit("print_line_numbers", &print_line_numbers); + v->Visit("num_context_lines", &num_context_lines); + v->Visit("path_to_underline", &path_to_underline); + } + + static constexpr const char* _type_key = "node.PrinterConfig"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrinterConfigNode, Object); +}; + +class PrinterConfig : public ObjectRef { + public: + explicit PrinterConfig(Map config_dict = Map()); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrinterConfig, runtime::ObjectRef, + PrinterConfigNode); +}; + +/*! \brief Legacy behavior of ReprPrinter. */ +class TVMScriptPrinter { + public: + /* Convert the object to TVMScript format */ + static std::string Script(const ObjectRef& node, const Optional& cfg); + // Allow registration to be printer. + using FType = NodeFunctor; + TVM_DLL static FType& vtable(); +}; + +#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER() \ + std::string Script(const Optional& config = NullOpt) const { \ + return TVMScriptPrinter::Script(GetRef(this), config.value_or(PrinterConfig())); \ + } + +} // namespace tvm +#endif // TVM_NODE_SCRIPT_PRINTER_H_ diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 01f0fc1f4a91..6504e2c2843d 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -29,8 +29,16 @@ namespace tvm { namespace script { namespace printer { +// Forward declaration class Doc; +/*! + * \brief Convert Doc into Python script. + * \param doc Doc to be converted + * \param cfg The configuration of the printer + */ +String DocToPythonScript(Doc doc, const PrinterConfig& cfg); + /*! * \brief The base class of all Doc. * diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index e0419b469505..67fa96ef8082 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -126,6 +126,8 @@ class IRDocsifierNode : public Object { /*! \brief The name of the variable */ Optional name; }; + /*! \brief The configuration of the printer */ + PrinterConfig cfg{nullptr}; /*! * \brief The stack of frames. * \sa FrameNode @@ -232,7 +234,7 @@ class IRDocsifier : public ObjectRef { public: using FType = IRDocsifierFunctor; /*! \brief Create a IRDocsifier. */ - IRDocsifier(); + explicit IRDocsifier(const PrinterConfig& cfg); /*! \brief The registration table for IRDocsifier. */ TVM_DLL static FType& vtable(); diff --git a/include/tvm/script/printer/printer.h b/include/tvm/script/printer/printer.h deleted file mode 100644 index b373a2be73fb..000000000000 --- a/include/tvm/script/printer/printer.h +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_SCRIPT_PRINTER_PRINTER_H_ -#define TVM_SCRIPT_PRINTER_PRINTER_H_ - -#include -#include - -#include -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -/*! \brief Default values in the TVMScript printer */ -struct Default { - /*! \brief The prefix of IR nodes */ - std::unordered_map ir_prefix = {{"ir", "I"}, {"tir", "T"}}; - /*! \brief Default data type of TIR buffer */ - DataType buffer_dtype = DataType::Float(32); - /*! \brief Default data type of integer literals */ - DataType int_dtype = DataType::Int(32); - /*! - * \brief Default data type of float literals. Right now we always print out the explicit type - * of floating point values, so setting it to Void means we do not print without the - * T.float32/T.float64 wrapper. - */ - DataType float_dtype = DataType::Void(); - /*! \brief Whether or not to verbose print expressions. */ - bool verbose_expr = false; - /*! \brief Returns a singleton of the configuration */ - static Default* Instance(); - static std::string& Prefix(const std::string& ir) { return Instance()->ir_prefix.at(ir); } - static DataType& BufferDType() { return Instance()->buffer_dtype; } - static DataType& IntDType() { return Instance()->int_dtype; } - static DataType& FloatDType() { return Instance()->float_dtype; } - static bool& VerboseExpr() { return Instance()->verbose_expr; } -}; - -/*! - * \brief Convert Doc into Python script. - * \param doc Doc to be converted - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ -String DocToPythonScript(Doc doc, // - int indent_spaces = 4, // - bool print_line_numbers = false, // - int num_context_lines = -1, // - Optional path_to_underline = NullOpt); - -} // namespace printer -} // namespace script -} // namespace tvm - -#endif // TVM_SCRIPT_PRINTER_PRINTER_H_ diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 17e7de930260..e135c261990b 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -132,16 +132,7 @@ class PrimFuncNode : public BaseFuncNode { */ TVM_DLL FuncType func_type_annotation() const; - /*! - * \brief Returns the TVMScript format - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ - std::string Script(int indent_spaces = 4, bool print_line_numbers = false, - int num_context_lines = -1, - Optional path_to_underline = NullOpt) const; + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "tir.PrimFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index e0b7bcc868b3..7a7ad2acedd7 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -46,16 +46,7 @@ class StmtNode : public Object { StmtNode() = default; explicit StmtNode(Span span) : span(span) {} - /*! - * \brief Returns the TVMScript format - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ - std::string Script(int indent_spaces = 4, bool print_line_numbers = false, - int num_context_lines = -1, - Optional path_to_underline = NullOpt) const; + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "tir.Stmt"; static constexpr const bool _type_has_method_sequal_reduce = true; diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 52af8407b7a0..3c3fefb6d6c6 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -17,7 +17,7 @@ """Common expressions data structures in the IR.""" import tvm._ffi -from ..runtime import const, convert +from ..runtime import Scriptable, const, convert from . import _ffi_api from .base import Node @@ -121,7 +121,7 @@ def astext(self, show_meta_data=True, annotate=None): @tvm._ffi.register_object -class Range(Node): +class Range(Node, Scriptable): """Represent a range in TVM. You do not need to create a Range explicitly. diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 51410049ec74..3daffb2640c5 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. """IRModule that holds the functions and type definitions.""" -from typing import Optional - import tvm._ffi from tvm._ffi.base import string_types +from tvm.runtime import Scriptable from . import _ffi_api from . import expr as _expr @@ -27,7 +26,7 @@ @tvm._ffi.register_object("IRModule") -class IRModule(Node): +class IRModule(Node, Scriptable): """IRModule that holds functions and type definitions. IRModule is the basic unit for all IR transformations across the stack. @@ -314,78 +313,3 @@ def astext(self, show_meta_data=True, annotate=None): from tvm.relay import astext # pylint: disable=import-outside-toplevel return astext(self, show_meta_data, annotate) - - def script( - self, - *, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.Module_Script( # type: ignore # pylint: disable=no-member - self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline - ) - - def show( - self, - *, - style: Optional[str] = None, - black_format: bool = True, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - """ - from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel - cprint, - ) - - cprint( - self.script( - indent_spaces=indent_spaces, - print_line_numbers=print_line_numbers, - num_context_lines=num_context_lines, - path_to_underline=path_to_underline, - ), - style=style, - black_format=black_format, - ) diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index ea06aeda2030..c83cef3f6cea 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -19,12 +19,13 @@ import tvm import tvm._ffi +from tvm.runtime import Scriptable from . import _ffi_api from .base import Node -class Type(Node): +class Type(Node, Scriptable): """The base class of all types.""" def __eq__(self, other): diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 502de7372154..71f71e6c8427 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -20,6 +20,7 @@ from .packed_func import PackedFunc from .object import Object from .object_path import ObjectPath, ObjectPathPair +from .script_printer import Scriptable from .object_generic import ObjectGeneric, ObjectTypes from .ndarray import NDArray, DataType, DataTypeCode, Device from .module import Module, num_threads diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py new file mode 100644 index 000000000000..23144c47f1ee --- /dev/null +++ b/python/tvm/runtime/script_printer.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Configuration of TVMScript printer""" +from typing import Optional + +from tvm._ffi import register_object +from tvm.runtime import Object + +from . import _ffi_node_api +from .object_path import ObjectPath + + +@register_object("node.PrinterConfig") +class PrinterConfig(Object): + """Configuration of TVMScript printer""" + + ir_prefix: str + tir_prefix: str + relax_prefix: str + buffer_dtype: str + int_dtype: str + float_dtype: str + verbose_expr: bool + indent_spaces: int + print_line_numbers: bool + num_context_lines: int + path_to_underline: Optional[ObjectPath] + + def __init__( + self, + *, + ir_prefix: str = "I", + tir_prefix: str = "T", + relax_prefix: str = "R", + buffer_dtype: str = "float32", + int_dtype: str = "int32", + float_dtype: str = "void", + verbose_expr: bool = False, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline: Optional[ObjectPath] = None, + ) -> None: + if num_context_lines is None: + num_context_lines = -1 + self.__init_handle_by_constructor__( + _ffi_node_api.PrinterConfig, # type: ignore # pylint: disable=no-member + { + "ir_prefix": ir_prefix, + "tir_prefix": tir_prefix, + "relax_prefix": relax_prefix, + "buffer_dtype": buffer_dtype, + "int_dtype": int_dtype, + "float_dtype": float_dtype, + "verbose_expr": verbose_expr, + "indent_spaces": indent_spaces, + "print_line_numbers": print_line_numbers, + "num_context_lines": num_context_lines, + "path_to_underline": path_to_underline, + }, + ) + + +def _script(obj: Object, config: PrinterConfig) -> str: + return _ffi_node_api.TVMScriptPrinterScript(obj, config) # type: ignore # pylint: disable=no-member + + +class Scriptable: + """A base class that enables the script() and show() method.""" + + def script( + self, + *, + ir_prefix: str = "I", + tir_prefix: str = "T", + relax_prefix: str = "R", + buffer_dtype: str = "float32", + int_dtype: str = "int32", + float_dtype: str = "void", + verbose_expr: bool = False, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: int = -1, + path_to_underline: Optional[ObjectPath] = None, + ) -> str: + """Print TVM IR into TVMScript text format + + Parameters + ---------- + ir_prefix : str = "I" + The prefix of AST nodes from tvm.ir + tir_prefix : str = "T" + The prefix of AST nodes from tvm.tir + relax_prefix : str = "R" + The prefix of AST nodes from tvm.relax + buffer_dtype : str = "float32" + The default data type of buffer + int_dtype : str = "int32" + The default data type of integer + float_dtype : str = "void" + The default data type of float + verbose_expr : bool = False + Whether to print the detailed definition of each variable in the expression + indent_spaces : int = 4 + The number of spaces for indentation + print_line_numbers : bool = False + Whether to print line numbers + num_context_lines : int = -1 + The number of lines of context to print before and after the line to underline. + path_to_underline : Optional[ObjectPath] = None + Object path to be underlined + + Returns + ------- + script : str + The TVM Script of the given TVM IR + """ + return _script( + self, + PrinterConfig( + ir_prefix=ir_prefix, + tir_prefix=tir_prefix, + relax_prefix=relax_prefix, + buffer_dtype=buffer_dtype, + int_dtype=int_dtype, + float_dtype=float_dtype, + verbose_expr=verbose_expr, + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + ) + + def show( + self, + style: Optional[str] = None, + black_format: bool = True, + *, + ir_prefix: str = "I", + tir_prefix: str = "T", + relax_prefix: str = "R", + buffer_dtype: str = "float32", + int_dtype: str = "int32", + float_dtype: str = "void", + verbose_expr: bool = False, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: int = -1, + path_to_underline: Optional[ObjectPath] = None, + ) -> None: + """A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + black_format: bool + If true (default), use the formatter Black to format the TVMScript + ir_prefix : str = "I" + The prefix of AST nodes from tvm.ir + tir_prefix : str = "T" + The prefix of AST nodes from tvm.tir + relax_prefix : str = "R" + The prefix of AST nodes from tvm.relax + buffer_dtype : str = "float32" + The default data type of buffer + int_dtype : str = "int32" + The default data type of integer + float_dtype : str = "void" + The default data type of float + verbose_expr : bool = False + Whether to print the detailed definition of each variable in the expression + indent_spaces : int = 4 + The number of spaces for indentation + print_line_numbers : bool = False + Whether to print line numbers + num_context_lines : int = -1 + The number of lines of context to print before and after the line to underline. + path_to_underline : Optional[ObjectPath] = None + Object path to be underlined + """ + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) + + cprint( + self.script( + ir_prefix=ir_prefix, + tir_prefix=tir_prefix, + relax_prefix=relax_prefix, + buffer_dtype=buffer_dtype, + int_dtype=int_dtype, + float_dtype=float_dtype, + verbose_expr=verbose_expr, + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py index 01d89dacbf52..8d2f73bb2b8d 100644 --- a/python/tvm/script/printer/__init__.py +++ b/python/tvm/script/printer/__init__.py @@ -19,4 +19,3 @@ This package provides a set of APIs to print supported TVM IR into TVMScript in a roundtrippable way. """ -from . import default diff --git a/python/tvm/script/printer/default.py b/python/tvm/script/printer/default.py deleted file mode 100644 index 33ca693ebf32..000000000000 --- a/python/tvm/script/printer/default.py +++ /dev/null @@ -1,83 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""The printer configuration""" -from typing_extensions import Literal - -from . import _ffi_api - - -def ir_prefix( # pylint: disable=invalid-name - ir: Literal["ir", "tir"], - prefix: str, -) -> None: - """Set the prefix for the IR. If not set, the prefix for "tvm.ir" is "I", and for "tir" is "T. - - Parameters - ---------- - ir : str - The IR type, either "ir" or "tir". - - prefix : str - The prefix to use. - """ - _ffi_api.DefaultIRPrefix(ir, prefix) # type: ignore # pylint: disable=no-member - - -def buffer_dtype(dtype: str) -> None: - """Set the default dtype for buffer. If not set, it is "float32". - - Parameters - ---------- - dtype : str - The default dtype for buffer. - """ - _ffi_api.DefaultBufferDtype(dtype) # type: ignore # pylint: disable=no-member - - -def int_dtype(dtype: str) -> None: - """Set the default dtype for integers. If not set, it is "int32". - - Parameters - ---------- - dtype : str - The default dtype for buffer. - """ - _ffi_api.DefaultBufferDtype(dtype) # type: ignore # pylint: disable=no-member - - -def float_dtype(dtype: str) -> None: - """Set the default dtype for buffer. If not set, there is no default, - which means every floating point numbers will be wrapped with its precise dtype. - - Parameters - ---------- - dtype : str - The default dtype for buffer. - """ - _ffi_api.DefaultFloatDtype(dtype) # type: ignore # pylint: disable=no-member - - -def verbose_expr(verbose: bool) -> None: - """Whether or not to verbose print expressions. If not, the definition of every variable in an - expression will be printed as separate statements. Otherwise, the result will be a one-liner. - - Parameters - ---------- - dtype : str - The default dtype for buffer. - """ - _ffi_api.VerboseExpr(verbose) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py index 1791f46b00a2..137b71a77d9f 100644 --- a/python/tvm/script/printer/doc_printer.py +++ b/python/tvm/script/printer/doc_printer.py @@ -17,7 +17,10 @@ """Functions to print doc into text format""" from typing import Optional -from tvm.runtime.object_path import ObjectPath + +from tvm.runtime import ObjectPath +from tvm.runtime.script_printer import PrinterConfig + from . import _ffi_api from .doc import Doc @@ -49,8 +52,10 @@ def to_python_script( script : str The text representation of Doc in Python syntax """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.DocToPythonScript( # type: ignore - doc, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + cfg = PrinterConfig( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, ) + return _ffi_api.DocToPythonScript(doc, cfg) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index c2c158c77f78..11db28e20a1c 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -20,13 +20,13 @@ import tvm._ffi from tvm._ffi.base import string_types from tvm.ir import PointerType, PrimExpr, PrimType, Range -from tvm.runtime import Object, convert +from tvm.runtime import Object, Scriptable, convert from . import _ffi_api @tvm._ffi.register_object("tir.Buffer") -class Buffer(Object): +class Buffer(Object, Scriptable): """Symbolic data buffer in TVM. Buffer provide a way to represent data layout @@ -179,7 +179,11 @@ def offset_of(self, indices): def __getitem__(self, indices): from ..arith import Analyzer # pylint: disable=import-outside-toplevel - from .expr import BufferLoad, Ramp, const # pylint: disable=import-outside-toplevel + from .expr import ( # pylint: disable=import-outside-toplevel + BufferLoad, + Ramp, + const, + ) from .stmt import BufferRegion # pylint: disable=import-outside-toplevel if not isinstance(indices, (tuple, list)): diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index dab7a175185d..cb4a892ac289 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -34,7 +34,7 @@ from tvm import ir from tvm.ir import Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const +from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, Scriptable, const from . import _ffi_api from . import generic as _generic @@ -318,88 +318,13 @@ def asobject(self): return IntImm("int32", self.value, self.span) # type: ignore -class PrimExprWithOp(ExprOp, PrimExpr): +class PrimExprWithOp(ExprOp, PrimExpr, Scriptable): """Helper base class to inherit from PrimExpr.""" # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ __hash__ = PrimExpr.__hash__ - def script( - self, - *, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.PrimExprScript( # type: ignore # pylint: disable=no-member - self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline - ) - - def show( - self, - *, - style: Optional[str] = None, - black_format: bool = True, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - """ - from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel - cprint, - ) - - cprint( - self.script( - indent_spaces=indent_spaces, - print_line_numbers=print_line_numbers, - num_context_lines=num_context_lines, - path_to_underline=path_to_underline, - ), - style=style, - black_format=black_format, - ) - class ConstExpr(PrimExprWithOp): pass @@ -460,7 +385,7 @@ def __init__(self, name, dtype, span=None): @tvm._ffi.register_object("tir.IterVar") -class IterVar(Object, ExprOp): +class IterVar(Object, ExprOp, Scriptable): """Represent iteration variable. IterVar represents axis iterations in the computation. @@ -521,7 +446,7 @@ def __init__(self, dom, var, iter_type, thread_tag="", span=None): @tvm._ffi.register_object("tir.CommReducer") -class CommReducer(Object): +class CommReducer(Object, Scriptable): """Commutative reduce operator Parameters diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index fb5a37c5dc17..f854e56ad11a 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -24,7 +24,7 @@ import tvm._ffi import tvm.runtime from tvm.ir import BaseFunc, Range -from tvm.runtime import Object +from tvm.runtime import Object, Scriptable from ..runtime.ndarray import NDArray from . import _ffi_api @@ -33,7 +33,7 @@ @tvm._ffi.register_object("tir.PrimFunc") -class PrimFunc(BaseFunc): +class PrimFunc(BaseFunc, Scriptable): """A function declaration expression. Parameters @@ -170,81 +170,6 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: """ return _ffi_api.Specialize(self, param_map) # type: ignore - def script( - self, - *, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.PrimFuncScript( # type: ignore # pylint: disable=no-member - self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline - ) - - def show( - self, - *, - style: Optional[str] = None, - black_format: bool = True, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - """ - from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel - cprint, - ) - - cprint( - self.script( - indent_spaces=indent_spaces, - print_line_numbers=print_line_numbers, - num_context_lines=num_context_lines, - path_to_underline=path_to_underline, - ), - style=style, - black_format=black_format, - ) - @tvm._ffi.register_object("tir.TensorIntrin") class TensorIntrin(Object): diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 096c13653a94..d6cd06a1d915 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -31,91 +31,16 @@ import tvm._ffi from tvm.ir import PrimExpr, Range, Span -from tvm.runtime import Object, const +from tvm.runtime import Object, Scriptable, const from . import _ffi_api from .buffer import Buffer from .expr import IterVar -class Stmt(Object): +class Stmt(Object, Scriptable): """Base class of all the statements.""" - def script( - self, - *, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.StmtScript( # type: ignore # pylint: disable=no-member - self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline - ) - - def show( - self, - *, - style: Optional[str] = None, - black_format: bool = True, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - """ - from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel - cprint, - ) - - cprint( - self.script( - indent_spaces=indent_spaces, - print_line_numbers=print_line_numbers, - num_context_lines=num_context_lines, - path_to_underline=path_to_underline, - ), - style=style, - black_format=black_format, - ) - @tvm._ffi.register_object("tir.LetStmt") class LetStmt(Stmt): @@ -623,7 +548,7 @@ def __init__(self, buffer, bounds, span=None): @tvm._ffi.register_object("tir.BufferRegion") -class BufferRegion(Object): +class BufferRegion(Object, Scriptable): """BufferRegion node. Parameters @@ -643,7 +568,7 @@ def __init__(self, buffer: Buffer, region: List[Range]): @tvm._ffi.register_object("tir.MatchBufferRegion") -class MatchBufferRegion(Object): +class MatchBufferRegion(Object, Scriptable): """MatchBufferRegion node. Parameters diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc new file mode 100644 index 000000000000..605d5208462f --- /dev/null +++ b/src/node/script_printer.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +namespace tvm { + +TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { + static FType inst; + return inst; +} + +std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional& cfg) { + return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig())); +} + +PrinterConfig::PrinterConfig(Map config_dict) { + runtime::ObjectPtr n = make_object(); + if (auto v = config_dict.Get("ir_prefix")) { + n->ir_prefix = Downcast(v); + } + if (auto v = config_dict.Get("tir_prefix")) { + n->tir_prefix = Downcast(v); + } + if (auto v = config_dict.Get("relax_prefix")) { + n->relax_prefix = Downcast(v); + } + if (auto v = config_dict.Get("buffer_dtype")) { + n->buffer_dtype = DataType(runtime::String2DLDataType(Downcast(v))); + } + if (auto v = config_dict.Get("int_dtype")) { + n->int_dtype = DataType(runtime::String2DLDataType(Downcast(v))); + } + if (auto v = config_dict.Get("float_dtype")) { + n->float_dtype = DataType(runtime::String2DLDataType(Downcast(v))); + } + if (auto v = config_dict.Get("verbose_expr")) { + n->verbose_expr = Downcast(v)->value; + } + if (auto v = config_dict.Get("indent_spaces")) { + n->indent_spaces = Downcast(v)->value; + } + if (auto v = config_dict.Get("print_line_numbers")) { + n->print_line_numbers = Downcast(v)->value; + } + if (auto v = config_dict.Get("num_context_lines")) { + n->num_context_lines = Downcast(v)->value; + } + if (auto v = config_dict.Get("path_to_underline")) { + n->path_to_underline = Downcast(v); + } + this->data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PrinterConfigNode); +TVM_REGISTER_GLOBAL("node.PrinterConfig").set_body_typed([](Map config_dict) { + return PrinterConfig(config_dict); +}); +TVM_REGISTER_GLOBAL("node.TVMScriptPrinterScript").set_body_typed(TVMScriptPrinter::Script); + +} // namespace tvm diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index 38b8ef897740..a3a5c06ede0d 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -77,13 +77,13 @@ ByteSpan PopNextUnderline(UnderlineIter* next_underline, UnderlineIter end_under void PrintChunk(const std::pair& lines_range, const std::pair& underlines, const std::string& text, - const std::vector& line_starts, const DocPrinterOptions& options, + const std::vector& line_starts, const PrinterConfig& options, size_t line_number_width, std::string* out) { UnderlineIter next_underline = underlines.first; ByteSpan current_underline = PopNextUnderline(&next_underline, underlines.second); for (size_t line_idx = lines_range.first; line_idx < lines_range.second; ++line_idx) { - if (options.print_line_numbers) { + if (options->print_line_numbers) { std::string line_num_str = std::to_string(line_idx + 1); line_num_str.push_back(' '); for (size_t i = line_num_str.size(); i < line_number_width; ++i) { @@ -148,12 +148,12 @@ void PrintCut(size_t num_lines_skipped, std::string* out) { std::pair GetLinesForUnderline(const ByteSpan& underline, const std::vector& line_starts, - size_t num_lines, const DocPrinterOptions& options) { + size_t num_lines, const PrinterConfig& options) { size_t first_line_of_underline = GetLineIndex(underline.first, line_starts); - size_t first_line_of_chunk = MoveBack(first_line_of_underline, options.num_context_lines); + size_t first_line_of_chunk = MoveBack(first_line_of_underline, options->num_context_lines); size_t end_line_of_underline = GetLineIndex(underline.second - 1, line_starts) + 1; size_t end_line_of_chunk = - MoveForward(end_line_of_underline, options.num_context_lines, num_lines); + MoveForward(end_line_of_underline, options->num_context_lines, num_lines); return {first_line_of_chunk, end_line_of_chunk}; } @@ -181,8 +181,8 @@ size_t GetNumLines(const std::string& text, const std::vector& line_star } } -size_t GetLineNumberWidth(size_t num_lines, const DocPrinterOptions& options) { - if (options.print_line_numbers) { +size_t GetLineNumberWidth(size_t num_lines, const PrinterConfig& options) { + if (options->print_line_numbers) { return std::to_string(num_lines).size() + 1; } else { return 0; @@ -190,8 +190,7 @@ size_t GetLineNumberWidth(size_t num_lines, const DocPrinterOptions& options) { } std::string DecorateText(const std::string& text, const std::vector& line_starts, - const DocPrinterOptions& options, - const std::vector& underlines) { + const PrinterConfig& options, const std::vector& underlines) { size_t num_lines = GetNumLines(text, line_starts); size_t line_number_width = GetLineNumberWidth(num_lines, options); @@ -237,7 +236,7 @@ std::string DecorateText(const std::string& text, const std::vector& lin } // anonymous namespace -DocPrinter::DocPrinter(const DocPrinterOptions& options) : options_(options) { +DocPrinter::DocPrinter(const PrinterConfig& options) : options_(options) { line_starts_.push_back(0); } diff --git a/src/script/printer/doc_printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h index db1d733d96ad..7851ce061b0d 100644 --- a/src/script/printer/doc_printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -35,23 +35,6 @@ namespace printer { /*! \brief Range of byte offsets in a string */ using ByteSpan = std::pair; -/*! \brief Options to customize DocPrinter's output */ -struct DocPrinterOptions { - /*! \brief Number of spaces for one level of indentation */ - int indent_spaces = 4; - - /*! \brief Whether to print the line numbers */ - bool print_line_numbers = false; - - /*! - * \brief Number of context lines to print around the underlined text. - * - * If set to a non-default value `n`, only print `n` context lines before and after - * the underlined pieces of text. - */ - size_t num_context_lines = std::numeric_limits::max(); -}; - /*! * \brief DocPrinter is responsible for printing Doc tree into text format * \details This is the base class for translating Doc into string. @@ -67,7 +50,8 @@ class DocPrinter { * * \param options the option for printer */ - explicit DocPrinter(const DocPrinterOptions& options); + explicit DocPrinter(const PrinterConfig& options); + virtual ~DocPrinter() = default; /*! @@ -224,13 +208,13 @@ class DocPrinter { * \brief Increase the indent level of any content to be * printed after this call */ - void IncreaseIndent() { indent_ += options_.indent_spaces; } + void IncreaseIndent() { indent_ += options_->indent_spaces; } /*! * \brief Decrease the indent level of any content to be * printed after this call */ - void DecreaseIndent() { indent_ -= options_.indent_spaces; } + void DecreaseIndent() { indent_ -= options_->indent_spaces; } /*! * \brief Add a new line into the output stream @@ -258,7 +242,7 @@ class DocPrinter { void MarkSpan(const ByteSpan& span, const ObjectPath& path); /*! \brief Options to customize certain aspects of the output */ - DocPrinterOptions options_; + PrinterConfig options_; /*! \brief the current level of indent */ int indent_ = 0; diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 8634236df5c3..ce6b8e7f423c 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -142,7 +142,7 @@ ExprPrecedence GetExprPrecedence(const ExprDoc& doc) { class PythonDocPrinter : public DocPrinter { public: - explicit PythonDocPrinter(const DocPrinterOptions& options) : DocPrinter(options) {} + explicit PythonDocPrinter(const PrinterConfig& options) : DocPrinter(options) {} protected: using DocPrinter::PrintDoc; @@ -642,17 +642,12 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) { NewLineWithoutIndent(); } -String DocToPythonScript(Doc doc, int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) { - DocPrinterOptions options; - options.indent_spaces = indent_spaces; - options.print_line_numbers = print_line_numbers; - if (num_context_lines >= 0) { - options.num_context_lines = num_context_lines; +String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { + if (cfg->num_context_lines < 0) { + cfg->num_context_lines = std::numeric_limits::max(); } - - PythonDocPrinter printer(options); - printer.Append(doc, path_to_underline); + PythonDocPrinter printer(cfg); + printer.Append(doc, cfg->path_to_underline); std::string result = printer.GetString(); int last_space = result.size(); while (last_space > 0 && std::isspace(result[last_space - 1])) { diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index e438919f4b1b..4a246e169276 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) BaseFunc func = kv.second; (*f)->stmts.push_back(d->AsDoc(func, p->Attr("functions")->MapValue(gv))); } - return ClassDoc(IdDoc("Module"), {IR("ir_module")}, (*f)->stmts); + return ClassDoc(IdDoc("Module"), {IR(d, "ir_module")}, (*f)->stmts); } }); @@ -63,43 +63,44 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc { - return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); + return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { - return IR("Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); + return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](TypeVar var, ObjectPath p, IRDocsifier d) -> Doc { - return IR("TypeVar")->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), // - LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))}); + return IR(d, "TypeVar") + ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), // + LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](GlobalTypeVar var, ObjectPath p, IRDocsifier d) -> Doc { - return IR("GlobalTypeVar") + return IR(d, "GlobalTypeVar") ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](RelayRefType ref, ObjectPath p, IRDocsifier d) -> Doc { - return IR("RelayRef")->Call({d->AsDoc(ref->value, p->Attr("value"))}); + return IR(d, "RelayRef")->Call({d->AsDoc(ref->value, p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](TensorType type, ObjectPath p, IRDocsifier d) -> Doc { - return IR("TensorType") + return IR(d, "TensorType") ->Call({d->AsDoc(type->shape, p->Attr("shape")), LiteralDoc::DataType(type->dtype, p->Attr("dtype"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](FuncType func_type, ObjectPath p, IRDocsifier d) -> Doc { - return IR("FuncType") + return IR(d, "FuncType") ->Call({ d->AsDoc(func_type->type_params, p->Attr("type_params")), d->AsDoc(func_type->arg_types, p->Attr("arg_types")), @@ -109,19 +110,17 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](IncompleteType ty, ObjectPath p, IRDocsifier d) -> Doc { - return IR("IncompleteType")->Call({}); + return IR(d, "IncompleteType")->Call({}); }); -void ReprPrintIRModule(const ObjectRef& mod, ReprPrinter* p) { +std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) { if (const auto* f = runtime::Registry::Get("relay.ir.PrintRelayModule")) { if (Optional s = (*f)(mod)) { - p->stream << s.value(); - return; + return s.value(); } } - std::string res = - DocToPythonScript(IRDocsifier()->AsDoc(Downcast(mod), ObjectPath::Root())); - p->stream << res; + Doc doc = IRDocsifier(cfg)->AsDoc(mod, ObjectPath::Root()); + return DocToPythonScript(doc, cfg); } TVM_SCRIPT_REPR(TypeVarNode, ReprPrintIR); diff --git a/src/script/printer/ir/script_method.cc b/src/script/printer/ir/script_method.cc deleted file mode 100644 index 01d3ede7ea6c..000000000000 --- a/src/script/printer/ir/script_method.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include - -#include "./utils.h" - -namespace tvm { - -std::string IRModuleNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) const { - using namespace tvm::script::printer; - return DocToPythonScript(IRDocsifier()->AsDoc(GetRef(this), ObjectPath::Root()), - indent_spaces, print_line_numbers, num_context_lines, path_to_underline); -} - -TVM_REGISTER_GLOBAL("ir.Module_Script").set_body_method(&IRModuleNode::Script); - -} // namespace tvm diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h index 820fe13df3c6..d20756e6081a 100644 --- a/src/script/printer/ir/utils.h +++ b/src/script/printer/ir/utils.h @@ -23,9 +23,9 @@ #include #include #include -#include #include +#include #include #include "../utils.h" @@ -35,7 +35,9 @@ namespace script { namespace printer { /*! \brief Creates the IR common prefix, which is by default `I` */ -inline ExprDoc IR(const String& attr) { return IdDoc(Default::Prefix("ir"))->Attr(attr); } +inline ExprDoc IR(const IRDocsifier& d, const String& attr) { + return IdDoc(d->cfg->ir_prefix)->Attr(attr); +} class IRFrameNode : public FrameNode { public: @@ -57,15 +59,14 @@ class IRFrame : public Frame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRFrame, Frame, IRFrameNode); }; -inline void ReprPrintIR(const ObjectRef& obj, ReprPrinter* p) { - IRDocsifier d; +/*! \brief Redirected method for the ReprPrinter */ +inline std::string ReprPrintIR(const ObjectRef& obj, const PrinterConfig& cfg) { + IRDocsifier d(cfg); With f(d); (*f)->AddDispatchToken(d, "ir"); - try { - p->stream << DocToPythonScript(Docsify(obj, d, *f)); - } catch (const Error& e) { - HandleUnsupportedFallback(e, obj, p); - } + std::ostringstream oss; + oss << Docsify(obj, d, *f, cfg); + return oss.str(); } } // namespace printer diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 4c52ce890c9d..5a0d2bd6bbe0 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -144,8 +144,9 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, this->common_prefix = std::move(visitor.common_prefix); } -IRDocsifier::IRDocsifier() { +IRDocsifier::IRDocsifier(const PrinterConfig& cfg) { auto n = make_object(); + n->cfg = cfg; n->dispatch_tokens.push_back(""); data_ = std::move(n); } diff --git a/src/script/printer/printer.cc b/src/script/printer/printer.cc deleted file mode 100644 index 878b380a3717..000000000000 --- a/src/script/printer/printer.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -Default* Default::Instance() { - static Default inst; - return &inst; -} - -TVM_REGISTER_GLOBAL("script.printer.DefaultIRPrefix") - .set_body_typed([](std::string ir, std::string prefix) { Default::Prefix(ir) = prefix; }); -TVM_REGISTER_GLOBAL("script.printer.DefaultBufferDType") - .set_body_typed([](runtime::DataType dtype) { Default::BufferDType() = dtype; }); -TVM_REGISTER_GLOBAL("script.printer.DefaultIntDType").set_body_typed([](runtime::DataType dtype) { - Default::IntDType() = dtype; -}); -TVM_REGISTER_GLOBAL("script.printer.DefaultFloatDType").set_body_typed([](runtime::DataType dtype) { - Default::FloatDType() = dtype; -}); -TVM_REGISTER_GLOBAL("script.printer.VerboseExpr").set_body_typed([](bool verbose_expr) { - Default::VerboseExpr() = verbose_expr; -}); - -} // namespace printer -} // namespace script -} // namespace tvm diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index f78e7037c3e0..a5b8d6609622 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -68,7 +68,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // auto print_single_iter_var = [&](int i) { tir::IterVar iter_var = block->iter_vars[i]; ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i); - ExprDoc rhs = TIR("axis"); + ExprDoc rhs = TIR(d, "axis"); if (iter_var->iter_type == tir::IterVarType::kDataPar) { rhs = rhs->Attr("spatial"); } else if (iter_var->iter_type == tir::IterVarType::kCommReduce) { @@ -128,7 +128,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // binding_paths.push_back(iter_var_p->Attr("iter_type")); binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R"; } - ExprDoc rhs = TIR("axis")->Attr("remap"); + ExprDoc rhs = TIR(d, "axis")->Attr("remap"); ExprDoc binding_str = LiteralDoc::Str(binding_type, NullOpt); binding_str->source_paths = std::move(binding_paths); rhs = rhs->Call({binding_str, ListDoc(loop_var_doc)}); @@ -151,8 +151,9 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // if (realize) { ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool()); if (!tir::is_one(realize->predicate)) { - (*frame)->stmts.push_back(ExprStmtDoc(TIR("where")->Call( - {d->AsDoc(realize->predicate, realize_p->Attr("predicate"))}))); + (*frame)->stmts.push_back(ExprStmtDoc( + TIR(d, "where") + ->Call({d->AsDoc(realize->predicate, realize_p->Attr("predicate"))}))); } } // Step 3. Handle block read/write regions @@ -161,17 +162,17 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // for (int i = 0, n = block->reads.size(); i < n; ++i) { reads.push_back(d->AsDoc(block->reads[i], block_p->Attr("reads")->ArrayIndex(i))); } - (*frame)->stmts.push_back(ExprStmtDoc(TIR("reads")->Call(reads))); + (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "reads")->Call(reads))); Array writes; for (int i = 0, n = block->writes.size(); i < n; ++i) { writes.push_back(d->AsDoc(block->writes[i], block_p->Attr("writes")->ArrayIndex(i))); } - (*frame)->stmts.push_back(ExprStmtDoc(TIR("writes")->Call(writes))); + (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "writes")->Call(writes))); } // Step 4. Handle block attributes if (!block->annotations.empty()) { (*frame)->stmts.push_back(ExprStmtDoc( - TIR("block_attr") + TIR(d, "block_attr") ->Call({d->AsDoc(block->annotations, block_p->Attr("annotations"))}))); } // Step 5. Handle `alloc_buffer` @@ -194,7 +195,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // tir::Stmt init = block->init.value(); With init_frame(d, init); AsDocBody(init, block_p->Attr("init"), init_frame->get(), d); - (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR("init")->Call({}), (*init_frame)->stmts)); + (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR(d, "init")->Call({}), (*init_frame)->stmts)); } // Step 8. Handle block body AsDocBody(block->body, block_p->Attr("body"), frame->get(), d); @@ -205,7 +206,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt)); } return ScopeDoc(NullOpt, - TIR("block") // + TIR(d, "block") // ->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint"))}, kwargs_keys, kwargs_values), (*frame)->stmts); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index b947039b58de..b4429dc9afc9 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -55,7 +55,7 @@ Map BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, // Step 1. Handle `buffer.shape` array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape"); // Step 2. Handle `buffer.dtype` - if (buffer->dtype != Default::BufferDType()) { + if (buffer->dtype != d->cfg->buffer_dtype) { kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, p->Attr("dtype"))); } // Step 3. Handle `buffer.data` @@ -123,7 +123,7 @@ ExprDoc BufferCall(const ExprDoc& prefix, const Map& attrs, Arr ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, const ObjectPath& p, const Frame& frame, const IRDocsifier& d) { - return BufferCall(/*prefix=*/TIR(method), + return BufferCall(/*prefix=*/TIR(d, method), /*attrs=*/BufferAttrs(buffer, p, frame, d), /*args=*/args); } @@ -134,7 +134,7 @@ ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& ExprDoc shape = attrs.Get("shape").value(); ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype, p->Attr("dtype"))); - return TIR("Buffer")->Call({shape, dtype}, {}, {}); + return TIR(d, "Buffer")->Call({shape, dtype}, {}, {}); } Array BufferIndices(const Array& indices, const ObjectPath& p, @@ -251,7 +251,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc prefix = IdDoc(stmt->producer->GetNameHint()); prefix = prefix[BufferSlices(stmt->bounds, p->Attr("bounds"), d)]; - prefix = TIR("ProducerRealize") + prefix = TIR(d, "ProducerRealize") ->Call({prefix, d->AsDoc(stmt->condition, p->Attr("condition"))}); With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 6e0cfd420262..ab91764b6a0b 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -34,7 +34,7 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) ExprDoc rhs = d->AsDoc(type, var_p->Attr("type_annotation")); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } else { - ExprDoc rhs = TIR("var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))}); + ExprDoc rhs = TIR(d, "var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))}); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } } @@ -57,7 +57,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::IterVar var, ObjectPath var_p, IRDocsifier d) -> Doc { - return TIR("iter_var") + return TIR(d, "iter_var") ->Call({ d->AsDoc(var->var, var_p->Attr("var")), d->AsDoc(var->dom, var_p->Attr("dom")), @@ -70,7 +70,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Not node, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); if (a->IsInstance()) { - return TIR("Not")->Call({a}); + return TIR(d, "Not")->Call({a}); } return OperationDoc(OperationDocNode::Kind::kNot, {a}); }); @@ -84,21 +84,22 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype")); ExprDoc value = d->AsDoc(cast->value, p->Attr("value")); - return TIR("Cast")->Call({dtype, value}); + return TIR(d, "Cast")->Call({dtype, value}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Select select, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("Select")->Call({ - d->AsDoc(select->condition, p->Attr("condition")), - d->AsDoc(select->true_value, p->Attr("true_value")), - d->AsDoc(select->false_value, p->Attr("false_value")), - }); + return TIR(d, "Select") + ->Call({ + d->AsDoc(select->condition, p->Attr("condition")), + d->AsDoc(select->true_value, p->Attr("true_value")), + d->AsDoc(select->false_value, p->Attr("false_value")), + }); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Ramp ramp, ObjectPath ramp_p, IRDocsifier d) -> Doc { - return TIR("Ramp")->Call({ + return TIR(d, "Ramp")->Call({ d->AsDoc(ramp->base, ramp_p->Attr("base")), d->AsDoc(ramp->stride, ramp_p->Attr("stride")), LiteralDoc::Int(ramp->lanes, ramp_p->Attr("lanes")), @@ -107,7 +108,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Broadcast bc, ObjectPath bc_p, IRDocsifier d) -> Doc { - return TIR("Broadcast") + return TIR(d, "Broadcast") ->Call({ d->AsDoc(bc->value, bc_p->Attr("value")), LiteralDoc::Int(bc->lanes, bc_p->Attr("lanes")), @@ -117,10 +118,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::Shuffle shuffle, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("Shuffle")->Call({ - d->AsDoc(shuffle->vectors, p->Attr("vectors")), - d->AsDoc(shuffle->indices, p->Attr("indices")), - }); + return TIR(d, "Shuffle") + ->Call({ + d->AsDoc(shuffle->vectors, p->Attr("vectors")), + d->AsDoc(shuffle->indices, p->Attr("indices")), + }); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -152,12 +154,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } ExprDoc id = d->AsDoc(r->identity_element, p->Attr("identity_element")); - return TIR("comm_reducer")->Call({lambda, id}); + return TIR(d, "comm_reducer")->Call({lambda, id}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("let")->Call({ + return TIR(d, "let")->Call({ d->AsDoc(let->var, p->Attr("var")), d->AsDoc(let->value, p->Attr("value")), d->AsDoc(let->body, p->Attr("body")), @@ -194,7 +196,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (op_names.count(GetRef(op)) == 0) { LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; } - prefix = TIR(name); + prefix = TIR(d, name); } else if (const auto* gv = call->op.as()) { prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op")); } else { @@ -217,7 +219,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Any any, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("Any")->Call({}); + return TIR(d, "Any")->Call({}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -228,8 +230,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc axis = d->AsDoc(r->axis, p->Attr("axis")); ExprDoc condition = d->AsDoc(r->condition, p->Attr("condition")); ExprDoc value_index = LiteralDoc::Int(r->value_index, p->Attr("value_index")); - return TIR("reduce")->Call({combiner}, {"source", "init", "axis", "condition", "value_index"}, - {source, init, axis, condition, value_index}); + return TIR(d, "reduce") + ->Call({combiner}, {"source", "init", "axis", "condition", "value_index"}, + {source, init, axis, condition, value_index}); LOG(FATAL) << "ValueError: Reduce should never exist in TIR: " << r; }); @@ -244,7 +247,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ - return TIR(OpString)->Call({a, b}); \ + return TIR(d, OpString)->Call({a, b}); \ }); #define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ @@ -254,7 +257,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ if (a->IsInstance() && b->IsInstance()) { \ - return TIR(OpString)->Call({a, b}); \ + return TIR(d, OpString)->Call({a, b}); \ } \ return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ }); diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 2a81c37061c6..7d21de27a1a2 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -59,7 +59,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) loop_p = loop_p->Attr("body"); } AsDocBody(grid.back()->body, loop_p, (*f).get(), d); - return ForDoc(TupleDoc(lhs), TIR("grid")->Call(rhs), (*f)->stmts); + return ForDoc(TupleDoc(lhs), TIR(d, "grid")->Call(rhs), (*f)->stmts); } // Step 3. If not `T.grid`, print loop kind accordingly ExprDoc lhs = DefineVar(loop->loop_var, *f, d); @@ -81,16 +81,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (loop->annotations.empty()) { prefix = IdDoc("range"); } else { - prefix = TIR("serial"); + prefix = TIR(d, "serial"); } } else if (loop->kind == tir::ForKind::kParallel) { - prefix = TIR("parallel"); + prefix = TIR(d, "parallel"); } else if (loop->kind == tir::ForKind::kUnrolled) { - prefix = TIR("unroll"); + prefix = TIR(d, "unroll"); } else if (loop->kind == tir::ForKind::kVectorized) { - prefix = TIR("vectorized"); + prefix = TIR(d, "vectorized"); } else if (loop->kind == tir::ForKind::kThreadBinding) { - prefix = TIR("thread_binding"); + prefix = TIR(d, "thread_binding"); thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag, loop_p->Attr("thread_binding")); } else { diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index ea7d56e1656d..fbcc2fca3b4b 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -111,7 +111,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 2. Handle `func->attrs` if (func->attrs.defined() && !func->attrs->dict.empty()) { (*frame)->stmts.push_back( - ExprStmtDoc(TIR("func_attr") // + ExprStmtDoc(TIR(d, "func_attr") // ->Call({d->AsDoc(func->attrs, p->Attr("attrs"))}))); } // Step 3. Handle `func->buffer_map` @@ -175,14 +175,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return FunctionDoc( /*name=*/IdDoc(FindFunctionName(d, func)), /*args=*/args, - /*decorators=*/{TIR("prim_func")}, + /*decorators=*/{TIR(d, "prim_func")}, /*return_type=*/ret_type, /*body=*/(*frame)->stmts); }); -void ReprPrintPrimFunc(const ObjectRef& obj, ReprPrinter* p) { - std::string res = DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root())); - p->stream << res; +std::string ReprPrintPrimFunc(const ObjectRef& obj, const PrinterConfig& cfg) { + Doc doc = IRDocsifier(cfg)->AsDoc(obj, ObjectPath::Root()); + return DocToPythonScript(doc, cfg); } TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintPrimFunc); diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index 1214f822610c..76d3680fec81 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -29,12 +29,12 @@ TVM_REGISTER_NODE_TYPE(TIRFrameNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](IntImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; - if (dtype == Default::IntDType()) { + if (dtype == d->cfg->int_dtype) { return LiteralDoc::Int(imm->value, imm_p->Attr("value")); } else if (dtype == DataType::Bool()) { return LiteralDoc::Boolean(imm->value, imm_p->Attr("value")); } else { - return TIR(runtime::DLDataType2String(dtype)) // + return TIR(d, runtime::DLDataType2String(dtype)) // ->Call({LiteralDoc::Int(imm->value, imm_p->Attr("value"))}); } }); @@ -42,26 +42,27 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](FloatImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; - if (dtype == Default::FloatDType()) { + if (dtype == d->cfg->float_dtype) { return LiteralDoc::Float(imm->value, imm_p->Attr("value")); } else { - return TIR(runtime::DLDataType2String(dtype)) // + return TIR(d, runtime::DLDataType2String(dtype)) // ->Call({LiteralDoc::Float(imm->value, imm_p->Attr("value"))}); } }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Range range, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("Range")->Call({ - d->AsDoc(range->min, p->Attr("min")), - d->AsDoc(range->extent, p->Attr("extent")), - }); + return TIR(d, "Range") + ->Call({ + d->AsDoc(range->min, p->Attr("min")), + d->AsDoc(range->extent, p->Attr("extent")), + }); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](PrimType ty, ObjectPath p, IRDocsifier d) -> Doc { std::string dtype = ty->dtype.is_void() ? "void" : runtime::DLDataType2String(ty->dtype); - return TIR(dtype); + return TIR(d, dtype); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -74,9 +75,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) element_type = d->AsDoc(ty->element_type, ty_p->Attr("element_type")); } if (ty->storage_scope == "") { - return TIR("Ptr")->Call({element_type}); + return TIR(d, "Ptr")->Call({element_type}); } else { - return TIR("Ptr")->Call( + return TIR(d, "Ptr")->Call( {element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))}); } }); @@ -86,13 +87,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (ty->fields.empty()) { return LiteralDoc::None(p); } - return TIR("Tuple")->Call(d->AsDoc(ty->fields, p->Attr("fields"))->elements); + return TIR(d, "Tuple")->Call(d->AsDoc(ty->fields, p->Attr("fields"))->elements); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Target target, ObjectPath p, IRDocsifier d) -> Doc { Map config = target->Export(); - return TIR("target")->Call({d->AsDoc(config, p)}); + return TIR(d, "target")->Call({d->AsDoc(config, p)}); }); TVM_SCRIPT_REPR(IntImmNode, ReprPrintTIR); diff --git a/src/script/printer/tir/script_method.cc b/src/script/printer/tir/script_method.cc deleted file mode 100644 index 5cda9a9626db..000000000000 --- a/src/script/printer/tir/script_method.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include - -#include "./utils.h" - -namespace tvm { - -std::string PrimExprNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) const { - using namespace tvm::script::printer; - IRDocsifier d; - ObjectRef obj = GetRef(this); - With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); - return DocToPythonScript(Docsify(obj, d, *f), indent_spaces, print_line_numbers, - num_context_lines, path_to_underline); -} - -namespace tir { - -std::string StmtNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) const { - using namespace tvm::script::printer; - IRDocsifier d; - ObjectRef obj = GetRef(this); - With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); - return DocToPythonScript(Docsify(obj, d, *f), indent_spaces, print_line_numbers, - num_context_lines, path_to_underline); -} - -std::string PrimFuncNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) const { - using namespace tvm::script::printer; - return DocToPythonScript(IRDocsifier()->AsDoc(GetRef(this), ObjectPath::Root()), - indent_spaces, print_line_numbers, num_context_lines, path_to_underline); -} - -TVM_REGISTER_GLOBAL("tir.PrimFuncScript").set_body_method(&PrimFuncNode::Script); -TVM_REGISTER_GLOBAL("tir.StmtScript").set_body_method(&StmtNode::Script); -TVM_REGISTER_GLOBAL("tir.PrimExprScript").set_body_method(&PrimExprNode::Script); - -} // namespace tir -} // namespace tvm diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index acdfd7da472b..2820f9ba6384 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -51,7 +51,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (eval->value->IsInstance()) { return ExprStmtDoc(value); } - return ExprStmtDoc(TIR("evaluate")->Call({value})); + return ExprStmtDoc(TIR(d, "evaluate")->Call({value})); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -75,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); return StmtBlockDoc(*stmts); } else { - rhs = TIR("let")->Call({lhs, rhs}); + rhs = TIR(d, "let")->Call({lhs, rhs}); return ScopeDoc(NullOpt, rhs, *stmts); } }); @@ -93,7 +93,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) stmts->insert(stmts->begin(), AssertDoc(cond, msg)); return StmtBlockDoc(*stmts); } - return ScopeDoc(NullOpt, TIR("Assert")->Call({cond, msg}), (*f)->stmts); + return ScopeDoc(NullOpt, TIR(d, "Assert")->Call({cond, msg}), (*f)->stmts); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -145,7 +145,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::Prefetch stmt, ObjectPath p, IRDocsifier d) -> Doc { - return ExprStmtDoc(TIR("prefetch") + return ExprStmtDoc(TIR(d, "prefetch") ->Call({ d->AsDoc(stmt->buffer, p->Attr("buffer")), d->AsDoc(stmt->bounds, p->Attr("bounds")), @@ -198,7 +198,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } ExprDoc lhs = DefineVar(stmt->buffer_var, d->frames.back(), d); With f(d, stmt); - ExprDoc rhs = TIR("allocate")->Call(args, kwargs_keys, kwargs_values); + ExprDoc rhs = TIR(d, "allocate")->Call(args, kwargs_keys, kwargs_values); AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d); return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); }); @@ -277,7 +277,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) args.push_back(data_doc); args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype"))); args.push_back(d->AsDoc(stmt->extents, stmt_p->Attr("extents"))); - ExprDoc rhs = TIR("allocate_const")->Call(args, kwargs_keys, kwargs_values); + ExprDoc rhs = TIR(d, "allocate_const")->Call(args, kwargs_keys, kwargs_values); With f(d, stmt); ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d); AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d); @@ -310,7 +310,7 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, OptionalAsDoc(stmt->condition, p->Attr("condition"))); } - return TIR("realize")->Call(args, kwargs_keys, kwargs_values); + return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -351,12 +351,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) DefineVar(iter_var->var, f, d); f->stmts.push_back( AssignDoc(d->AsDoc(iter_var->var, iter_var_p->Attr("var")), - TIR("env_thread") + TIR(d, "env_thread") ->Call({LiteralDoc::Str(iter_var->thread_tag, iter_var_p->Attr("thread_tag"))}), // NullOpt)); } - rhs = TIR("launch_thread") + rhs = TIR(d, "launch_thread") ->Call({ d->AsDoc(iter_var->var, stmt_p->Attr("node")), d->AsDoc(stmt->value, stmt_p->Attr("value")), @@ -364,7 +364,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } if (!rhs.defined()) { - rhs = TIR("attr")->Call({ + rhs = TIR(d, "attr")->Call({ d->AsDoc(stmt->node, stmt_p->Attr("node")), LiteralDoc::Str(stmt->attr_key, stmt_p->Attr("attr_key")), d->AsDoc(stmt->value, stmt_p->Attr("value")), diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index e1ffe135229e..88094ee816ca 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -20,7 +20,6 @@ #define TVM_SCRIPT_PRINTER_TIR_UTILS_H_ #include -#include #include #include #include @@ -74,7 +73,9 @@ class TIRFrame : public Frame { }; /*! \brief Creates the TIR common prefix, which is by default `T` */ -inline ExprDoc TIR(const String& attr) { return IdDoc(Default::Prefix("tir"))->Attr(attr); } +inline ExprDoc TIR(const IRDocsifier& d, const String& attr) { + return IdDoc(d->cfg->tir_prefix)->Attr(attr); +} /*! * \brief Defines a variable in the IRDocsifier at the given frame, @@ -187,14 +188,12 @@ inline TIRFrame MakeDispatchFrame(const IRDocsifier& d, const ObjectRef& root, } /*! \brief Redirected method for the ReprPrinter */ -inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) { - IRDocsifier d; +inline std::string ReprPrintTIR(const ObjectRef& obj, const PrinterConfig& cfg) { + IRDocsifier d(cfg); With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); - try { - p->stream << DocToPythonScript(Docsify(obj, d, *f)); - } catch (const tvm::Error& e) { - HandleUnsupportedFallback(e, obj, p); - } + std::ostringstream oss; + oss << Docsify(obj, d, *f, cfg); + return oss.str(); } /*! diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 9f9a7d8299c4..5161f1f9a268 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -20,13 +20,6 @@ #define TVM_SCRIPT_PRINTER_UTILS_H_ #include -#include -#include -#include -#include -#include -#include -#include #include #include @@ -37,13 +30,26 @@ namespace tvm { namespace script { namespace printer { -#define TVM_SCRIPT_REPR(ObjectType, Method) \ - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(Method); +#define TVM_SCRIPT_REPR(ObjectType, Method) \ + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) \ + .set_dispatch(RedirectedReprPrinterMethod); \ + TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter, vtable).set_dispatch(Method); -inline StmtBlockDoc Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f) { +inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) { + try { + p->stream << TVMScriptPrinter::Script(obj, NullOpt); + } catch (const tvm::Error& e) { + LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" + << e.what(); + p->stream << AsLegacyRepr(obj); + } +} + +inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f, + const PrinterConfig& cfg) { Doc doc = d->AsDoc(obj, ObjectPath::Root()); if (const auto* expr_doc = doc.as()) { - if (!Default::VerboseExpr()) { + if (!cfg->verbose_expr) { f->stmts.clear(); } f->stmts.push_back(ExprStmtDoc(GetRef(expr_doc))); @@ -56,14 +62,7 @@ inline StmtBlockDoc Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fr } else { LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); } - return StmtBlockDoc(f->stmts); -} - -inline void HandleUnsupportedFallback(const tvm::Error& error, const ObjectRef& obj, - ReprPrinter* p) { - LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" - << error.what(); - p->stream << AsLegacyRepr(obj); + return DocToPythonScript(StmtBlockDoc(f->stmts), cfg); } } // namespace printer diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index c73ae291930c..71da86bff763 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -15,31 +15,15 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from contextlib import contextmanager - +import tvm.testing from tvm import ir, tir from tvm.ir import Range from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import tir as T -from tvm.script.printer import default -import tvm.testing - - -@contextmanager -def verbose_expr(): - try: - default.verbose_expr(True) - yield - finally: - default.verbose_expr(False) def _assert_print(obj, expected): - with verbose_expr(): - if isinstance(obj, (tir.PrimFunc, tir.PrimExpr, tir.Stmt)): - assert obj.script().strip() == expected.strip() - assert str(obj).strip() == expected.strip() - assert repr(obj).strip() == expected.strip() + assert obj.script(verbose_expr=True).strip() == expected.strip() def test_prim_func():