From 8b5517890907fa26591358eafda5288d70660548 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 22 Jan 2023 23:15:16 -0800 Subject: [PATCH] [TVMScript] Introduce `PrinterConfig` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces `PrinterConfig`, a systematic way to configure TVMScript printer without having to set global flags. This PR enables more customization of printer behavior. More specifically, now any TVM’s object in python, as long as it inherits from `Scriptable`, it automatically gains two methods: - `.script(tir_prefix=...)` - `.show(...)` --- include/tvm/ir/expr.h | 11 +- include/tvm/ir/module.h | 11 +- include/tvm/node/repr_printer.h | 1 + include/tvm/node/script_printer.h | 105 +++++++++ include/tvm/script/printer/doc.h | 8 + include/tvm/script/printer/ir_docsifier.h | 4 +- include/tvm/script/printer/printer.h | 76 ------ include/tvm/tir/function.h | 11 +- include/tvm/tir/stmt.h | 11 +- python/tvm/ir/expr.py | 4 +- python/tvm/ir/module.py | 80 +------ python/tvm/ir/type.py | 3 +- python/tvm/runtime/__init__.py | 1 + python/tvm/runtime/script_printer.py | 218 ++++++++++++++++++ python/tvm/script/printer/__init__.py | 1 - python/tvm/script/printer/default.py | 83 ------- python/tvm/script/printer/doc_printer.py | 15 +- python/tvm/tir/buffer.py | 10 +- python/tvm/tir/expr.py | 83 +------ python/tvm/tir/function.py | 79 +------ python/tvm/tir/stmt.py | 83 +------ src/node/script_printer.cc | 79 +++++++ .../printer/doc_printer/base_doc_printer.cc | 19 +- .../printer/doc_printer/base_doc_printer.h | 26 +-- .../printer/doc_printer/python_doc_printer.cc | 17 +- src/script/printer/ir/ir.cc | 31 ++- src/script/printer/ir/script_method.cc | 34 --- src/script/printer/ir/utils.h | 19 +- src/script/printer/ir_docsifier.cc | 3 +- src/script/printer/printer.cc | 47 ---- src/script/printer/tir/block.cc | 19 +- src/script/printer/tir/buffer.cc | 8 +- src/script/printer/tir/expr.cc | 49 ++-- src/script/printer/tir/for_loop.cc | 12 +- src/script/printer/tir/function.cc | 10 +- src/script/printer/tir/ir.cc | 27 +-- src/script/printer/tir/script_method.cc | 59 ----- src/script/printer/tir/stmt.cc | 20 +- src/script/printer/tir/utils.h | 17 +- src/script/printer/utils.h | 37 ++- .../unittest/test_tvmscript_printer_tir.py | 20 +- 41 files changed, 602 insertions(+), 849 deletions(-) create mode 100644 include/tvm/node/script_printer.h delete mode 100644 include/tvm/script/printer/printer.h create mode 100644 python/tvm/runtime/script_printer.py delete mode 100644 python/tvm/script/printer/default.py create mode 100644 src/node/script_printer.cc delete mode 100644 src/script/printer/ir/script_method.cc delete mode 100644 src/script/printer/printer.cc delete mode 100644 src/script/printer/tir/script_method.cc diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index bfbaa7cddd4f..78c09e81b16f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -100,16 +100,7 @@ class PrimExprNode : public BaseExprNode { */ DataType dtype; - /*! - * \brief Returns the TVMScript format - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ - TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false, - int num_context_lines = -1, - Optional path_to_underline = NullOpt) const; + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "PrimExpr"; static constexpr const uint32_t _type_child_slots = 38; diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 4cd357d4180b..0a5bac182fd9 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -328,16 +328,7 @@ class IRModuleNode : public Object { */ TVM_DLL std::unordered_set Imports() const; - /*! - * \brief Returns the TVMScript format - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ - TVM_DLL std::string Script(int indent_spaces = 4, bool print_line_numbers = false, - int num_context_lines = -1, - Optional path_to_underline = NullOpt) const; + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "IRModule"; static constexpr const bool _type_has_method_sequal_reduce = true; diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index e3f59fcc14a1..2a2d0bf3fb05 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -24,6 +24,7 @@ #define TVM_NODE_REPR_PRINTER_H_ #include +#include #include #include diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h new file mode 100644 index 000000000000..af50aae71a43 --- /dev/null +++ b/include/tvm/node/script_printer.h @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/node/repr_printer.h + * \brief Printer class to print repr string of each AST/IR nodes. + */ +#ifndef TVM_NODE_SCRIPT_PRINTER_H_ +#define TVM_NODE_SCRIPT_PRINTER_H_ + +#include +#include +#include +#include + +#include +#include + +namespace tvm { + +class PrinterConfigNode : public Object { + public: + /*! \brief The prefix of IR nodes */ + std::string ir_prefix = "I"; + /*! \brief The prefix of TIR nodes */ + std::string tir_prefix = "T"; + /*! \brief The prefix of Relax nodes */ + std::string relax_prefix = "R"; + /*! \brief Default data type of TIR buffer */ + DataType buffer_dtype = DataType::Float(32); + /*! \brief Default data type of integer literals */ + DataType int_dtype = DataType::Int(32); + /*! + * \brief Default data type of float literals. Right now we always print out the explicit type + * of floating point values, so setting it to Void means we do not print without the + * T.float32/T.float64 wrapper. + */ + DataType float_dtype = DataType::Void(); + /*! \brief Whether or not to verbose print expressions. */ + bool verbose_expr = false; + /* \brief Number of spaces used for indentation*/ + int indent_spaces = 4; + /* \brief Whether to print line numbers */ + bool print_line_numbers = false; + /* \brief Number of context lines to print around the underlined text */ + int num_context_lines = -1; + /* \brief Object path to be underlined */ + Optional path_to_underline = NullOpt; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("ir_prefix", &ir_prefix); + v->Visit("buffer_dtype", &buffer_dtype); + v->Visit("int_dtype", &int_dtype); + v->Visit("float_dtype", &float_dtype); + v->Visit("verbose_expr", &verbose_expr); + v->Visit("indent_spaces", &indent_spaces); + v->Visit("print_line_numbers", &print_line_numbers); + v->Visit("num_context_lines", &num_context_lines); + v->Visit("path_to_underline", &path_to_underline); + } + + static constexpr const char* _type_key = "node.PrinterConfig"; + TVM_DECLARE_FINAL_OBJECT_INFO(PrinterConfigNode, Object); +}; + +class PrinterConfig : public ObjectRef { + public: + explicit PrinterConfig(Map config_dict = Map()); + + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(PrinterConfig, runtime::ObjectRef, + PrinterConfigNode); +}; + +/*! \brief Legacy behavior of ReprPrinter. */ +class TVMScriptPrinter { + public: + /* Convert the object to TVMScript format */ + static std::string Script(const ObjectRef& node, const Optional& cfg); + // Allow registration to be printer. + using FType = NodeFunctor; + TVM_DLL static FType& vtable(); +}; + +#define TVM_OBJECT_ENABLE_SCRIPT_PRINTER() \ + std::string Script(const Optional& config = NullOpt) const { \ + return TVMScriptPrinter::Script(GetRef(this), config.value_or(PrinterConfig())); \ + } + +} // namespace tvm +#endif // TVM_NODE_SCRIPT_PRINTER_H_ diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 01f0fc1f4a91..6504e2c2843d 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -29,8 +29,16 @@ namespace tvm { namespace script { namespace printer { +// Forward declaration class Doc; +/*! + * \brief Convert Doc into Python script. + * \param doc Doc to be converted + * \param cfg The configuration of the printer + */ +String DocToPythonScript(Doc doc, const PrinterConfig& cfg); + /*! * \brief The base class of all Doc. * diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index e0419b469505..67fa96ef8082 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -126,6 +126,8 @@ class IRDocsifierNode : public Object { /*! \brief The name of the variable */ Optional name; }; + /*! \brief The configuration of the printer */ + PrinterConfig cfg{nullptr}; /*! * \brief The stack of frames. * \sa FrameNode @@ -232,7 +234,7 @@ class IRDocsifier : public ObjectRef { public: using FType = IRDocsifierFunctor; /*! \brief Create a IRDocsifier. */ - IRDocsifier(); + explicit IRDocsifier(const PrinterConfig& cfg); /*! \brief The registration table for IRDocsifier. */ TVM_DLL static FType& vtable(); diff --git a/include/tvm/script/printer/printer.h b/include/tvm/script/printer/printer.h deleted file mode 100644 index b373a2be73fb..000000000000 --- a/include/tvm/script/printer/printer.h +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#ifndef TVM_SCRIPT_PRINTER_PRINTER_H_ -#define TVM_SCRIPT_PRINTER_PRINTER_H_ - -#include -#include - -#include -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -/*! \brief Default values in the TVMScript printer */ -struct Default { - /*! \brief The prefix of IR nodes */ - std::unordered_map ir_prefix = {{"ir", "I"}, {"tir", "T"}}; - /*! \brief Default data type of TIR buffer */ - DataType buffer_dtype = DataType::Float(32); - /*! \brief Default data type of integer literals */ - DataType int_dtype = DataType::Int(32); - /*! - * \brief Default data type of float literals. Right now we always print out the explicit type - * of floating point values, so setting it to Void means we do not print without the - * T.float32/T.float64 wrapper. - */ - DataType float_dtype = DataType::Void(); - /*! \brief Whether or not to verbose print expressions. */ - bool verbose_expr = false; - /*! \brief Returns a singleton of the configuration */ - static Default* Instance(); - static std::string& Prefix(const std::string& ir) { return Instance()->ir_prefix.at(ir); } - static DataType& BufferDType() { return Instance()->buffer_dtype; } - static DataType& IntDType() { return Instance()->int_dtype; } - static DataType& FloatDType() { return Instance()->float_dtype; } - static bool& VerboseExpr() { return Instance()->verbose_expr; } -}; - -/*! - * \brief Convert Doc into Python script. - * \param doc Doc to be converted - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ -String DocToPythonScript(Doc doc, // - int indent_spaces = 4, // - bool print_line_numbers = false, // - int num_context_lines = -1, // - Optional path_to_underline = NullOpt); - -} // namespace printer -} // namespace script -} // namespace tvm - -#endif // TVM_SCRIPT_PRINTER_PRINTER_H_ diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 17e7de930260..e135c261990b 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -132,16 +132,7 @@ class PrimFuncNode : public BaseFuncNode { */ TVM_DLL FuncType func_type_annotation() const; - /*! - * \brief Returns the TVMScript format - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ - std::string Script(int indent_spaces = 4, bool print_line_numbers = false, - int num_context_lines = -1, - Optional path_to_underline = NullOpt) const; + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "tir.PrimFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncNode, BaseFuncNode); diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index e0b7bcc868b3..7a7ad2acedd7 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -46,16 +46,7 @@ class StmtNode : public Object { StmtNode() = default; explicit StmtNode(Span span) : span(span) {} - /*! - * \brief Returns the TVMScript format - * \param indent_spaces Number of spaces used for indentation - * \param print_line_numbers Whether to print line numbers - * \param num_context_lines Number of context lines to print around the underlined text - * \param path_to_underline Object path to be underlined - */ - std::string Script(int indent_spaces = 4, bool print_line_numbers = false, - int num_context_lines = -1, - Optional path_to_underline = NullOpt) const; + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "tir.Stmt"; static constexpr const bool _type_has_method_sequal_reduce = true; diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 52af8407b7a0..3c3fefb6d6c6 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -17,7 +17,7 @@ """Common expressions data structures in the IR.""" import tvm._ffi -from ..runtime import const, convert +from ..runtime import Scriptable, const, convert from . import _ffi_api from .base import Node @@ -121,7 +121,7 @@ def astext(self, show_meta_data=True, annotate=None): @tvm._ffi.register_object -class Range(Node): +class Range(Node, Scriptable): """Represent a range in TVM. You do not need to create a Range explicitly. diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 51410049ec74..3daffb2640c5 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -15,10 +15,9 @@ # specific language governing permissions and limitations # under the License. """IRModule that holds the functions and type definitions.""" -from typing import Optional - import tvm._ffi from tvm._ffi.base import string_types +from tvm.runtime import Scriptable from . import _ffi_api from . import expr as _expr @@ -27,7 +26,7 @@ @tvm._ffi.register_object("IRModule") -class IRModule(Node): +class IRModule(Node, Scriptable): """IRModule that holds functions and type definitions. IRModule is the basic unit for all IR transformations across the stack. @@ -314,78 +313,3 @@ def astext(self, show_meta_data=True, annotate=None): from tvm.relay import astext # pylint: disable=import-outside-toplevel return astext(self, show_meta_data, annotate) - - def script( - self, - *, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.Module_Script( # type: ignore # pylint: disable=no-member - self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline - ) - - def show( - self, - *, - style: Optional[str] = None, - black_format: bool = True, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - """ - from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel - cprint, - ) - - cprint( - self.script( - indent_spaces=indent_spaces, - print_line_numbers=print_line_numbers, - num_context_lines=num_context_lines, - path_to_underline=path_to_underline, - ), - style=style, - black_format=black_format, - ) diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index ea06aeda2030..c83cef3f6cea 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -19,12 +19,13 @@ import tvm import tvm._ffi +from tvm.runtime import Scriptable from . import _ffi_api from .base import Node -class Type(Node): +class Type(Node, Scriptable): """The base class of all types.""" def __eq__(self, other): diff --git a/python/tvm/runtime/__init__.py b/python/tvm/runtime/__init__.py index 502de7372154..71f71e6c8427 100644 --- a/python/tvm/runtime/__init__.py +++ b/python/tvm/runtime/__init__.py @@ -20,6 +20,7 @@ from .packed_func import PackedFunc from .object import Object from .object_path import ObjectPath, ObjectPathPair +from .script_printer import Scriptable from .object_generic import ObjectGeneric, ObjectTypes from .ndarray import NDArray, DataType, DataTypeCode, Device from .module import Module, num_threads diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py new file mode 100644 index 000000000000..23144c47f1ee --- /dev/null +++ b/python/tvm/runtime/script_printer.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Configuration of TVMScript printer""" +from typing import Optional + +from tvm._ffi import register_object +from tvm.runtime import Object + +from . import _ffi_node_api +from .object_path import ObjectPath + + +@register_object("node.PrinterConfig") +class PrinterConfig(Object): + """Configuration of TVMScript printer""" + + ir_prefix: str + tir_prefix: str + relax_prefix: str + buffer_dtype: str + int_dtype: str + float_dtype: str + verbose_expr: bool + indent_spaces: int + print_line_numbers: bool + num_context_lines: int + path_to_underline: Optional[ObjectPath] + + def __init__( + self, + *, + ir_prefix: str = "I", + tir_prefix: str = "T", + relax_prefix: str = "R", + buffer_dtype: str = "float32", + int_dtype: str = "int32", + float_dtype: str = "void", + verbose_expr: bool = False, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: Optional[int] = None, + path_to_underline: Optional[ObjectPath] = None, + ) -> None: + if num_context_lines is None: + num_context_lines = -1 + self.__init_handle_by_constructor__( + _ffi_node_api.PrinterConfig, # type: ignore # pylint: disable=no-member + { + "ir_prefix": ir_prefix, + "tir_prefix": tir_prefix, + "relax_prefix": relax_prefix, + "buffer_dtype": buffer_dtype, + "int_dtype": int_dtype, + "float_dtype": float_dtype, + "verbose_expr": verbose_expr, + "indent_spaces": indent_spaces, + "print_line_numbers": print_line_numbers, + "num_context_lines": num_context_lines, + "path_to_underline": path_to_underline, + }, + ) + + +def _script(obj: Object, config: PrinterConfig) -> str: + return _ffi_node_api.TVMScriptPrinterScript(obj, config) # type: ignore # pylint: disable=no-member + + +class Scriptable: + """A base class that enables the script() and show() method.""" + + def script( + self, + *, + ir_prefix: str = "I", + tir_prefix: str = "T", + relax_prefix: str = "R", + buffer_dtype: str = "float32", + int_dtype: str = "int32", + float_dtype: str = "void", + verbose_expr: bool = False, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: int = -1, + path_to_underline: Optional[ObjectPath] = None, + ) -> str: + """Print TVM IR into TVMScript text format + + Parameters + ---------- + ir_prefix : str = "I" + The prefix of AST nodes from tvm.ir + tir_prefix : str = "T" + The prefix of AST nodes from tvm.tir + relax_prefix : str = "R" + The prefix of AST nodes from tvm.relax + buffer_dtype : str = "float32" + The default data type of buffer + int_dtype : str = "int32" + The default data type of integer + float_dtype : str = "void" + The default data type of float + verbose_expr : bool = False + Whether to print the detailed definition of each variable in the expression + indent_spaces : int = 4 + The number of spaces for indentation + print_line_numbers : bool = False + Whether to print line numbers + num_context_lines : int = -1 + The number of lines of context to print before and after the line to underline. + path_to_underline : Optional[ObjectPath] = None + Object path to be underlined + + Returns + ------- + script : str + The TVM Script of the given TVM IR + """ + return _script( + self, + PrinterConfig( + ir_prefix=ir_prefix, + tir_prefix=tir_prefix, + relax_prefix=relax_prefix, + buffer_dtype=buffer_dtype, + int_dtype=int_dtype, + float_dtype=float_dtype, + verbose_expr=verbose_expr, + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + ) + + def show( + self, + style: Optional[str] = None, + black_format: bool = True, + *, + ir_prefix: str = "I", + tir_prefix: str = "T", + relax_prefix: str = "R", + buffer_dtype: str = "float32", + int_dtype: str = "int32", + float_dtype: str = "void", + verbose_expr: bool = False, + indent_spaces: int = 4, + print_line_numbers: bool = False, + num_context_lines: int = -1, + path_to_underline: Optional[ObjectPath] = None, + ) -> None: + """A sugar for print highlighted TVM script. + + Parameters + ---------- + style : str, optional + Pygmentize printing style, auto-detected if None. See + `tvm.script.highlight.cprint` for more details. + black_format: bool + If true (default), use the formatter Black to format the TVMScript + ir_prefix : str = "I" + The prefix of AST nodes from tvm.ir + tir_prefix : str = "T" + The prefix of AST nodes from tvm.tir + relax_prefix : str = "R" + The prefix of AST nodes from tvm.relax + buffer_dtype : str = "float32" + The default data type of buffer + int_dtype : str = "int32" + The default data type of integer + float_dtype : str = "void" + The default data type of float + verbose_expr : bool = False + Whether to print the detailed definition of each variable in the expression + indent_spaces : int = 4 + The number of spaces for indentation + print_line_numbers : bool = False + Whether to print line numbers + num_context_lines : int = -1 + The number of lines of context to print before and after the line to underline. + path_to_underline : Optional[ObjectPath] = None + Object path to be underlined + """ + from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel + cprint, + ) + + cprint( + self.script( + ir_prefix=ir_prefix, + tir_prefix=tir_prefix, + relax_prefix=relax_prefix, + buffer_dtype=buffer_dtype, + int_dtype=int_dtype, + float_dtype=float_dtype, + verbose_expr=verbose_expr, + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, + ), + style=style, + black_format=black_format, + ) diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py index 01d89dacbf52..8d2f73bb2b8d 100644 --- a/python/tvm/script/printer/__init__.py +++ b/python/tvm/script/printer/__init__.py @@ -19,4 +19,3 @@ This package provides a set of APIs to print supported TVM IR into TVMScript in a roundtrippable way. """ -from . import default diff --git a/python/tvm/script/printer/default.py b/python/tvm/script/printer/default.py deleted file mode 100644 index 33ca693ebf32..000000000000 --- a/python/tvm/script/printer/default.py +++ /dev/null @@ -1,83 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""The printer configuration""" -from typing_extensions import Literal - -from . import _ffi_api - - -def ir_prefix( # pylint: disable=invalid-name - ir: Literal["ir", "tir"], - prefix: str, -) -> None: - """Set the prefix for the IR. If not set, the prefix for "tvm.ir" is "I", and for "tir" is "T. - - Parameters - ---------- - ir : str - The IR type, either "ir" or "tir". - - prefix : str - The prefix to use. - """ - _ffi_api.DefaultIRPrefix(ir, prefix) # type: ignore # pylint: disable=no-member - - -def buffer_dtype(dtype: str) -> None: - """Set the default dtype for buffer. If not set, it is "float32". - - Parameters - ---------- - dtype : str - The default dtype for buffer. - """ - _ffi_api.DefaultBufferDtype(dtype) # type: ignore # pylint: disable=no-member - - -def int_dtype(dtype: str) -> None: - """Set the default dtype for integers. If not set, it is "int32". - - Parameters - ---------- - dtype : str - The default dtype for buffer. - """ - _ffi_api.DefaultBufferDtype(dtype) # type: ignore # pylint: disable=no-member - - -def float_dtype(dtype: str) -> None: - """Set the default dtype for buffer. If not set, there is no default, - which means every floating point numbers will be wrapped with its precise dtype. - - Parameters - ---------- - dtype : str - The default dtype for buffer. - """ - _ffi_api.DefaultFloatDtype(dtype) # type: ignore # pylint: disable=no-member - - -def verbose_expr(verbose: bool) -> None: - """Whether or not to verbose print expressions. If not, the definition of every variable in an - expression will be printed as separate statements. Otherwise, the result will be a one-liner. - - Parameters - ---------- - dtype : str - The default dtype for buffer. - """ - _ffi_api.VerboseExpr(verbose) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py index 1791f46b00a2..137b71a77d9f 100644 --- a/python/tvm/script/printer/doc_printer.py +++ b/python/tvm/script/printer/doc_printer.py @@ -17,7 +17,10 @@ """Functions to print doc into text format""" from typing import Optional -from tvm.runtime.object_path import ObjectPath + +from tvm.runtime import ObjectPath +from tvm.runtime.script_printer import PrinterConfig + from . import _ffi_api from .doc import Doc @@ -49,8 +52,10 @@ def to_python_script( script : str The text representation of Doc in Python syntax """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.DocToPythonScript( # type: ignore - doc, indent_spaces, print_line_numbers, num_context_lines, path_to_underline + cfg = PrinterConfig( + indent_spaces=indent_spaces, + print_line_numbers=print_line_numbers, + num_context_lines=num_context_lines, + path_to_underline=path_to_underline, ) + return _ffi_api.DocToPythonScript(doc, cfg) # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index c2c158c77f78..11db28e20a1c 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -20,13 +20,13 @@ import tvm._ffi from tvm._ffi.base import string_types from tvm.ir import PointerType, PrimExpr, PrimType, Range -from tvm.runtime import Object, convert +from tvm.runtime import Object, Scriptable, convert from . import _ffi_api @tvm._ffi.register_object("tir.Buffer") -class Buffer(Object): +class Buffer(Object, Scriptable): """Symbolic data buffer in TVM. Buffer provide a way to represent data layout @@ -179,7 +179,11 @@ def offset_of(self, indices): def __getitem__(self, indices): from ..arith import Analyzer # pylint: disable=import-outside-toplevel - from .expr import BufferLoad, Ramp, const # pylint: disable=import-outside-toplevel + from .expr import ( # pylint: disable=import-outside-toplevel + BufferLoad, + Ramp, + const, + ) from .stmt import BufferRegion # pylint: disable=import-outside-toplevel if not isinstance(indices, (tuple, list)): diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index dab7a175185d..cb4a892ac289 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -34,7 +34,7 @@ from tvm import ir from tvm.ir import Op, PrimExpr from tvm.ir.base import Span -from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, const +from tvm.runtime import DataType, DataTypeCode, Object, ObjectGeneric, Scriptable, const from . import _ffi_api from . import generic as _generic @@ -318,88 +318,13 @@ def asobject(self): return IntImm("int32", self.value, self.span) # type: ignore -class PrimExprWithOp(ExprOp, PrimExpr): +class PrimExprWithOp(ExprOp, PrimExpr, Scriptable): """Helper base class to inherit from PrimExpr.""" # In Python3, We have to explicitly tell interpreter to retain __hash__ if we overide __eq__ # https://docs.python.org/3.1/reference/datamodel.html#object.__hash__ __hash__ = PrimExpr.__hash__ - def script( - self, - *, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.PrimExprScript( # type: ignore # pylint: disable=no-member - self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline - ) - - def show( - self, - *, - style: Optional[str] = None, - black_format: bool = True, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - """ - from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel - cprint, - ) - - cprint( - self.script( - indent_spaces=indent_spaces, - print_line_numbers=print_line_numbers, - num_context_lines=num_context_lines, - path_to_underline=path_to_underline, - ), - style=style, - black_format=black_format, - ) - class ConstExpr(PrimExprWithOp): pass @@ -460,7 +385,7 @@ def __init__(self, name, dtype, span=None): @tvm._ffi.register_object("tir.IterVar") -class IterVar(Object, ExprOp): +class IterVar(Object, ExprOp, Scriptable): """Represent iteration variable. IterVar represents axis iterations in the computation. @@ -521,7 +446,7 @@ def __init__(self, dom, var, iter_type, thread_tag="", span=None): @tvm._ffi.register_object("tir.CommReducer") -class CommReducer(Object): +class CommReducer(Object, Scriptable): """Commutative reduce operator Parameters diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index fb5a37c5dc17..f854e56ad11a 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -24,7 +24,7 @@ import tvm._ffi import tvm.runtime from tvm.ir import BaseFunc, Range -from tvm.runtime import Object +from tvm.runtime import Object, Scriptable from ..runtime.ndarray import NDArray from . import _ffi_api @@ -33,7 +33,7 @@ @tvm._ffi.register_object("tir.PrimFunc") -class PrimFunc(BaseFunc): +class PrimFunc(BaseFunc, Scriptable): """A function declaration expression. Parameters @@ -170,81 +170,6 @@ def mem_copy_16_16(a: T.handle, b: T.handle) -> None: """ return _ffi_api.Specialize(self, param_map) # type: ignore - def script( - self, - *, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.PrimFuncScript( # type: ignore # pylint: disable=no-member - self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline - ) - - def show( - self, - *, - style: Optional[str] = None, - black_format: bool = True, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - """ - from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel - cprint, - ) - - cprint( - self.script( - indent_spaces=indent_spaces, - print_line_numbers=print_line_numbers, - num_context_lines=num_context_lines, - path_to_underline=path_to_underline, - ), - style=style, - black_format=black_format, - ) - @tvm._ffi.register_object("tir.TensorIntrin") class TensorIntrin(Object): diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 096c13653a94..d6cd06a1d915 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -31,91 +31,16 @@ import tvm._ffi from tvm.ir import PrimExpr, Range, Span -from tvm.runtime import Object, const +from tvm.runtime import Object, Scriptable, const from . import _ffi_api from .buffer import Buffer from .expr import IterVar -class Stmt(Object): +class Stmt(Object, Scriptable): """Base class of all the statements.""" - def script( - self, - *, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> str: - """Print IRModule into TVMScript - - Parameters - ---------- - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - - Returns - ------- - script : str - The TVM Script of the IRModule - """ - if num_context_lines is None: - num_context_lines = -1 - return _ffi_api.StmtScript( # type: ignore # pylint: disable=no-member - self, indent_spaces, print_line_numbers, num_context_lines, path_to_underline - ) - - def show( - self, - *, - style: Optional[str] = None, - black_format: bool = True, - indent_spaces: int = 4, - print_line_numbers: bool = False, - num_context_lines: Optional[int] = None, - path_to_underline=None, - ) -> None: - """A sugar for print highlighted TVM script. - - Parameters - ---------- - style : str, optional - Pygmentize printing style, auto-detected if None. See - `tvm.script.highlight.cprint` for more details. - black_format: bool - If true (default), use the formatter Black to format the TVMScript - indent_spaces : int - The number of indent spaces to use in the output - print_line_numbers: bool - Whether to print line numbers - num_context_lines : Optional[int] - Number of context lines to print around the underlined text - path_to_underline : Optional[ObjectPath] - Object path to be underlined - """ - from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel - cprint, - ) - - cprint( - self.script( - indent_spaces=indent_spaces, - print_line_numbers=print_line_numbers, - num_context_lines=num_context_lines, - path_to_underline=path_to_underline, - ), - style=style, - black_format=black_format, - ) - @tvm._ffi.register_object("tir.LetStmt") class LetStmt(Stmt): @@ -623,7 +548,7 @@ def __init__(self, buffer, bounds, span=None): @tvm._ffi.register_object("tir.BufferRegion") -class BufferRegion(Object): +class BufferRegion(Object, Scriptable): """BufferRegion node. Parameters @@ -643,7 +568,7 @@ def __init__(self, buffer: Buffer, region: List[Range]): @tvm._ffi.register_object("tir.MatchBufferRegion") -class MatchBufferRegion(Object): +class MatchBufferRegion(Object, Scriptable): """MatchBufferRegion node. Parameters diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc new file mode 100644 index 000000000000..605d5208462f --- /dev/null +++ b/src/node/script_printer.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include + +namespace tvm { + +TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { + static FType inst; + return inst; +} + +std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional& cfg) { + return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig())); +} + +PrinterConfig::PrinterConfig(Map config_dict) { + runtime::ObjectPtr n = make_object(); + if (auto v = config_dict.Get("ir_prefix")) { + n->ir_prefix = Downcast(v); + } + if (auto v = config_dict.Get("tir_prefix")) { + n->tir_prefix = Downcast(v); + } + if (auto v = config_dict.Get("relax_prefix")) { + n->relax_prefix = Downcast(v); + } + if (auto v = config_dict.Get("buffer_dtype")) { + n->buffer_dtype = DataType(runtime::String2DLDataType(Downcast(v))); + } + if (auto v = config_dict.Get("int_dtype")) { + n->int_dtype = DataType(runtime::String2DLDataType(Downcast(v))); + } + if (auto v = config_dict.Get("float_dtype")) { + n->float_dtype = DataType(runtime::String2DLDataType(Downcast(v))); + } + if (auto v = config_dict.Get("verbose_expr")) { + n->verbose_expr = Downcast(v)->value; + } + if (auto v = config_dict.Get("indent_spaces")) { + n->indent_spaces = Downcast(v)->value; + } + if (auto v = config_dict.Get("print_line_numbers")) { + n->print_line_numbers = Downcast(v)->value; + } + if (auto v = config_dict.Get("num_context_lines")) { + n->num_context_lines = Downcast(v)->value; + } + if (auto v = config_dict.Get("path_to_underline")) { + n->path_to_underline = Downcast(v); + } + this->data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(PrinterConfigNode); +TVM_REGISTER_GLOBAL("node.PrinterConfig").set_body_typed([](Map config_dict) { + return PrinterConfig(config_dict); +}); +TVM_REGISTER_GLOBAL("node.TVMScriptPrinterScript").set_body_typed(TVMScriptPrinter::Script); + +} // namespace tvm diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index 38b8ef897740..a3a5c06ede0d 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -77,13 +77,13 @@ ByteSpan PopNextUnderline(UnderlineIter* next_underline, UnderlineIter end_under void PrintChunk(const std::pair& lines_range, const std::pair& underlines, const std::string& text, - const std::vector& line_starts, const DocPrinterOptions& options, + const std::vector& line_starts, const PrinterConfig& options, size_t line_number_width, std::string* out) { UnderlineIter next_underline = underlines.first; ByteSpan current_underline = PopNextUnderline(&next_underline, underlines.second); for (size_t line_idx = lines_range.first; line_idx < lines_range.second; ++line_idx) { - if (options.print_line_numbers) { + if (options->print_line_numbers) { std::string line_num_str = std::to_string(line_idx + 1); line_num_str.push_back(' '); for (size_t i = line_num_str.size(); i < line_number_width; ++i) { @@ -148,12 +148,12 @@ void PrintCut(size_t num_lines_skipped, std::string* out) { std::pair GetLinesForUnderline(const ByteSpan& underline, const std::vector& line_starts, - size_t num_lines, const DocPrinterOptions& options) { + size_t num_lines, const PrinterConfig& options) { size_t first_line_of_underline = GetLineIndex(underline.first, line_starts); - size_t first_line_of_chunk = MoveBack(first_line_of_underline, options.num_context_lines); + size_t first_line_of_chunk = MoveBack(first_line_of_underline, options->num_context_lines); size_t end_line_of_underline = GetLineIndex(underline.second - 1, line_starts) + 1; size_t end_line_of_chunk = - MoveForward(end_line_of_underline, options.num_context_lines, num_lines); + MoveForward(end_line_of_underline, options->num_context_lines, num_lines); return {first_line_of_chunk, end_line_of_chunk}; } @@ -181,8 +181,8 @@ size_t GetNumLines(const std::string& text, const std::vector& line_star } } -size_t GetLineNumberWidth(size_t num_lines, const DocPrinterOptions& options) { - if (options.print_line_numbers) { +size_t GetLineNumberWidth(size_t num_lines, const PrinterConfig& options) { + if (options->print_line_numbers) { return std::to_string(num_lines).size() + 1; } else { return 0; @@ -190,8 +190,7 @@ size_t GetLineNumberWidth(size_t num_lines, const DocPrinterOptions& options) { } std::string DecorateText(const std::string& text, const std::vector& line_starts, - const DocPrinterOptions& options, - const std::vector& underlines) { + const PrinterConfig& options, const std::vector& underlines) { size_t num_lines = GetNumLines(text, line_starts); size_t line_number_width = GetLineNumberWidth(num_lines, options); @@ -237,7 +236,7 @@ std::string DecorateText(const std::string& text, const std::vector& lin } // anonymous namespace -DocPrinter::DocPrinter(const DocPrinterOptions& options) : options_(options) { +DocPrinter::DocPrinter(const PrinterConfig& options) : options_(options) { line_starts_.push_back(0); } diff --git a/src/script/printer/doc_printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h index db1d733d96ad..7851ce061b0d 100644 --- a/src/script/printer/doc_printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -35,23 +35,6 @@ namespace printer { /*! \brief Range of byte offsets in a string */ using ByteSpan = std::pair; -/*! \brief Options to customize DocPrinter's output */ -struct DocPrinterOptions { - /*! \brief Number of spaces for one level of indentation */ - int indent_spaces = 4; - - /*! \brief Whether to print the line numbers */ - bool print_line_numbers = false; - - /*! - * \brief Number of context lines to print around the underlined text. - * - * If set to a non-default value `n`, only print `n` context lines before and after - * the underlined pieces of text. - */ - size_t num_context_lines = std::numeric_limits::max(); -}; - /*! * \brief DocPrinter is responsible for printing Doc tree into text format * \details This is the base class for translating Doc into string. @@ -67,7 +50,8 @@ class DocPrinter { * * \param options the option for printer */ - explicit DocPrinter(const DocPrinterOptions& options); + explicit DocPrinter(const PrinterConfig& options); + virtual ~DocPrinter() = default; /*! @@ -224,13 +208,13 @@ class DocPrinter { * \brief Increase the indent level of any content to be * printed after this call */ - void IncreaseIndent() { indent_ += options_.indent_spaces; } + void IncreaseIndent() { indent_ += options_->indent_spaces; } /*! * \brief Decrease the indent level of any content to be * printed after this call */ - void DecreaseIndent() { indent_ -= options_.indent_spaces; } + void DecreaseIndent() { indent_ -= options_->indent_spaces; } /*! * \brief Add a new line into the output stream @@ -258,7 +242,7 @@ class DocPrinter { void MarkSpan(const ByteSpan& span, const ObjectPath& path); /*! \brief Options to customize certain aspects of the output */ - DocPrinterOptions options_; + PrinterConfig options_; /*! \brief the current level of indent */ int indent_ = 0; diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 8634236df5c3..ce6b8e7f423c 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -142,7 +142,7 @@ ExprPrecedence GetExprPrecedence(const ExprDoc& doc) { class PythonDocPrinter : public DocPrinter { public: - explicit PythonDocPrinter(const DocPrinterOptions& options) : DocPrinter(options) {} + explicit PythonDocPrinter(const PrinterConfig& options) : DocPrinter(options) {} protected: using DocPrinter::PrintDoc; @@ -642,17 +642,12 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) { NewLineWithoutIndent(); } -String DocToPythonScript(Doc doc, int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) { - DocPrinterOptions options; - options.indent_spaces = indent_spaces; - options.print_line_numbers = print_line_numbers; - if (num_context_lines >= 0) { - options.num_context_lines = num_context_lines; +String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { + if (cfg->num_context_lines < 0) { + cfg->num_context_lines = std::numeric_limits::max(); } - - PythonDocPrinter printer(options); - printer.Append(doc, path_to_underline); + PythonDocPrinter printer(cfg); + printer.Append(doc, cfg->path_to_underline); std::string result = printer.GetString(); int last_space = result.size(); while (last_space > 0 && std::isspace(result[last_space - 1])) { diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index e438919f4b1b..4a246e169276 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -52,7 +52,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) BaseFunc func = kv.second; (*f)->stmts.push_back(d->AsDoc(func, p->Attr("functions")->MapValue(gv))); } - return ClassDoc(IdDoc("Module"), {IR("ir_module")}, (*f)->stmts); + return ClassDoc(IdDoc("Module"), {IR(d, "ir_module")}, (*f)->stmts); } }); @@ -63,43 +63,44 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](GlobalVar gv, ObjectPath p, IRDocsifier d) -> Doc { - return IR("GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); + return IR(d, "GlobalVar")->Call({LiteralDoc::Str(gv->name_hint, p->Attr("name_hint"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc { - return IR("Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); + return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](TypeVar var, ObjectPath p, IRDocsifier d) -> Doc { - return IR("TypeVar")->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), // - LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))}); + return IR(d, "TypeVar") + ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), // + LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](GlobalTypeVar var, ObjectPath p, IRDocsifier d) -> Doc { - return IR("GlobalTypeVar") + return IR(d, "GlobalTypeVar") ->Call({LiteralDoc::Str(var->name_hint, p->Attr("name_hint")), LiteralDoc::Str(TypeKind2String(var->kind), p->Attr("kind"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](RelayRefType ref, ObjectPath p, IRDocsifier d) -> Doc { - return IR("RelayRef")->Call({d->AsDoc(ref->value, p->Attr("value"))}); + return IR(d, "RelayRef")->Call({d->AsDoc(ref->value, p->Attr("value"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](TensorType type, ObjectPath p, IRDocsifier d) -> Doc { - return IR("TensorType") + return IR(d, "TensorType") ->Call({d->AsDoc(type->shape, p->Attr("shape")), LiteralDoc::DataType(type->dtype, p->Attr("dtype"))}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](FuncType func_type, ObjectPath p, IRDocsifier d) -> Doc { - return IR("FuncType") + return IR(d, "FuncType") ->Call({ d->AsDoc(func_type->type_params, p->Attr("type_params")), d->AsDoc(func_type->arg_types, p->Attr("arg_types")), @@ -109,19 +110,17 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](IncompleteType ty, ObjectPath p, IRDocsifier d) -> Doc { - return IR("IncompleteType")->Call({}); + return IR(d, "IncompleteType")->Call({}); }); -void ReprPrintIRModule(const ObjectRef& mod, ReprPrinter* p) { +std::string ReprPrintIRModule(const ObjectRef& mod, const PrinterConfig& cfg) { if (const auto* f = runtime::Registry::Get("relay.ir.PrintRelayModule")) { if (Optional s = (*f)(mod)) { - p->stream << s.value(); - return; + return s.value(); } } - std::string res = - DocToPythonScript(IRDocsifier()->AsDoc(Downcast(mod), ObjectPath::Root())); - p->stream << res; + Doc doc = IRDocsifier(cfg)->AsDoc(mod, ObjectPath::Root()); + return DocToPythonScript(doc, cfg); } TVM_SCRIPT_REPR(TypeVarNode, ReprPrintIR); diff --git a/src/script/printer/ir/script_method.cc b/src/script/printer/ir/script_method.cc deleted file mode 100644 index 01d3ede7ea6c..000000000000 --- a/src/script/printer/ir/script_method.cc +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include - -#include "./utils.h" - -namespace tvm { - -std::string IRModuleNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) const { - using namespace tvm::script::printer; - return DocToPythonScript(IRDocsifier()->AsDoc(GetRef(this), ObjectPath::Root()), - indent_spaces, print_line_numbers, num_context_lines, path_to_underline); -} - -TVM_REGISTER_GLOBAL("ir.Module_Script").set_body_method(&IRModuleNode::Script); - -} // namespace tvm diff --git a/src/script/printer/ir/utils.h b/src/script/printer/ir/utils.h index 820fe13df3c6..d20756e6081a 100644 --- a/src/script/printer/ir/utils.h +++ b/src/script/printer/ir/utils.h @@ -23,9 +23,9 @@ #include #include #include -#include #include +#include #include #include "../utils.h" @@ -35,7 +35,9 @@ namespace script { namespace printer { /*! \brief Creates the IR common prefix, which is by default `I` */ -inline ExprDoc IR(const String& attr) { return IdDoc(Default::Prefix("ir"))->Attr(attr); } +inline ExprDoc IR(const IRDocsifier& d, const String& attr) { + return IdDoc(d->cfg->ir_prefix)->Attr(attr); +} class IRFrameNode : public FrameNode { public: @@ -57,15 +59,14 @@ class IRFrame : public Frame { TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IRFrame, Frame, IRFrameNode); }; -inline void ReprPrintIR(const ObjectRef& obj, ReprPrinter* p) { - IRDocsifier d; +/*! \brief Redirected method for the ReprPrinter */ +inline std::string ReprPrintIR(const ObjectRef& obj, const PrinterConfig& cfg) { + IRDocsifier d(cfg); With f(d); (*f)->AddDispatchToken(d, "ir"); - try { - p->stream << DocToPythonScript(Docsify(obj, d, *f)); - } catch (const Error& e) { - HandleUnsupportedFallback(e, obj, p); - } + std::ostringstream oss; + oss << Docsify(obj, d, *f, cfg); + return oss.str(); } } // namespace printer diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 4c52ce890c9d..5a0d2bd6bbe0 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -144,8 +144,9 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, this->common_prefix = std::move(visitor.common_prefix); } -IRDocsifier::IRDocsifier() { +IRDocsifier::IRDocsifier(const PrinterConfig& cfg) { auto n = make_object(); + n->cfg = cfg; n->dispatch_tokens.push_back(""); data_ = std::move(n); } diff --git a/src/script/printer/printer.cc b/src/script/printer/printer.cc deleted file mode 100644 index 878b380a3717..000000000000 --- a/src/script/printer/printer.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include - -namespace tvm { -namespace script { -namespace printer { - -Default* Default::Instance() { - static Default inst; - return &inst; -} - -TVM_REGISTER_GLOBAL("script.printer.DefaultIRPrefix") - .set_body_typed([](std::string ir, std::string prefix) { Default::Prefix(ir) = prefix; }); -TVM_REGISTER_GLOBAL("script.printer.DefaultBufferDType") - .set_body_typed([](runtime::DataType dtype) { Default::BufferDType() = dtype; }); -TVM_REGISTER_GLOBAL("script.printer.DefaultIntDType").set_body_typed([](runtime::DataType dtype) { - Default::IntDType() = dtype; -}); -TVM_REGISTER_GLOBAL("script.printer.DefaultFloatDType").set_body_typed([](runtime::DataType dtype) { - Default::FloatDType() = dtype; -}); -TVM_REGISTER_GLOBAL("script.printer.VerboseExpr").set_body_typed([](bool verbose_expr) { - Default::VerboseExpr() = verbose_expr; -}); - -} // namespace printer -} // namespace script -} // namespace tvm diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index f78e7037c3e0..a5b8d6609622 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -68,7 +68,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // auto print_single_iter_var = [&](int i) { tir::IterVar iter_var = block->iter_vars[i]; ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i); - ExprDoc rhs = TIR("axis"); + ExprDoc rhs = TIR(d, "axis"); if (iter_var->iter_type == tir::IterVarType::kDataPar) { rhs = rhs->Attr("spatial"); } else if (iter_var->iter_type == tir::IterVarType::kCommReduce) { @@ -128,7 +128,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // binding_paths.push_back(iter_var_p->Attr("iter_type")); binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R"; } - ExprDoc rhs = TIR("axis")->Attr("remap"); + ExprDoc rhs = TIR(d, "axis")->Attr("remap"); ExprDoc binding_str = LiteralDoc::Str(binding_type, NullOpt); binding_str->source_paths = std::move(binding_paths); rhs = rhs->Call({binding_str, ListDoc(loop_var_doc)}); @@ -151,8 +151,9 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // if (realize) { ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool()); if (!tir::is_one(realize->predicate)) { - (*frame)->stmts.push_back(ExprStmtDoc(TIR("where")->Call( - {d->AsDoc(realize->predicate, realize_p->Attr("predicate"))}))); + (*frame)->stmts.push_back(ExprStmtDoc( + TIR(d, "where") + ->Call({d->AsDoc(realize->predicate, realize_p->Attr("predicate"))}))); } } // Step 3. Handle block read/write regions @@ -161,17 +162,17 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // for (int i = 0, n = block->reads.size(); i < n; ++i) { reads.push_back(d->AsDoc(block->reads[i], block_p->Attr("reads")->ArrayIndex(i))); } - (*frame)->stmts.push_back(ExprStmtDoc(TIR("reads")->Call(reads))); + (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "reads")->Call(reads))); Array writes; for (int i = 0, n = block->writes.size(); i < n; ++i) { writes.push_back(d->AsDoc(block->writes[i], block_p->Attr("writes")->ArrayIndex(i))); } - (*frame)->stmts.push_back(ExprStmtDoc(TIR("writes")->Call(writes))); + (*frame)->stmts.push_back(ExprStmtDoc(TIR(d, "writes")->Call(writes))); } // Step 4. Handle block attributes if (!block->annotations.empty()) { (*frame)->stmts.push_back(ExprStmtDoc( - TIR("block_attr") + TIR(d, "block_attr") ->Call({d->AsDoc(block->annotations, block_p->Attr("annotations"))}))); } // Step 5. Handle `alloc_buffer` @@ -194,7 +195,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // tir::Stmt init = block->init.value(); With init_frame(d, init); AsDocBody(init, block_p->Attr("init"), init_frame->get(), d); - (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR("init")->Call({}), (*init_frame)->stmts)); + (*frame)->stmts.push_back(ScopeDoc(NullOpt, TIR(d, "init")->Call({}), (*init_frame)->stmts)); } // Step 8. Handle block body AsDocBody(block->body, block_p->Attr("body"), frame->get(), d); @@ -205,7 +206,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // kwargs_values.push_back(LiteralDoc::Boolean(true, NullOpt)); } return ScopeDoc(NullOpt, - TIR("block") // + TIR(d, "block") // ->Call({LiteralDoc::Str(block->name_hint, block_p->Attr("name_hint"))}, kwargs_keys, kwargs_values), (*frame)->stmts); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index b947039b58de..b4429dc9afc9 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -55,7 +55,7 @@ Map BufferAttrs(const tir::Buffer& buffer, const ObjectPath& p, // Step 1. Handle `buffer.shape` array_out_line_var_def(buffer->shape, p->Attr("shape"), "shape"); // Step 2. Handle `buffer.dtype` - if (buffer->dtype != Default::BufferDType()) { + if (buffer->dtype != d->cfg->buffer_dtype) { kwargs.Set("dtype", LiteralDoc::DataType(buffer->dtype, p->Attr("dtype"))); } // Step 3. Handle `buffer.data` @@ -123,7 +123,7 @@ ExprDoc BufferCall(const ExprDoc& prefix, const Map& attrs, Arr ExprDoc BufferDecl(const tir::Buffer& buffer, const String& method, const Array& args, const ObjectPath& p, const Frame& frame, const IRDocsifier& d) { - return BufferCall(/*prefix=*/TIR(method), + return BufferCall(/*prefix=*/TIR(d, method), /*attrs=*/BufferAttrs(buffer, p, frame, d), /*args=*/args); } @@ -134,7 +134,7 @@ ExprDoc BufferAttn(const tir::Buffer& buffer, const ObjectPath& p, const Frame& ExprDoc shape = attrs.Get("shape").value(); ExprDoc dtype = attrs.Get("dtype").value_or(LiteralDoc::DataType(buffer->dtype, p->Attr("dtype"))); - return TIR("Buffer")->Call({shape, dtype}, {}, {}); + return TIR(d, "Buffer")->Call({shape, dtype}, {}, {}); } Array BufferIndices(const Array& indices, const ObjectPath& p, @@ -251,7 +251,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc prefix = IdDoc(stmt->producer->GetNameHint()); prefix = prefix[BufferSlices(stmt->bounds, p->Attr("bounds"), d)]; - prefix = TIR("ProducerRealize") + prefix = TIR(d, "ProducerRealize") ->Call({prefix, d->AsDoc(stmt->condition, p->Attr("condition"))}); With f(d, stmt); AsDocBody(stmt->body, p->Attr("body"), f->get(), d); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index 6e0cfd420262..ab91764b6a0b 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -34,7 +34,7 @@ Doc PrintVar(const tir::Var& var, const ObjectPath& var_p, const IRDocsifier& d) ExprDoc rhs = d->AsDoc(type, var_p->Attr("type_annotation")); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } else { - ExprDoc rhs = TIR("var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))}); + ExprDoc rhs = TIR(d, "var")->Call({LiteralDoc::DataType(var->dtype, var_p->Attr("dtype"))}); opt_f.value()->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } } @@ -57,7 +57,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::IterVar var, ObjectPath var_p, IRDocsifier d) -> Doc { - return TIR("iter_var") + return TIR(d, "iter_var") ->Call({ d->AsDoc(var->var, var_p->Attr("var")), d->AsDoc(var->dom, var_p->Attr("dom")), @@ -70,7 +70,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Not node, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc a = d->AsDoc(node->a, p->Attr("a")); if (a->IsInstance()) { - return TIR("Not")->Call({a}); + return TIR(d, "Not")->Call({a}); } return OperationDoc(OperationDocNode::Kind::kNot, {a}); }); @@ -84,21 +84,22 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Cast cast, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc dtype = LiteralDoc::DataType(cast->dtype, p->Attr("dtype")); ExprDoc value = d->AsDoc(cast->value, p->Attr("value")); - return TIR("Cast")->Call({dtype, value}); + return TIR(d, "Cast")->Call({dtype, value}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Select select, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("Select")->Call({ - d->AsDoc(select->condition, p->Attr("condition")), - d->AsDoc(select->true_value, p->Attr("true_value")), - d->AsDoc(select->false_value, p->Attr("false_value")), - }); + return TIR(d, "Select") + ->Call({ + d->AsDoc(select->condition, p->Attr("condition")), + d->AsDoc(select->true_value, p->Attr("true_value")), + d->AsDoc(select->false_value, p->Attr("false_value")), + }); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Ramp ramp, ObjectPath ramp_p, IRDocsifier d) -> Doc { - return TIR("Ramp")->Call({ + return TIR(d, "Ramp")->Call({ d->AsDoc(ramp->base, ramp_p->Attr("base")), d->AsDoc(ramp->stride, ramp_p->Attr("stride")), LiteralDoc::Int(ramp->lanes, ramp_p->Attr("lanes")), @@ -107,7 +108,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Broadcast bc, ObjectPath bc_p, IRDocsifier d) -> Doc { - return TIR("Broadcast") + return TIR(d, "Broadcast") ->Call({ d->AsDoc(bc->value, bc_p->Attr("value")), LiteralDoc::Int(bc->lanes, bc_p->Attr("lanes")), @@ -117,10 +118,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::Shuffle shuffle, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("Shuffle")->Call({ - d->AsDoc(shuffle->vectors, p->Attr("vectors")), - d->AsDoc(shuffle->indices, p->Attr("indices")), - }); + return TIR(d, "Shuffle") + ->Call({ + d->AsDoc(shuffle->vectors, p->Attr("vectors")), + d->AsDoc(shuffle->indices, p->Attr("indices")), + }); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -152,12 +154,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } ExprDoc id = d->AsDoc(r->identity_element, p->Attr("identity_element")); - return TIR("comm_reducer")->Call({lambda, id}); + return TIR(d, "comm_reducer")->Call({lambda, id}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("let")->Call({ + return TIR(d, "let")->Call({ d->AsDoc(let->var, p->Attr("var")), d->AsDoc(let->value, p->Attr("value")), d->AsDoc(let->body, p->Attr("body")), @@ -194,7 +196,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (op_names.count(GetRef(op)) == 0) { LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; } - prefix = TIR(name); + prefix = TIR(d, name); } else if (const auto* gv = call->op.as()) { prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op")); } else { @@ -217,7 +219,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Any any, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("Any")->Call({}); + return TIR(d, "Any")->Call({}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -228,8 +230,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc axis = d->AsDoc(r->axis, p->Attr("axis")); ExprDoc condition = d->AsDoc(r->condition, p->Attr("condition")); ExprDoc value_index = LiteralDoc::Int(r->value_index, p->Attr("value_index")); - return TIR("reduce")->Call({combiner}, {"source", "init", "axis", "condition", "value_index"}, - {source, init, axis, condition, value_index}); + return TIR(d, "reduce") + ->Call({combiner}, {"source", "init", "axis", "condition", "value_index"}, + {source, init, axis, condition, value_index}); LOG(FATAL) << "ValueError: Reduce should never exist in TIR: " << r; }); @@ -244,7 +247,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) [](tir::NodeType node, ObjectPath p, IRDocsifier d) -> Doc { \ ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ - return TIR(OpString)->Call({a, b}); \ + return TIR(d, OpString)->Call({a, b}); \ }); #define TVM_SCRIPT_PRINTER_DEF_BINARY_WITH_SUGAR(NodeType, OpString, OpKind) \ @@ -254,7 +257,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) ExprDoc a = d->AsDoc(node->a, p->Attr("a")); \ ExprDoc b = d->AsDoc(node->b, p->Attr("b")); \ if (a->IsInstance() && b->IsInstance()) { \ - return TIR(OpString)->Call({a, b}); \ + return TIR(d, OpString)->Call({a, b}); \ } \ return OperationDoc(OperationDocNode::Kind::OpKind, {a, b}); \ }); diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 2a81c37061c6..7d21de27a1a2 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -59,7 +59,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) loop_p = loop_p->Attr("body"); } AsDocBody(grid.back()->body, loop_p, (*f).get(), d); - return ForDoc(TupleDoc(lhs), TIR("grid")->Call(rhs), (*f)->stmts); + return ForDoc(TupleDoc(lhs), TIR(d, "grid")->Call(rhs), (*f)->stmts); } // Step 3. If not `T.grid`, print loop kind accordingly ExprDoc lhs = DefineVar(loop->loop_var, *f, d); @@ -81,16 +81,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (loop->annotations.empty()) { prefix = IdDoc("range"); } else { - prefix = TIR("serial"); + prefix = TIR(d, "serial"); } } else if (loop->kind == tir::ForKind::kParallel) { - prefix = TIR("parallel"); + prefix = TIR(d, "parallel"); } else if (loop->kind == tir::ForKind::kUnrolled) { - prefix = TIR("unroll"); + prefix = TIR(d, "unroll"); } else if (loop->kind == tir::ForKind::kVectorized) { - prefix = TIR("vectorized"); + prefix = TIR(d, "vectorized"); } else if (loop->kind == tir::ForKind::kThreadBinding) { - prefix = TIR("thread_binding"); + prefix = TIR(d, "thread_binding"); thread = LiteralDoc::Str(loop->thread_binding.value()->thread_tag, loop_p->Attr("thread_binding")); } else { diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index ea7d56e1656d..fbcc2fca3b4b 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -111,7 +111,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // Step 2. Handle `func->attrs` if (func->attrs.defined() && !func->attrs->dict.empty()) { (*frame)->stmts.push_back( - ExprStmtDoc(TIR("func_attr") // + ExprStmtDoc(TIR(d, "func_attr") // ->Call({d->AsDoc(func->attrs, p->Attr("attrs"))}))); } // Step 3. Handle `func->buffer_map` @@ -175,14 +175,14 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return FunctionDoc( /*name=*/IdDoc(FindFunctionName(d, func)), /*args=*/args, - /*decorators=*/{TIR("prim_func")}, + /*decorators=*/{TIR(d, "prim_func")}, /*return_type=*/ret_type, /*body=*/(*frame)->stmts); }); -void ReprPrintPrimFunc(const ObjectRef& obj, ReprPrinter* p) { - std::string res = DocToPythonScript(IRDocsifier()->AsDoc(obj, ObjectPath::Root())); - p->stream << res; +std::string ReprPrintPrimFunc(const ObjectRef& obj, const PrinterConfig& cfg) { + Doc doc = IRDocsifier(cfg)->AsDoc(obj, ObjectPath::Root()); + return DocToPythonScript(doc, cfg); } TVM_SCRIPT_REPR(tir::PrimFuncNode, ReprPrintPrimFunc); diff --git a/src/script/printer/tir/ir.cc b/src/script/printer/tir/ir.cc index 1214f822610c..76d3680fec81 100644 --- a/src/script/printer/tir/ir.cc +++ b/src/script/printer/tir/ir.cc @@ -29,12 +29,12 @@ TVM_REGISTER_NODE_TYPE(TIRFrameNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](IntImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; - if (dtype == Default::IntDType()) { + if (dtype == d->cfg->int_dtype) { return LiteralDoc::Int(imm->value, imm_p->Attr("value")); } else if (dtype == DataType::Bool()) { return LiteralDoc::Boolean(imm->value, imm_p->Attr("value")); } else { - return TIR(runtime::DLDataType2String(dtype)) // + return TIR(d, runtime::DLDataType2String(dtype)) // ->Call({LiteralDoc::Int(imm->value, imm_p->Attr("value"))}); } }); @@ -42,26 +42,27 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](FloatImm imm, ObjectPath imm_p, IRDocsifier d) -> Doc { DataType dtype = imm->dtype; - if (dtype == Default::FloatDType()) { + if (dtype == d->cfg->float_dtype) { return LiteralDoc::Float(imm->value, imm_p->Attr("value")); } else { - return TIR(runtime::DLDataType2String(dtype)) // + return TIR(d, runtime::DLDataType2String(dtype)) // ->Call({LiteralDoc::Float(imm->value, imm_p->Attr("value"))}); } }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Range range, ObjectPath p, IRDocsifier d) -> Doc { - return TIR("Range")->Call({ - d->AsDoc(range->min, p->Attr("min")), - d->AsDoc(range->extent, p->Attr("extent")), - }); + return TIR(d, "Range") + ->Call({ + d->AsDoc(range->min, p->Attr("min")), + d->AsDoc(range->extent, p->Attr("extent")), + }); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](PrimType ty, ObjectPath p, IRDocsifier d) -> Doc { std::string dtype = ty->dtype.is_void() ? "void" : runtime::DLDataType2String(ty->dtype); - return TIR(dtype); + return TIR(d, dtype); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -74,9 +75,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) element_type = d->AsDoc(ty->element_type, ty_p->Attr("element_type")); } if (ty->storage_scope == "") { - return TIR("Ptr")->Call({element_type}); + return TIR(d, "Ptr")->Call({element_type}); } else { - return TIR("Ptr")->Call( + return TIR(d, "Ptr")->Call( {element_type, LiteralDoc::Str(ty->storage_scope, ty_p->Attr("storage_scope"))}); } }); @@ -86,13 +87,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (ty->fields.empty()) { return LiteralDoc::None(p); } - return TIR("Tuple")->Call(d->AsDoc(ty->fields, p->Attr("fields"))->elements); + return TIR(d, "Tuple")->Call(d->AsDoc(ty->fields, p->Attr("fields"))->elements); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](Target target, ObjectPath p, IRDocsifier d) -> Doc { Map config = target->Export(); - return TIR("target")->Call({d->AsDoc(config, p)}); + return TIR(d, "target")->Call({d->AsDoc(config, p)}); }); TVM_SCRIPT_REPR(IntImmNode, ReprPrintTIR); diff --git a/src/script/printer/tir/script_method.cc b/src/script/printer/tir/script_method.cc deleted file mode 100644 index 5cda9a9626db..000000000000 --- a/src/script/printer/tir/script_method.cc +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include - -#include "./utils.h" - -namespace tvm { - -std::string PrimExprNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) const { - using namespace tvm::script::printer; - IRDocsifier d; - ObjectRef obj = GetRef(this); - With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); - return DocToPythonScript(Docsify(obj, d, *f), indent_spaces, print_line_numbers, - num_context_lines, path_to_underline); -} - -namespace tir { - -std::string StmtNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) const { - using namespace tvm::script::printer; - IRDocsifier d; - ObjectRef obj = GetRef(this); - With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); - return DocToPythonScript(Docsify(obj, d, *f), indent_spaces, print_line_numbers, - num_context_lines, path_to_underline); -} - -std::string PrimFuncNode::Script(int indent_spaces, bool print_line_numbers, int num_context_lines, - Optional path_to_underline) const { - using namespace tvm::script::printer; - return DocToPythonScript(IRDocsifier()->AsDoc(GetRef(this), ObjectPath::Root()), - indent_spaces, print_line_numbers, num_context_lines, path_to_underline); -} - -TVM_REGISTER_GLOBAL("tir.PrimFuncScript").set_body_method(&PrimFuncNode::Script); -TVM_REGISTER_GLOBAL("tir.StmtScript").set_body_method(&StmtNode::Script); -TVM_REGISTER_GLOBAL("tir.PrimExprScript").set_body_method(&PrimExprNode::Script); - -} // namespace tir -} // namespace tvm diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index acdfd7da472b..2820f9ba6384 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -51,7 +51,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (eval->value->IsInstance()) { return ExprStmtDoc(value); } - return ExprStmtDoc(TIR("evaluate")->Call({value})); + return ExprStmtDoc(TIR(d, "evaluate")->Call({value})); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -75,7 +75,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); return StmtBlockDoc(*stmts); } else { - rhs = TIR("let")->Call({lhs, rhs}); + rhs = TIR(d, "let")->Call({lhs, rhs}); return ScopeDoc(NullOpt, rhs, *stmts); } }); @@ -93,7 +93,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) stmts->insert(stmts->begin(), AssertDoc(cond, msg)); return StmtBlockDoc(*stmts); } - return ScopeDoc(NullOpt, TIR("Assert")->Call({cond, msg}), (*f)->stmts); + return ScopeDoc(NullOpt, TIR(d, "Assert")->Call({cond, msg}), (*f)->stmts); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -145,7 +145,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::Prefetch stmt, ObjectPath p, IRDocsifier d) -> Doc { - return ExprStmtDoc(TIR("prefetch") + return ExprStmtDoc(TIR(d, "prefetch") ->Call({ d->AsDoc(stmt->buffer, p->Attr("buffer")), d->AsDoc(stmt->bounds, p->Attr("bounds")), @@ -198,7 +198,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } ExprDoc lhs = DefineVar(stmt->buffer_var, d->frames.back(), d); With f(d, stmt); - ExprDoc rhs = TIR("allocate")->Call(args, kwargs_keys, kwargs_values); + ExprDoc rhs = TIR(d, "allocate")->Call(args, kwargs_keys, kwargs_values); AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d); return DoConciseScoping(lhs, rhs, &(*f)->stmts, concise); }); @@ -277,7 +277,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) args.push_back(data_doc); args.push_back(LiteralDoc::DataType(stmt->dtype, stmt_p->Attr("dtype"))); args.push_back(d->AsDoc(stmt->extents, stmt_p->Attr("extents"))); - ExprDoc rhs = TIR("allocate_const")->Call(args, kwargs_keys, kwargs_values); + ExprDoc rhs = TIR(d, "allocate_const")->Call(args, kwargs_keys, kwargs_values); With f(d, stmt); ExprDoc lhs = DefineVar(stmt->buffer_var, *f, d); AsDocBody(stmt->body, stmt_p->Attr("body"), f->get(), d); @@ -310,7 +310,7 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, OptionalAsDoc(stmt->condition, p->Attr("condition"))); } - return TIR("realize")->Call(args, kwargs_keys, kwargs_values); + return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) @@ -351,12 +351,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) DefineVar(iter_var->var, f, d); f->stmts.push_back( AssignDoc(d->AsDoc(iter_var->var, iter_var_p->Attr("var")), - TIR("env_thread") + TIR(d, "env_thread") ->Call({LiteralDoc::Str(iter_var->thread_tag, iter_var_p->Attr("thread_tag"))}), // NullOpt)); } - rhs = TIR("launch_thread") + rhs = TIR(d, "launch_thread") ->Call({ d->AsDoc(iter_var->var, stmt_p->Attr("node")), d->AsDoc(stmt->value, stmt_p->Attr("value")), @@ -364,7 +364,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } } if (!rhs.defined()) { - rhs = TIR("attr")->Call({ + rhs = TIR(d, "attr")->Call({ d->AsDoc(stmt->node, stmt_p->Attr("node")), LiteralDoc::Str(stmt->attr_key, stmt_p->Attr("attr_key")), d->AsDoc(stmt->value, stmt_p->Attr("value")), diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index e1ffe135229e..88094ee816ca 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -20,7 +20,6 @@ #define TVM_SCRIPT_PRINTER_TIR_UTILS_H_ #include -#include #include #include #include @@ -74,7 +73,9 @@ class TIRFrame : public Frame { }; /*! \brief Creates the TIR common prefix, which is by default `T` */ -inline ExprDoc TIR(const String& attr) { return IdDoc(Default::Prefix("tir"))->Attr(attr); } +inline ExprDoc TIR(const IRDocsifier& d, const String& attr) { + return IdDoc(d->cfg->tir_prefix)->Attr(attr); +} /*! * \brief Defines a variable in the IRDocsifier at the given frame, @@ -187,14 +188,12 @@ inline TIRFrame MakeDispatchFrame(const IRDocsifier& d, const ObjectRef& root, } /*! \brief Redirected method for the ReprPrinter */ -inline void ReprPrintTIR(const ObjectRef& obj, ReprPrinter* p) { - IRDocsifier d; +inline std::string ReprPrintTIR(const ObjectRef& obj, const PrinterConfig& cfg) { + IRDocsifier d(cfg); With f(MakeDispatchFrame(d, obj, ObjectRef(nullptr))); - try { - p->stream << DocToPythonScript(Docsify(obj, d, *f)); - } catch (const tvm::Error& e) { - HandleUnsupportedFallback(e, obj, p); - } + std::ostringstream oss; + oss << Docsify(obj, d, *f, cfg); + return oss.str(); } /*! diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 9f9a7d8299c4..5161f1f9a268 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -20,13 +20,6 @@ #define TVM_SCRIPT_PRINTER_UTILS_H_ #include -#include -#include -#include -#include -#include -#include -#include #include #include @@ -37,13 +30,26 @@ namespace tvm { namespace script { namespace printer { -#define TVM_SCRIPT_REPR(ObjectType, Method) \ - TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch(Method); +#define TVM_SCRIPT_REPR(ObjectType, Method) \ + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) \ + .set_dispatch(RedirectedReprPrinterMethod); \ + TVM_STATIC_IR_FUNCTOR(TVMScriptPrinter, vtable).set_dispatch(Method); -inline StmtBlockDoc Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f) { +inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) { + try { + p->stream << TVMScriptPrinter::Script(obj, NullOpt); + } catch (const tvm::Error& e) { + LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" + << e.what(); + p->stream << AsLegacyRepr(obj); + } +} + +inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f, + const PrinterConfig& cfg) { Doc doc = d->AsDoc(obj, ObjectPath::Root()); if (const auto* expr_doc = doc.as()) { - if (!Default::VerboseExpr()) { + if (!cfg->verbose_expr) { f->stmts.clear(); } f->stmts.push_back(ExprStmtDoc(GetRef(expr_doc))); @@ -56,14 +62,7 @@ inline StmtBlockDoc Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fr } else { LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); } - return StmtBlockDoc(f->stmts); -} - -inline void HandleUnsupportedFallback(const tvm::Error& error, const ObjectRef& obj, - ReprPrinter* p) { - LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" - << error.what(); - p->stream << AsLegacyRepr(obj); + return DocToPythonScript(StmtBlockDoc(f->stmts), cfg); } } // namespace printer diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index c73ae291930c..71da86bff763 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -15,31 +15,15 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring -from contextlib import contextmanager - +import tvm.testing from tvm import ir, tir from tvm.ir import Range from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import tir as T -from tvm.script.printer import default -import tvm.testing - - -@contextmanager -def verbose_expr(): - try: - default.verbose_expr(True) - yield - finally: - default.verbose_expr(False) def _assert_print(obj, expected): - with verbose_expr(): - if isinstance(obj, (tir.PrimFunc, tir.PrimExpr, tir.Stmt)): - assert obj.script().strip() == expected.strip() - assert str(obj).strip() == expected.strip() - assert repr(obj).strip() == expected.strip() + assert obj.script(verbose_expr=True).strip() == expected.strip() def test_prim_func():