From e6f992e74534516084e77bea52985f1a37150bc9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 9 Oct 2020 01:18:15 -0700 Subject: [PATCH 01/32] Add initial boilerplate for Rust diagnostic interface. --- python/tvm/ir/diagnostics/__init__.py | 2 +- rust/tvm/src/ir/diagnostics.rs | 239 ++++++++++++++++++++++++++ rust/tvm/src/ir/mod.rs | 1 + 3 files changed, 241 insertions(+), 1 deletion(-) create mode 100644 rust/tvm/src/ir/diagnostics.rs diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py index 6503743aaa51..0ad2a7aa6bfd 100644 --- a/python/tvm/ir/diagnostics/__init__.py +++ b/python/tvm/ir/diagnostics/__init__.py @@ -37,7 +37,7 @@ def get_renderer(): """ return _ffi_api.GetRenderer() - +@tvm.register_func("diagnostics.override_renderer") def override_renderer(render_func): """ Sets a custom renderer for diagnostics. diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs new file mode 100644 index 000000000000..799a10c71b00 --- /dev/null +++ b/rust/tvm/src/ir/diagnostics.rs @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/// The diagnostic interface to TVM, used for reporting and rendering +/// diagnostic information by the compiler. This module exposes +/// three key abstractions: a Diagnostic, the DiagnosticContext, +/// and the DiagnosticRenderer. + +use tvm_macros::{Object, external}; +use super::module::IRModule; +use crate::runtime::{function::{Function, Typed}, array::Array, string::String as TString}; +use crate::runtime::object::{Object, ObjectRef}; +use crate::runtime::function::Result; +use super::span::Span; + +type SourceName = ObjectRef; + +/// The diagnostic level, controls the printing of the message. +#[repr(C)] +pub enum DiagnosticLevel { + Bug = 10, + Error = 20, + Warning = 30, + Note = 40, + Help = 50, +} + +/// A compiler diagnostic. +#[repr(C)] +#[derive(Object)] +#[ref_name = "Diagnostic"] +#[type_key = "Diagnostic"] +pub struct DiagnosticNode { + pub base: Object, + /// The level. + pub level: DiagnosticLevel, + /// The span at which to report an error. + pub span: Span, + /// The diagnostic message. + pub message: TString, +} + +impl Diagnostic { + pub fn new(level: DiagnosticLevel, span: Span, message: TString) { + todo!() + } + + pub fn bug(span: Span) -> DiagnosticBuilder { + todo!() + } + + pub fn error(span: Span) -> DiagnosticBuilder { + todo!() + } + + pub fn warning(span: Span) -> DiagnosticBuilder { + todo!() + } + + pub fn note(span: Span) -> DiagnosticBuilder { + todo!() + } + + pub fn help(span: Span) -> DiagnosticBuilder { + todo!() + } +} + +/// A wrapper around std::stringstream to build a diagnostic. +pub struct DiagnosticBuilder { + /// The level. + pub level: DiagnosticLevel, + + /// The source name. + pub source_name: SourceName, + + /// The span of the diagnostic. + pub span: Span, +} + +// /*! \brief Display diagnostics in a given display format. +// * +// * A diagnostic renderer is responsible for converting the +// * raw diagnostics into consumable output. +// * +// * For example the terminal renderer will render a sequence +// * of compiler diagnostics to std::out and std::err in +// * a human readable form. +// */ +// class DiagnosticRendererNode : public Object { +// public: +// TypedPackedFunc renderer; + +// // override attr visitor +// void VisitAttrs(AttrVisitor* v) {} + +// static constexpr const char* _type_key = "DiagnosticRenderer"; +// TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object); +// }; + +// class DiagnosticRenderer : public ObjectRef { +// public: +// TVM_DLL DiagnosticRenderer(TypedPackedFunc render); +// TVM_DLL DiagnosticRenderer() +// : DiagnosticRenderer(TypedPackedFunc()) {} + +// void Render(const DiagnosticContext& ctx); + +// DiagnosticRendererNode* operator->() { +// CHECK(get() != nullptr); +// return static_cast(get_mutable()); +// } + +// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode); +// }; + +// @tvm._ffi.register_object("DiagnosticRenderer") +// class DiagnosticRenderer(Object): +// """ +// A diagnostic renderer, which given a diagnostic context produces a "rendered" +// form of the diagnostics for either human or computer consumption. +// """ + +// def __init__(self, render_func): +// self.__init_handle_by_constructor__(_ffi_api.DiagnosticRenderer, render_func) + +// def render(self, ctx): +// """ +// Render the provided context. + +// Params +// ------ +// ctx: DiagnosticContext +// The diagnostic context to render. +// """ +// return _ffi_api.DiagnosticRendererRender(self, ctx +pub type DiagnosticRenderer = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "DiagnosticContext"] +#[type_key = "DiagnosticContext"] +/// A diagnostic context for recording errors against a source file. +pub struct DiagnosticContextNode { + // The base type. + pub base: Object, + + /// The Module to report against. + pub module: IRModule, + + /// The set of diagnostics to report. + pub diagnostics: Array, + + /// The renderer set for the context. + pub renderer: DiagnosticRenderer, +} + +// Get the the diagnostic renderer. +external! { + #[name("node.ArrayGetItem")] + fn get_renderer() -> DiagnosticRenderer; + + #[name("diagnostics.DiagnosticRenderer")] + fn diagnostic_renderer(func: Function) -> DiagnosticRenderer; + + #[name("diagnostics.Emit")] + fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> (); + + #[name("diagnostics.DiagnosticContextRender")] + fn diagnostic_context_render(ctx: DiagnosticContext) -> (); +} + +/// A diagnostic context which records active errors +/// and contains a renderer. +impl DiagnosticContext { + pub fn new(module: IRModule, renderer: DiagnosticRenderer) { + todo!() + } + + pub fn default(module: IRModule) -> DiagnosticContext { + todo!() + } + + /// Emit a diagnostic. + pub fn emit(&mut self, diagnostic: Diagnostic) -> Result<()> { + emit(self.clone(), diagnostic) + } + + /// Render the errors and raise a DiagnosticError exception. + pub fn render(&mut self) -> Result<()> { + diagnostic_context_render(self.clone()) + } + + /// Emit a diagnostic and then immediately attempt to render all errors. + pub fn emit_fatal(&mut self, diagnostic: Diagnostic) -> Result<()> { + self.emit(diagnostic)?; + self.render()?; + Ok(()) + } +} + +// Sets a custom renderer for diagnostics. + +// Params +// ------ +// render_func: Option[Callable[[DiagnosticContext], None]] +// If the render_func is None it will remove the current custom renderer +// and return to default behavior. +fn override_renderer(opt_func: Option) -> Result<()> +where F: Fn(DiagnosticContext) -> () +{ + todo!() + // fn () + // diagnostic_renderer(func) + // if render_func: + + // def _render_factory(): + // return DiagnosticRenderer(render_func) + + // register_func("diagnostics.OverrideRenderer", _render_factory, override=True) + // else: + // _ffi_api.ClearRenderer() +} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 126d0faccabb..8450bd790445 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -20,6 +20,7 @@ pub mod arith; pub mod attrs; pub mod expr; +pub mod diagnostics; pub mod function; pub mod module; pub mod op; From 4fe35b0444ac106c3bb6f91338e026f8c6d52f95 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 9 Oct 2020 23:59:51 -0700 Subject: [PATCH 02/32] Codespan example almost working --- rust/tvm-sys/src/packed_func.rs | 1 + rust/tvm/Cargo.toml | 2 + rust/tvm/src/bin/tyck.rs | 24 ++++++ rust/tvm/src/ir/diagnostics.rs | 121 +++++++++++++++++++++++-------- rust/tvm/src/ir/relay/visitor.rs | 24 ++++++ 5 files changed, 143 insertions(+), 29 deletions(-) create mode 100644 rust/tvm/src/bin/tyck.rs create mode 100644 rust/tvm/src/ir/relay/visitor.rs diff --git a/rust/tvm-sys/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs index f7b289c59675..7b8d5296d641 100644 --- a/rust/tvm-sys/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -101,6 +101,7 @@ macro_rules! TVMPODValue { TVMArgTypeCode_kTVMOpaqueHandle => Handle($value.v_handle), TVMArgTypeCode_kTVMDLTensorHandle => ArrayHandle($value.v_handle as TVMArrayHandle), TVMArgTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle), + TVMArgTypeCode_kTVMObjectRValueRefArg => ObjectHandle(*($value.v_handle as *mut *mut c_void)), TVMArgTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle), TVMArgTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle), TVMArgTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle), diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index 55fc1790604e..71a4b937460f 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -41,6 +41,8 @@ paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" pyo3 = { version = "0.11.1", optional = true } +codespan-reporting = "0.9.5" +structopt = { version = "0.3" } [features] default = ["python"] diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs new file mode 100644 index 000000000000..9300412585a7 --- /dev/null +++ b/rust/tvm/src/bin/tyck.rs @@ -0,0 +1,24 @@ +use std::path::PathBuf; + +use anyhow::Result; +use structopt::StructOpt; + +use tvm::ir::diagnostics::codespan; +use tvm::ir::IRModule; + + +#[derive(Debug, StructOpt)] +#[structopt(name = "tyck", about = "Parse and type check a Relay program.")] +struct Opt { + /// Input file + #[structopt(parse(from_os_str))] + input: PathBuf, +} + +fn main() -> Result<()> { + codespan::init().expect("Rust based diagnostics"); + let opt = Opt::from_args(); + println!("{:?}", &opt); + let file = IRModule::parse_file(opt.input)?; + Ok(()) +} diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs index 799a10c71b00..e434d3f45767 100644 --- a/rust/tvm/src/ir/diagnostics.rs +++ b/rust/tvm/src/ir/diagnostics.rs @@ -24,13 +24,31 @@ use tvm_macros::{Object, external}; use super::module::IRModule; -use crate::runtime::{function::{Function, Typed}, array::Array, string::String as TString}; -use crate::runtime::object::{Object, ObjectRef}; +use crate::runtime::{function::{self, Function, ToFunction, Typed}, array::Array, string::String as TString}; +use crate::runtime::object::{Object, ObjectPtr, ObjectRef}; use crate::runtime::function::Result; use super::span::Span; type SourceName = ObjectRef; +// Get the the diagnostic renderer. +external! { + #[name("node.ArrayGetItem")] + fn get_renderer() -> DiagnosticRenderer; + + #[name("diagnostics.DiagnosticRenderer")] + fn diagnostic_renderer(func: Function) -> DiagnosticRenderer; + + #[name("diagnostics.Emit")] + fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> (); + + #[name("diagnostics.DiagnosticContextRender")] + fn diagnostic_context_render(ctx: DiagnosticContext) -> (); + + #[name("diagnostics.ClearRenderer")] + fn clear_renderer() -> (); +} + /// The diagnostic level, controls the printing of the message. #[repr(C)] pub enum DiagnosticLevel { @@ -171,26 +189,20 @@ pub struct DiagnosticContextNode { pub renderer: DiagnosticRenderer, } -// Get the the diagnostic renderer. -external! { - #[name("node.ArrayGetItem")] - fn get_renderer() -> DiagnosticRenderer; - - #[name("diagnostics.DiagnosticRenderer")] - fn diagnostic_renderer(func: Function) -> DiagnosticRenderer; - - #[name("diagnostics.Emit")] - fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> (); - - #[name("diagnostics.DiagnosticContextRender")] - fn diagnostic_context_render(ctx: DiagnosticContext) -> (); -} - /// A diagnostic context which records active errors /// and contains a renderer. impl DiagnosticContext { - pub fn new(module: IRModule, renderer: DiagnosticRenderer) { - todo!() + pub fn new(module: IRModule, render_func: F) -> DiagnosticContext + where F: Fn(DiagnosticContext) -> () + 'static + { + let renderer = diagnostic_renderer(render_func.to_function()).unwrap(); + let node = DiagnosticContextNode { + base: Object::base_object::(), + module, + diagnostics: Array::from_vec(vec![]).unwrap(), + renderer, + }; + DiagnosticContext(Some(ObjectPtr::new(node))) } pub fn default(module: IRModule) -> DiagnosticContext { @@ -223,17 +235,68 @@ impl DiagnosticContext { // If the render_func is None it will remove the current custom renderer // and return to default behavior. fn override_renderer(opt_func: Option) -> Result<()> -where F: Fn(DiagnosticContext) -> () +where F: Fn(DiagnosticContext) -> () + 'static { - todo!() - // fn () - // diagnostic_renderer(func) - // if render_func: - // def _render_factory(): - // return DiagnosticRenderer(render_func) + match opt_func { + None => clear_renderer(), + Some(func) => { + let func = func.to_function(); + let render_factory = move || { + diagnostic_renderer(func.clone()).unwrap() + }; + + function::register_override( + render_factory, + "diagnostics.OverrideRenderer", + true)?; + + Ok(()) + } + } +} + +pub mod codespan { + use super::*; + + use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity}; + use codespan_reporting::files::SimpleFiles; + use codespan_reporting::term::termcolor::{ColorChoice, StandardStream}; + + pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic { + let severity = match diag.level { + DiagnosticLevel::Error => Severity::Error, + DiagnosticLevel::Warning => Severity::Warning, + DiagnosticLevel::Note => Severity::Note, + DiagnosticLevel::Help => Severity::Help, + DiagnosticLevel::Bug => Severity::Bug, + }; + + let file_id = "foo".into(); // diag.span.source_name; + + let message: String = diag.message.as_str().unwrap().into(); + let inner_message: String = "expected `String`, found `Nat`".into(); + let diagnostic = CDiagnostic::new(severity) + .with_message(message) + .with_code("EXXX") + .with_labels(vec![ + Label::primary(file_id, 328..331).with_message(inner_message), + ]); + + diagnostic + } + + pub fn init() -> Result<()> { + let mut files: SimpleFiles = SimpleFiles::new(); + let render_fn = move |diag_ctx: DiagnosticContext| { + // let source_map = diag_ctx.module.source_map; + for diagnostic in diag_ctx.diagnostics { + + } + panic!("render_fn"); + }; - // register_func("diagnostics.OverrideRenderer", _render_factory, override=True) - // else: - // _ffi_api.ClearRenderer() + override_renderer(Some(render_fn))?; + Ok(()) + } } diff --git a/rust/tvm/src/ir/relay/visitor.rs b/rust/tvm/src/ir/relay/visitor.rs new file mode 100644 index 000000000000..31661742c4fb --- /dev/null +++ b/rust/tvm/src/ir/relay/visitor.rs @@ -0,0 +1,24 @@ +use super::Expr; + +macro_rules! downcast_match { + ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => { + $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+ + { $default } + } +} + +trait ExprVisitorMut { + fn visit(&mut self, expr: Expr) { + downcast_match!(expr; { + else => { + panic!() + } + }); + } + + fn visit(&mut self, expr: Expr); +} + +// trait ExprTransformer { +// fn +// } From dfacf9e51d69322c7d287f083ccae51dc4346173 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 13 Oct 2020 11:04:37 -0700 Subject: [PATCH 03/32] WIP --- rust/tvm/src/ir/diagnostics.rs | 78 ++++++++++++++++------------------ 1 file changed, 37 insertions(+), 41 deletions(-) diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs index e434d3f45767..d30618593e1e 100644 --- a/rust/tvm/src/ir/diagnostics.rs +++ b/rust/tvm/src/ir/diagnostics.rs @@ -24,7 +24,7 @@ use tvm_macros::{Object, external}; use super::module::IRModule; -use crate::runtime::{function::{self, Function, ToFunction, Typed}, array::Array, string::String as TString}; +use crate::runtime::{function::{self, Function, ToFunction}, array::Array, string::String as TString}; use crate::runtime::object::{Object, ObjectPtr, ObjectRef}; use crate::runtime::function::Result; use super::span::Span; @@ -121,42 +121,19 @@ pub struct DiagnosticBuilder { // * of compiler diagnostics to std::out and std::err in // * a human readable form. // */ -// class DiagnosticRendererNode : public Object { -// public: -// TypedPackedFunc renderer; - -// // override attr visitor -// void VisitAttrs(AttrVisitor* v) {} - -// static constexpr const char* _type_key = "DiagnosticRenderer"; -// TVM_DECLARE_FINAL_OBJECT_INFO(DiagnosticRendererNode, Object); -// }; - -// class DiagnosticRenderer : public ObjectRef { -// public: -// TVM_DLL DiagnosticRenderer(TypedPackedFunc render); -// TVM_DLL DiagnosticRenderer() -// : DiagnosticRenderer(TypedPackedFunc()) {} - -// void Render(const DiagnosticContext& ctx); - -// DiagnosticRendererNode* operator->() { -// CHECK(get() != nullptr); -// return static_cast(get_mutable()); -// } - -// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DiagnosticRenderer, ObjectRef, DiagnosticRendererNode); -// }; - -// @tvm._ffi.register_object("DiagnosticRenderer") -// class DiagnosticRenderer(Object): -// """ -// A diagnostic renderer, which given a diagnostic context produces a "rendered" -// form of the diagnostics for either human or computer consumption. -// """ +#[repr(C)] +#[derive(Object)] +#[ref_name = "DiagnosticRenderer"] +#[type_key = "DiagnosticRenderer"] +/// A diagnostic renderer, which given a diagnostic context produces a "rendered" +/// form of the diagnostics for either human or computer consumption. +pub struct DiagnosticRendererNode { + /// The base type. + pub base: Object, + // TODO(@jroesch): we can't easily exposed packed functions due to + // memory layout +} -// def __init__(self, render_func): -// self.__init_handle_by_constructor__(_ffi_api.DiagnosticRenderer, render_func) // def render(self, ctx): // """ @@ -168,7 +145,6 @@ pub struct DiagnosticBuilder { // The diagnostic context to render. // """ // return _ffi_api.DiagnosticRendererRender(self, ctx -pub type DiagnosticRenderer = ObjectRef; #[repr(C)] #[derive(Object)] @@ -227,8 +203,7 @@ impl DiagnosticContext { } } -// Sets a custom renderer for diagnostics. - +// Override the global diagnostics renderer. // Params // ------ // render_func: Option[Callable[[DiagnosticContext], None]] @@ -263,6 +238,27 @@ pub mod codespan { use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::termcolor::{ColorChoice, StandardStream}; + enum StartOrEnd { + Start, + End, + } + + struct SpanToBytes { + inner: HashMap { + file_id: FileId, + start_pos: usize, + end_pos: usize, + } + + // impl SpanToBytes { + // fn to_byte_pos(&self, span: tvm::ir::Span) -> ByteRange { + + // } + // } + pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic { let severity = match diag.level { DiagnosticLevel::Error => Severity::Error, @@ -290,9 +286,9 @@ pub mod codespan { let mut files: SimpleFiles = SimpleFiles::new(); let render_fn = move |diag_ctx: DiagnosticContext| { // let source_map = diag_ctx.module.source_map; - for diagnostic in diag_ctx.diagnostics { + // for diagnostic in diag_ctx.diagnostics { - } + // } panic!("render_fn"); }; From e827660ed2fa6b241be13840ccf985c3d4147215 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 13 Oct 2020 15:25:56 -0700 Subject: [PATCH 04/32] Hacking on Rust inside of TVM --- CMakeLists.txt | 1 + cmake/modules/RustExt.cmake | 13 + rust/Cargo.toml | 1 + rust/compiler-ext/Cargo.toml | 13 + rust/compiler-ext/src/lib.rs | 7 + rust/tvm/src/ir/source_map.rs | 0 rust/tvm/test.rly | 2 + tests/python/relay/test_type_infer2.py | 419 +++++++++++++++++++++++++ 8 files changed, 456 insertions(+) create mode 100644 cmake/modules/RustExt.cmake create mode 100644 rust/compiler-ext/Cargo.toml create mode 100644 rust/compiler-ext/src/lib.rs create mode 100644 rust/tvm/src/ir/source_map.rs create mode 100644 rust/tvm/test.rly create mode 100644 tests/python/relay/test_type_infer2.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 6ce0cdc129db..decdd56537d5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,6 +79,7 @@ tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF) tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF) tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF) +tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake new file mode 100644 index 000000000000..45e46bd1bd62 --- /dev/null +++ b/cmake/modules/RustExt.cmake @@ -0,0 +1,13 @@ +if(USE_RUST_EXT) + set(RUST_SRC_DIR "rust") + set(CARGO_OUT_DIR "rust/target" + set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/target/release/libcompiler_ext.dylib") + + add_custom_command( + OUTPUT "${COMPILER_EXT_PATH}" + COMMAND cargo build --release + MAIN_DEPENDENCY "${RUST_SRC_DIR}" + WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext") + + target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE) +endif(USE_RUST_EXT) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 9935ce7c8b9f..6e14c2b02c2b 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -28,4 +28,5 @@ members = [ "tvm-graph-rt/tests/test_tvm_dso", "tvm-graph-rt/tests/test_wasm32", "tvm-graph-rt/tests/test_nn", + "compiler-ext", ] diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml new file mode 100644 index 000000000000..76d10eb2e49c --- /dev/null +++ b/rust/compiler-ext/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "compiler-ext" +version = "0.1.0" +authors = ["Jared Roesch "] +edition = "2018" +# TODO(@jroesch): would be cool to figure out how to statically link instead. + +[lib] +crate-type = ["cdylib"] + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs new file mode 100644 index 000000000000..31e1bb209f98 --- /dev/null +++ b/rust/compiler-ext/src/lib.rs @@ -0,0 +1,7 @@ +#[cfg(test)] +mod tests { + #[test] + fn it_works() { + assert_eq!(2 + 2, 4); + } +} diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rust/tvm/test.rly b/rust/tvm/test.rly new file mode 100644 index 000000000000..d8b7c6960fef --- /dev/null +++ b/rust/tvm/test.rly @@ -0,0 +1,2 @@ +#[version = "0.0.5"] +fn @main(%x: int32) -> float32 { %x } diff --git a/tests/python/relay/test_type_infer2.py b/tests/python/relay/test_type_infer2.py new file mode 100644 index 000000000000..6758d96773a2 --- /dev/null +++ b/tests/python/relay/test_type_infer2.py @@ -0,0 +1,419 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test that type checker correcly computes types + for expressions. +""" +import pytest +import tvm + +from tvm import IRModule, te, relay, parser +from tvm.relay import op, transform, analysis +from tvm.relay import Any + + +def infer_mod(mod, annotate_spans=True): + if annotate_spans: + mod = relay.transform.AnnotateSpans()(mod) + + mod = transform.InferType()(mod) + return mod + + +def infer_expr(expr, annotate_spans=True): + mod = IRModule.from_expr(expr) + mod = infer_mod(mod, annotate_spans) + mod = transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def assert_has_type(expr, typ, mod=None): + if not mod: + mod = tvm.IRModule({}) + + mod["main"] = expr + mod = infer_mod(mod) + checked_expr = mod["main"] + checked_type = checked_expr.checked_type + if checked_type != typ: + raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ)) + + +def initialize_box_adt(mod): + # initializes simple ADT for tests + box = relay.GlobalTypeVar("box") + tv = relay.TypeVar("tv") + constructor = relay.Constructor("constructor", [tv], box) + data = relay.TypeData(box, [tv], [constructor]) + mod[box] = data + return box, constructor + + +def test_monomorphic_let(): + "Program: let %x = 1; %x" + # TODO(@jroesch): this seems whack. + sb = relay.ScopeBuilder() + x = relay.var("x", dtype="float64", shape=()) + x = sb.let("x", relay.const(1.0, "float64")) + sb.ret(x) + xchecked = infer_expr(sb.get()) + assert xchecked.checked_type == relay.scalar_type("float64") + + +def test_single_op(): + "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }" + x = relay.var("x", shape=[]) + func = relay.Function([x], op.log(x)) + ttype = relay.TensorType([], dtype="float32") + assert_has_type(func, relay.FuncType([ttype], ttype)) + + +def test_add_broadcast_op(): + """ + Program: + fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32]) + -> Tensor[(5, 10, 4), float32] { + %x + %y + } + """ + x = relay.var("x", shape=(10, 4)) + y = relay.var("y", shape=(5, 10, 1)) + z = x + y + func = relay.Function([x, y], z) + t1 = relay.TensorType((10, 4), "float32") + t2 = relay.TensorType((5, 10, 1), "float32") + t3 = relay.TensorType((5, 10, 4), "float32") + expected_ty = relay.FuncType([t1, t2], t3) + assert_has_type(func, expected_ty) + + +def test_dual_op(): + """Program: + fn (%x : Tensor[(10, 10), float32]) { + let %t1 = log(x); + let %t2 = add(%t1, %x); + %t1 + } + """ + tp = relay.TensorType((10, 10), "float32") + x = relay.var("x", tp) + sb = relay.ScopeBuilder() + t1 = sb.let("t1", relay.log(x)) + t2 = sb.let("t2", relay.add(t1, x)) + sb.ret(t2) + f = relay.Function([x], sb.get()) + fchecked = infer_expr(f) + assert fchecked.checked_type == relay.FuncType([tp], tp) + + +def test_decl(): + """Program: + def @f(%x : Tensor[(10, 10), float32]) { + log(%x) + } + """ + tp = relay.TensorType((10, 10)) + x = relay.var("x", tp) + f = relay.Function([x], relay.log(x)) + fchecked = infer_expr(f) + assert fchecked.checked_type == relay.FuncType([tp], tp) + + +def test_recursion(): + """ + Program: + def @f(%n: int32, %data: float32) -> float32 { + if (%n == 0) { + %data + } else { + @f(%n - 1, log(%data)) + } + } + """ + sb = relay.ScopeBuilder() + f = relay.GlobalVar("f") + ti32 = relay.scalar_type("int32") + tf32 = relay.scalar_type("float32") + n = relay.var("n", ti32) + data = relay.var("data", tf32) + + with sb.if_scope(relay.equal(n, relay.const(0, ti32))): + sb.ret(data) + with sb.else_scope(): + sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) + mod = tvm.IRModule() + mod[f] = relay.Function([n, data], sb.get()) + mod = infer_mod(mod) + assert "@f(%1, %2)" in mod.astext() + assert mod["f"].checked_type == relay.FuncType([ti32, tf32], tf32) + + +def test_incomplete_call(): + tt = relay.scalar_type("int32") + x = relay.var("x", tt) + f = relay.var("f") + func = relay.Function([x, f], relay.Call(f, [x]), tt) + + ft = infer_expr(func) + f_type = relay.FuncType([tt], tt) + assert ft.checked_type == relay.FuncType([tt, f_type], tt) + + +def test_higher_order_argument(): + a = relay.TypeVar("a") + x = relay.Var("x", a) + id_func = relay.Function([x], x, a, [a]) + + b = relay.TypeVar("b") + f = relay.Var("f", relay.FuncType([b], b)) + y = relay.Var("y", b) + ho_func = relay.Function([f, y], f(y), b, [b]) + + # id func should be an acceptable argument to the higher-order + # function even though id_func takes a type parameter + ho_call = ho_func(id_func, relay.const(0, "int32")) + + hc = infer_expr(ho_call) + expected = relay.scalar_type("int32") + assert hc.checked_type == expected + + +def test_higher_order_return(): + a = relay.TypeVar("a") + x = relay.Var("x", a) + id_func = relay.Function([x], x, a, [a]) + + b = relay.TypeVar("b") + nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b]) + + ft = infer_expr(nested_id) + assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b]) + + +def test_higher_order_nested(): + a = relay.TypeVar("a") + x = relay.Var("x", a) + id_func = relay.Function([x], x, a, [a]) + + choice_t = relay.FuncType([], relay.scalar_type("bool")) + f = relay.Var("f", choice_t) + + b = relay.TypeVar("b") + z = relay.Var("z") + top = relay.Function( + [f], relay.If(f(), id_func, relay.Function([z], z)), relay.FuncType([b], b), [b] + ) + + expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b]) + ft = infer_expr(top) + assert ft.checked_type == expected + + +def test_tuple(): + tp = relay.TensorType((10,)) + x = relay.var("x", tp) + res = relay.Tuple([x, x]) + assert infer_expr(res).checked_type == relay.TupleType([tp, tp]) + + +def test_ref(): + x = relay.var("x", "float32") + y = relay.var("y", "float32") + r = relay.RefCreate(x) + st = relay.scalar_type("float32") + assert infer_expr(r).checked_type == relay.RefType(st) + g = relay.RefRead(r) + assert infer_expr(g).checked_type == st + w = relay.RefWrite(r, y) + assert infer_expr(w).checked_type == relay.TupleType([]) + + +def test_free_expr(): + x = relay.var("x", "float32") + y = relay.add(x, x) + yy = infer_expr(y, annotate_spans=False) + assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True) + assert yy.checked_type == relay.scalar_type("float32") + assert x.vid.same_as(yy.args[0].vid) + + +def test_type_args(): + x = relay.var("x", shape=(10, 10)) + y = relay.var("y", shape=(1, 10)) + z = relay.add(x, y) + ty_z = infer_expr(z) + ty_args = ty_z.type_args + assert len(ty_args) == 2 + assert ty_args[0].dtype == "float32" + assert ty_args[1].dtype == "float32" + sh1 = ty_args[0].shape + sh2 = ty_args[1].shape + assert sh1[0].value == 10 + assert sh1[1].value == 10 + assert sh2[0].value == 1 + assert sh2[1].value == 10 + + +def test_global_var_recursion(): + mod = tvm.IRModule({}) + gv = relay.GlobalVar("main") + x = relay.var("x", shape=[]) + tt = relay.scalar_type("float32") + + func = relay.Function([x], relay.Call(gv, [x]), tt) + mod[gv] = func + mod = infer_mod(mod) + func_ty = mod["main"].checked_type + + assert func_ty == relay.FuncType([tt], tt) + + +def test_equal(): + i = relay.var("i", shape=[], dtype="int32") + eq = op.equal(i, relay.const(0, dtype="int32")) + func = relay.Function([i], eq) + ft = infer_expr(func) + expected = relay.FuncType([relay.scalar_type("int32")], relay.scalar_type("bool")) + assert ft.checked_type == expected + + assert ft.checked_type == relay.FuncType( + [relay.scalar_type("int32")], relay.scalar_type("bool") + ) + + +def test_constructor_type(): + mod = tvm.IRModule() + box, constructor = initialize_box_adt(mod) + + a = relay.TypeVar("a") + x = relay.Var("x", a) + func = relay.Function([x], constructor(x), box(a), [a]) + mod["main"] = func + mod = infer_mod(mod) + func_ty = mod["main"].checked_type + box = mod.get_global_type_var("box") + expected = relay.FuncType([a], box(a), [a]) + assert func_ty == expected + + +def test_constructor_call(): + mod = tvm.IRModule() + box, constructor = initialize_box_adt(mod) + + box_unit = constructor(relay.Tuple([])) + box_constant = constructor(relay.const(0, "float32")) + + func = relay.Function([], relay.Tuple([box_unit, box_constant])) + mod["main"] = func + mod = infer_mod(mod) + ret_type = mod["main"].checked_type.ret_type.fields + # NB(@jroesch): when we annotate spans the ast fragments before + # annotation the previous fragments will no longer be directly equal. + box = mod.get_global_type_var("box") + expected1 = box(relay.TupleType([])) + expected2 = box(relay.TensorType((), "float32")) + assert ret_type[0] == expected1 + assert ret_type[1] == expected2 + + +def test_adt_match(): + mod = tvm.IRModule() + box, constructor = initialize_box_adt(mod) + + v = relay.Var("v", relay.TensorType((), "float32")) + match = relay.Match( + constructor(relay.const(0, "float32")), + [ + relay.Clause( + relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([]) + ), + # redundant but shouldn't matter to typechecking + relay.Clause(relay.PatternWildcard(), relay.Tuple([])), + ], + ) + + func = relay.Function([], match) + mod["main"] = func + mod = infer_mod(mod) + actual = mod["main"].checked_type.ret_type + assert actual == relay.TupleType([]) + + +def test_adt_match_type_annotations(): + mod = tvm.IRModule() + box, constructor = initialize_box_adt(mod) + + # the only type annotation is inside the match pattern var + # but that should be enough info + tt = relay.TensorType((2, 2), "float32") + x = relay.Var("x") + mv = relay.Var("mv", tt) + match = relay.Match( + constructor(x), + [ + relay.Clause( + relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([]) + ) + ], + ) + + mod["main"] = relay.Function([x], match) + mod = infer_mod(mod) + ft = mod["main"].checked_type + assert ft == relay.FuncType([tt], relay.TupleType([])) + + +def test_let_polymorphism(): + id = relay.Var("id") + xt = relay.TypeVar("xt") + x = relay.Var("x", xt) + body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))]) + body = relay.Let(id, relay.Function([x], x, xt, [xt]), body) + body = infer_expr(body) + int32 = relay.TensorType((), "int32") + tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) + + +def test_if(): + choice_t = relay.FuncType([], relay.scalar_type("bool")) + f = relay.Var("f", choice_t) + true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32")) + false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32")) + top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch)) + ft = infer_expr(top) + tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32")) + + +def test_type_arg_infer(): + code = """ +#[version = "0.0.5"] +def @id[A](%x: A) -> A { + %x +} +def @main(%f: float32) -> float32 { + @id(%f) +} +""" + mod = tvm.parser.fromtext(code) + mod = transform.InferType()(mod) + tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")]) + + +if __name__ == "__main__": + import sys + + pytest.main(sys.argv) From a1a4f3e0304f5f2f303e7c0a727ce6e1f73b8246 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 13 Oct 2020 15:26:54 -0700 Subject: [PATCH 05/32] Borrow code from Egg --- rust/compiler-ext/src/lib.rs | 344 ++++++++++++++++++++++++++++++++++- 1 file changed, 337 insertions(+), 7 deletions(-) diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs index 31e1bb209f98..58bdd0cb29f8 100644 --- a/rust/compiler-ext/src/lib.rs +++ b/rust/compiler-ext/src/lib.rs @@ -1,7 +1,337 @@ -#[cfg(test)] -mod tests { - #[test] - fn it_works() { - assert_eq!(2 + 2, 4); - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + use std::os::raw::c_int; + use tvm::initialize; + use tvm::ir::{tir, PrimExpr}; + use tvm::runtime::function::register_override; + use tvm::runtime::map::Map; + use tvm::runtime::object::{IsObject, IsObjectRef}; + + use ordered_float::NotNan; + + mod interval; + mod math; + + use math::{BoundsMap, Expr, RecExpr}; + use tvm::ir::arith::ConstIntBound; + use tvm_rt::{ObjectRef, array::Array}; + + macro_rules! downcast_match { + ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => { + $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+ + { $default } + } + } + + #[derive(Default)] + struct VarMap { + vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>, + objs: Vec, + } + + impl VarMap { + // FIXME this should eventually do the right thing for TVM variables + // right now it depends on them having unique names + fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol { + let sym = egg::Symbol::from(var.name_hint.as_str().unwrap()); + for (_, sym2) in &self.vars { + if sym == *sym2 { + return sym; + } + } + + self.vars.push((var, sym)); + sym + } + + fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var { + for (v, sym2) in &self.vars { + if sym == *sym2 { + return v.clone(); + } + } + panic!("Should have found a var") + } + + fn push_obj(&mut self, obj: impl IsObjectRef) -> usize { + let i = self.objs.len(); + self.objs.push(obj.upcast()); + i + } + + fn get_obj(&self, i: usize) -> T { + self.objs[i].clone().downcast().expect("bad downcast") + } + } + + fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr { + fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id { + macro_rules! r { + ($e:expr) => { + build(vars, &$e, recexpr) + }; + } + + let dt = recexpr.add(Expr::DataType(p.datatype)); + let e = downcast_match!(p; { + tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]), + tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]), + tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]), + + tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]), + tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]), + tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]), + tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]), + + tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]), + tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]), + + tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]), + tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]), + + tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]), + tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]), + tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]), + tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]), + tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]), + tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]), + + tir::And => Expr::And([dt, r!(p.a), r!(p.b)]), + tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]), + tir::Not => Expr::Not([dt, r!(p.value)]), + + tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]), + + tir::Let => { + let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone()))); + Expr::Let([dt, sym, r!(p.value), r!(p.body)]) + } + tir::Var => { + let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p))); + Expr::Var([dt, sym]) + } + tir::IntImm => { + let int = recexpr.add(Expr::Int(p.value)); + Expr::IntImm([dt, int]) + } + tir::FloatImm => { + let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap())); + Expr::FloatImm([dt, float]) + } + tir::Cast => Expr::Cast([dt, r!(p.value)]), + + tir::Call => { + let op = vars.push_obj(p.op.clone()); + let mut arg_ids = vec![dt]; + for i in 0..p.args.len() { + let arg: PrimExpr = p.args.get(i as isize).expect("array get fail"); + arg_ids.push(r!(arg)); + } + Expr::Call(op, arg_ids) + }, + tir::Load => { + let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone()))); + Expr::Load([dt, sym, r!(p.index), r!(p.predicate)]) + }, + else => { + println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap()); + Expr::Object(vars.push_obj(p.clone())) + } + }); + + recexpr.add(e) + } + + let mut recexpr = Default::default(); + build(vars, prim, &mut recexpr); + recexpr + } + + fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr { + fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr { + let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]); + let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap(); + let prim: PrimExpr = match nodes.last().expect("cannot be empty") { + Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] { + Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), + n => panic!("Expected a symbol, got {:?}", n), + }, + Expr::IntImm([dt, v]) => { + let value = nodes[usize::from(*v)].to_int().unwrap(); + tir::IntImm::new(get_dt(dt), value).upcast() + } + Expr::FloatImm([dt, v]) => { + let value = nodes[usize::from(*v)].to_float().unwrap(); + tir::FloatImm::new(get_dt(dt), value).upcast() + } + Expr::Let([dt, s, value, body]) => { + let var = match &nodes[usize::from(*s)] { + Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), + n => panic!("Expected a symbol, got {:?}", n), + }; + tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast() + } + Expr::Load([dt, s, value, body]) => { + let var = match &nodes[usize::from(*s)] { + Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), + n => panic!("Expected a symbol, got {:?}", n), + }; + tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast() + } + + Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(), + + Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(), + + Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(), + + Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(), + + Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(), + + Expr::Ramp([dt, a, b, c]) => { + let len = &nodes[usize::from(*c)]; + let i = len + .to_int() + .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len)); + tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast() + } + Expr::Broadcast([dt, val, lanes]) => { + let lanes = &nodes[usize::from(*lanes)]; + let lanes = lanes + .to_int() + .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes)); + println!("dt: {}", get_dt(dt)); + tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast() + } + + Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(), + Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(), + Expr::Call(expr, args) => { + let arg_exprs: Vec = args[1..].iter().map(go).collect(); + let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args"); + tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast() + } + + Expr::Object(i) => vars.get_obj(*i), + node => panic!("I don't know how to extract {:?}", node), + }; + assert_ne!(prim.datatype.bits(), 0); + assert_ne!(prim.datatype.lanes(), 0); + prim + } + build(vars, recexpr.as_ref()) + } + + fn run( + input: PrimExpr, + expected: Option, + map: Map, + ) -> Result { + use egg::{CostFunction, Extractor}; + + let mut bounds = BoundsMap::default(); + for (k, v) in map { + if let Ok(var) = k.downcast_clone::() { + let sym: egg::Symbol = var.name_hint.as_str().unwrap().into(); + bounds.insert(sym, (v.min_value, v.max_value)); + } else { + println!("Non var in bounds map: {}", tvm::ir::as_text(k)); + } + } + + let mut vars = VarMap::default(); + let expr = to_egg(&mut vars, &input); + let mut runner = math::default_runner(); + runner.egraph.analysis.bounds = bounds; + + let mut runner = runner.with_expr(&expr).run(&math::rules()); + // runner.print_report(); + let mut extractor = Extractor::new(&runner.egraph, math::CostFn); + let root = runner.egraph.find(runner.roots[0]); + let (cost, best) = extractor.find_best(root); + if let Some(expected) = expected { + let mut expected_vars = VarMap::default(); + let expected_expr = to_egg(&mut expected_vars, &expected); + let expected_root = runner.egraph.add_expr(&expected_expr); + if expected_root != root { + return Err(format!( + "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n", + expected_expr.pretty(40), + best.pretty(40) + )); + } + let expected_cost = math::CostFn.cost_rec(&expected_expr); + if expected_cost != cost { + let msg = format!( + "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n", + expected_cost, + expected_expr.pretty(40), + cost, + best.pretty(40) + ); + if cost < expected_cost { + println!("egg wins: {}", msg) + } else { + return Err(msg); + } + } + } + log::info!(" returning... {}", best.pretty(60)); + Ok(from_egg(&vars, &best)) + } + + fn simplify(prim: PrimExpr, map: Map) -> Result { + log::debug!("map: {:?}", map); + run(prim, None, map).map_err(tvm::Error::CallFailed) + } + + fn simplify_and_check( + prim: PrimExpr, + check: PrimExpr, + map: Map, + ) -> Result { + log::debug!("check map: {:?}", map); + run(prim, Some(check), map).map_err(tvm::Error::CallFailed) + } + + initialize!({ + let _ = env_logger::try_init(); + // NOTE this print prevents a segfault (on Linux) for now... + println!("Initializing simplifier... "); + register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier"); + register_override(simplify_and_check, "egg.simplify_and_check", true) + .expect("failed to initialize simplifier"); + log::debug!("done!"); + }); + \ No newline at end of file From 52177cc264955cef7693b6ade309c2d6c0c4fd1c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 15 Oct 2020 01:03:03 -0700 Subject: [PATCH 06/32] Update CMake and delete old API --- CMakeLists.txt | 1 + cmake/modules/RustExt.cmake | 25 ++- include/tvm/parser/source_map.h | 2 - rust/compiler-ext/Cargo.toml | 5 +- rust/compiler-ext/src/lib.rs | 334 ++------------------------------ rust/tvm-rt/Cargo.toml | 15 +- rust/tvm-sys/Cargo.toml | 1 + rust/tvm-sys/build.rs | 1 + rust/tvm/Cargo.toml | 22 ++- rust/tvm/src/bin/tyck.rs | 1 - rust/tvm/src/ir/diagnostics.rs | 42 ++-- rust/tvm/src/ir/mod.rs | 2 +- rust/tvm/src/ir/relay/mod.rs | 3 +- rust/tvm/src/ir/source_map.rs | 61 ++++++ rust/tvm/src/ir/span.rs | 95 +++++++-- src/ir/expr.cc | 11 ++ src/parser/source_map.cc | 11 -- 17 files changed, 237 insertions(+), 395 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index decdd56537d5..6fc19037e737 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -353,6 +353,7 @@ include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) +include(cmake/modules/RustExt.cmake) include(CheckCXXCompilerFlag) if(NOT MSVC) diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake index 45e46bd1bd62..2ad726e94213 100644 --- a/cmake/modules/RustExt.cmake +++ b/cmake/modules/RustExt.cmake @@ -1,7 +1,14 @@ -if(USE_RUST_EXT) - set(RUST_SRC_DIR "rust") - set(CARGO_OUT_DIR "rust/target" - set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/target/release/libcompiler_ext.dylib") +if(USE_RUST_EXT AND NOT USE_RUST_EXT EQUAL OFF) + set(RUST_SRC_DIR "${CMAKE_SOURCE_DIR}/rust") + set(CARGO_OUT_DIR "${CMAKE_SOURCE_DIR}/rust/target") + + if(USE_RUST_EXT STREQUAL "STATIC") + set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.a") + elseif(USE_RUST_EXT STREQUAL "DYNAMIC") + set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.so") + else() + message(FATAL_ERROR "invalid setting for RUST_EXT") + endif() add_custom_command( OUTPUT "${COMPILER_EXT_PATH}" @@ -9,5 +16,11 @@ if(USE_RUST_EXT) MAIN_DEPENDENCY "${RUST_SRC_DIR}" WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext") - target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE) -endif(USE_RUST_EXT) + add_custom_target(rust_ext ALL DEPENDS "${COMPILER_EXT_PATH}") + + # TODO(@jroesch, @tkonolige): move this to CMake target + # target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE) + list(APPEND TVM_LINKER_LIBS ${COMPILER_EXT_PATH}) + + add_definitions(-DRUST_COMPILER_EXT=1) +endif() diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 5595574265c6..5316c8bd2b33 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -103,8 +103,6 @@ class SourceMap : public ObjectRef { TVM_DLL SourceMap() : SourceMap({}) {} - TVM_DLL static SourceMap Global(); - void Add(const Source& source); SourceMapNode* operator->() { diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml index 76d10eb2e49c..3b13bc5200d9 100644 --- a/rust/compiler-ext/Cargo.toml +++ b/rust/compiler-ext/Cargo.toml @@ -6,8 +6,11 @@ edition = "2018" # TODO(@jroesch): would be cool to figure out how to statically link instead. [lib] -crate-type = ["cdylib"] +crate-type = ["staticlib", "cdylib"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +tvm = { path = "../tvm", default-features = false, features = ["static-linking"] } +log = "*" +env_logger = "*" diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs index 58bdd0cb29f8..3e37d21e756a 100644 --- a/rust/compiler-ext/src/lib.rs +++ b/rust/compiler-ext/src/lib.rs @@ -17,321 +17,19 @@ * under the License. */ - use std::os::raw::c_int; - use tvm::initialize; - use tvm::ir::{tir, PrimExpr}; - use tvm::runtime::function::register_override; - use tvm::runtime::map::Map; - use tvm::runtime::object::{IsObject, IsObjectRef}; - - use ordered_float::NotNan; - - mod interval; - mod math; - - use math::{BoundsMap, Expr, RecExpr}; - use tvm::ir::arith::ConstIntBound; - use tvm_rt::{ObjectRef, array::Array}; - - macro_rules! downcast_match { - ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => { - $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+ - { $default } - } - } - - #[derive(Default)] - struct VarMap { - vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>, - objs: Vec, - } - - impl VarMap { - // FIXME this should eventually do the right thing for TVM variables - // right now it depends on them having unique names - fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol { - let sym = egg::Symbol::from(var.name_hint.as_str().unwrap()); - for (_, sym2) in &self.vars { - if sym == *sym2 { - return sym; - } - } - - self.vars.push((var, sym)); - sym - } - - fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var { - for (v, sym2) in &self.vars { - if sym == *sym2 { - return v.clone(); - } - } - panic!("Should have found a var") - } - - fn push_obj(&mut self, obj: impl IsObjectRef) -> usize { - let i = self.objs.len(); - self.objs.push(obj.upcast()); - i - } - - fn get_obj(&self, i: usize) -> T { - self.objs[i].clone().downcast().expect("bad downcast") - } - } - - fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr { - fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id { - macro_rules! r { - ($e:expr) => { - build(vars, &$e, recexpr) - }; - } - - let dt = recexpr.add(Expr::DataType(p.datatype)); - let e = downcast_match!(p; { - tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]), - tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]), - tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]), - - tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]), - tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]), - tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]), - tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]), - - tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]), - tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]), - - tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]), - tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]), - - tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]), - tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]), - tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]), - tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]), - tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]), - tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]), - - tir::And => Expr::And([dt, r!(p.a), r!(p.b)]), - tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]), - tir::Not => Expr::Not([dt, r!(p.value)]), - - tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]), - - tir::Let => { - let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone()))); - Expr::Let([dt, sym, r!(p.value), r!(p.body)]) - } - tir::Var => { - let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p))); - Expr::Var([dt, sym]) - } - tir::IntImm => { - let int = recexpr.add(Expr::Int(p.value)); - Expr::IntImm([dt, int]) - } - tir::FloatImm => { - let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap())); - Expr::FloatImm([dt, float]) - } - tir::Cast => Expr::Cast([dt, r!(p.value)]), - - tir::Call => { - let op = vars.push_obj(p.op.clone()); - let mut arg_ids = vec![dt]; - for i in 0..p.args.len() { - let arg: PrimExpr = p.args.get(i as isize).expect("array get fail"); - arg_ids.push(r!(arg)); - } - Expr::Call(op, arg_ids) - }, - tir::Load => { - let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone()))); - Expr::Load([dt, sym, r!(p.index), r!(p.predicate)]) - }, - else => { - println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap()); - Expr::Object(vars.push_obj(p.clone())) - } - }); - - recexpr.add(e) - } - - let mut recexpr = Default::default(); - build(vars, prim, &mut recexpr); - recexpr - } - - fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr { - fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr { - let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]); - let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap(); - let prim: PrimExpr = match nodes.last().expect("cannot be empty") { - Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] { - Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), - n => panic!("Expected a symbol, got {:?}", n), - }, - Expr::IntImm([dt, v]) => { - let value = nodes[usize::from(*v)].to_int().unwrap(); - tir::IntImm::new(get_dt(dt), value).upcast() - } - Expr::FloatImm([dt, v]) => { - let value = nodes[usize::from(*v)].to_float().unwrap(); - tir::FloatImm::new(get_dt(dt), value).upcast() - } - Expr::Let([dt, s, value, body]) => { - let var = match &nodes[usize::from(*s)] { - Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), - n => panic!("Expected a symbol, got {:?}", n), - }; - tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast() - } - Expr::Load([dt, s, value, body]) => { - let var = match &nodes[usize::from(*s)] { - Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), - n => panic!("Expected a symbol, got {:?}", n), - }; - tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast() - } - - Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(), - - Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(), - - Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(), - - Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(), - - Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(), - - Expr::Ramp([dt, a, b, c]) => { - let len = &nodes[usize::from(*c)]; - let i = len - .to_int() - .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len)); - tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast() - } - Expr::Broadcast([dt, val, lanes]) => { - let lanes = &nodes[usize::from(*lanes)]; - let lanes = lanes - .to_int() - .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes)); - println!("dt: {}", get_dt(dt)); - tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast() - } - - Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(), - Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(), - Expr::Call(expr, args) => { - let arg_exprs: Vec = args[1..].iter().map(go).collect(); - let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args"); - tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast() - } - - Expr::Object(i) => vars.get_obj(*i), - node => panic!("I don't know how to extract {:?}", node), - }; - assert_ne!(prim.datatype.bits(), 0); - assert_ne!(prim.datatype.lanes(), 0); - prim - } - build(vars, recexpr.as_ref()) - } - - fn run( - input: PrimExpr, - expected: Option, - map: Map, - ) -> Result { - use egg::{CostFunction, Extractor}; - - let mut bounds = BoundsMap::default(); - for (k, v) in map { - if let Ok(var) = k.downcast_clone::() { - let sym: egg::Symbol = var.name_hint.as_str().unwrap().into(); - bounds.insert(sym, (v.min_value, v.max_value)); - } else { - println!("Non var in bounds map: {}", tvm::ir::as_text(k)); - } - } - - let mut vars = VarMap::default(); - let expr = to_egg(&mut vars, &input); - let mut runner = math::default_runner(); - runner.egraph.analysis.bounds = bounds; - - let mut runner = runner.with_expr(&expr).run(&math::rules()); - // runner.print_report(); - let mut extractor = Extractor::new(&runner.egraph, math::CostFn); - let root = runner.egraph.find(runner.roots[0]); - let (cost, best) = extractor.find_best(root); - if let Some(expected) = expected { - let mut expected_vars = VarMap::default(); - let expected_expr = to_egg(&mut expected_vars, &expected); - let expected_root = runner.egraph.add_expr(&expected_expr); - if expected_root != root { - return Err(format!( - "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n", - expected_expr.pretty(40), - best.pretty(40) - )); - } - let expected_cost = math::CostFn.cost_rec(&expected_expr); - if expected_cost != cost { - let msg = format!( - "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n", - expected_cost, - expected_expr.pretty(40), - cost, - best.pretty(40) - ); - if cost < expected_cost { - println!("egg wins: {}", msg) - } else { - return Err(msg); - } - } - } - log::info!(" returning... {}", best.pretty(60)); - Ok(from_egg(&vars, &best)) - } - - fn simplify(prim: PrimExpr, map: Map) -> Result { - log::debug!("map: {:?}", map); - run(prim, None, map).map_err(tvm::Error::CallFailed) - } - - fn simplify_and_check( - prim: PrimExpr, - check: PrimExpr, - map: Map, - ) -> Result { - log::debug!("check map: {:?}", map); - run(prim, Some(check), map).map_err(tvm::Error::CallFailed) - } - - initialize!({ - let _ = env_logger::try_init(); - // NOTE this print prevents a segfault (on Linux) for now... - println!("Initializing simplifier... "); - register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier"); - register_override(simplify_and_check, "egg.simplify_and_check", true) - .expect("failed to initialize simplifier"); - log::debug!("done!"); - }); - \ No newline at end of file +use env_logger; +use tvm; +use tvm::runtime::function::register_override; + +fn test_fn() -> Result<(), tvm::Error> { + println!("Hello from Rust!"); + Ok(()) +} + +#[no_mangle] +fn compiler_ext_initialize() -> i32 { + let _ = env_logger::try_init(); + register_override(test_fn, "rust_ext.test_fn", true).expect("failed to initialize simplifier"); + log::debug!("done!"); + return 0; +} diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml index acece5aeec48..9660943da50d 100644 --- a/rust/tvm-rt/Cargo.toml +++ b/rust/tvm-rt/Cargo.toml @@ -28,19 +28,26 @@ categories = ["api-bindings", "science"] authors = ["TVM Contributors"] edition = "2018" +[features] +default = ["dynamic-linking"] +dynamic-linking = ["tvm-sys/bindings"] +static-linking = [] +blas = ["ndarray/blas"] + [dependencies] thiserror = "^1.0" ndarray = "0.12" num-traits = "0.2" -tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] } tvm-macros = { version = "0.1", path = "../tvm-macros" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" memoffset = "0.5.6" +[dependencies.tvm-sys] +version = "0.1" +default-features = false +path = "../tvm-sys/" + [dev-dependencies] anyhow = "^1.0" - -[features] -blas = ["ndarray/blas"] diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml index 4e3fc98b4e75..c25a5bf3d957 100644 --- a/rust/tvm-sys/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -23,6 +23,7 @@ license = "Apache-2.0" edition = "2018" [features] +default = ["bindings"] bindings = [] [dependencies] diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index 05806c0d5ce0..2d86c4b9b5ed 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -60,6 +60,7 @@ fn main() -> Result<()> { if cfg!(feature = "bindings") { println!("cargo:rerun-if-env-changed=TVM_HOME"); println!("cargo:rustc-link-lib=dylib=tvm"); + println!("cargo:rustc-link-lib=dylib=llvm-10"); println!("cargo:rustc-link-search={}/build", tvm_home); } diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index 71a4b937460f..153a1950e46b 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -28,14 +28,24 @@ categories = ["api-bindings", "science"] authors = ["TVM Contributors"] edition = "2018" +[features] +default = ["python", "dynamic-linking"] +dynamic-linking = ["tvm-rt/dynamic-linking"] +static-linking = ["tvm-rt/static-linking"] +blas = ["ndarray/blas"] +python = ["pyo3"] + +[dependencies.tvm-rt] +version = "0.1" +default-features = false +path = "../tvm-rt/" + [dependencies] thiserror = "^1.0" anyhow = "^1.0" lazy_static = "1.1" ndarray = "0.12" num-traits = "0.2" -tvm-rt = { version = "0.1", path = "../tvm-rt/" } -tvm-sys = { version = "0.1", path = "../tvm-sys/" } tvm-macros = { version = "*", path = "../tvm-macros/" } paste = "0.1" mashup = "0.1" @@ -44,8 +54,6 @@ pyo3 = { version = "0.11.1", optional = true } codespan-reporting = "0.9.5" structopt = { version = "0.3" } -[features] -default = ["python"] - -blas = ["ndarray/blas"] -python = ["pyo3"] +[[bin]] +name = "tyck" +required-features = ["dynamic-linking"] diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs index 9300412585a7..b869012e4e6c 100644 --- a/rust/tvm/src/bin/tyck.rs +++ b/rust/tvm/src/bin/tyck.rs @@ -6,7 +6,6 @@ use structopt::StructOpt; use tvm::ir::diagnostics::codespan; use tvm::ir::IRModule; - #[derive(Debug, StructOpt)] #[structopt(name = "tyck", about = "Parse and type check a Relay program.")] struct Opt { diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs index d30618593e1e..b76e43f5b6de 100644 --- a/rust/tvm/src/ir/diagnostics.rs +++ b/rust/tvm/src/ir/diagnostics.rs @@ -17,17 +17,20 @@ * under the License. */ +use super::module::IRModule; +use super::span::Span; +use crate::runtime::function::Result; +use crate::runtime::object::{Object, ObjectPtr, ObjectRef}; +use crate::runtime::{ + array::Array, + function::{self, Function, ToFunction}, + string::String as TString, +}; /// The diagnostic interface to TVM, used for reporting and rendering /// diagnostic information by the compiler. This module exposes /// three key abstractions: a Diagnostic, the DiagnosticContext, /// and the DiagnosticRenderer. - -use tvm_macros::{Object, external}; -use super::module::IRModule; -use crate::runtime::{function::{self, Function, ToFunction}, array::Array, string::String as TString}; -use crate::runtime::object::{Object, ObjectPtr, ObjectRef}; -use crate::runtime::function::Result; -use super::span::Span; +use tvm_macros::{external, Object}; type SourceName = ObjectRef; @@ -134,7 +137,6 @@ pub struct DiagnosticRendererNode { // memory layout } - // def render(self, ctx): // """ // Render the provided context. @@ -169,7 +171,8 @@ pub struct DiagnosticContextNode { /// and contains a renderer. impl DiagnosticContext { pub fn new(module: IRModule, render_func: F) -> DiagnosticContext - where F: Fn(DiagnosticContext) -> () + 'static + where + F: Fn(DiagnosticContext) -> () + 'static, { let renderer = diagnostic_renderer(render_func.to_function()).unwrap(); let node = DiagnosticContextNode { @@ -210,21 +213,16 @@ impl DiagnosticContext { // If the render_func is None it will remove the current custom renderer // and return to default behavior. fn override_renderer(opt_func: Option) -> Result<()> -where F: Fn(DiagnosticContext) -> () + 'static +where + F: Fn(DiagnosticContext) -> () + 'static, { - match opt_func { None => clear_renderer(), Some(func) => { let func = func.to_function(); - let render_factory = move || { - diagnostic_renderer(func.clone()).unwrap() - }; + let render_factory = move || diagnostic_renderer(func.clone()).unwrap(); - function::register_override( - render_factory, - "diagnostics.OverrideRenderer", - true)?; + function::register_override(render_factory, "diagnostics.OverrideRenderer", true)?; Ok(()) } @@ -243,9 +241,9 @@ pub mod codespan { End, } - struct SpanToBytes { - inner: HashMap { file_id: FileId, @@ -276,7 +274,7 @@ pub mod codespan { .with_message(message) .with_code("EXXX") .with_labels(vec![ - Label::primary(file_id, 328..331).with_message(inner_message), + Label::primary(file_id, 328..331).with_message(inner_message) ]); diagnostic diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 8450bd790445..401b6c289966 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -19,8 +19,8 @@ pub mod arith; pub mod attrs; -pub mod expr; pub mod diagnostics; +pub mod expr; pub mod function; pub mod module; pub mod op; diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index e539221d1db6..4b091285d245 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -28,6 +28,7 @@ use super::attrs::Attrs; use super::expr::BaseExprNode; use super::function::BaseFuncNode; use super::ty::{Type, TypeNode}; +use super::span::Span; use tvm_macros::Object; use tvm_rt::NDArray; @@ -51,7 +52,7 @@ impl ExprNode { span: ObjectRef::null(), checked_type: Type::from(TypeNode { base: Object::base_object::(), - span: ObjectRef::null(), + span: Span::empty(), }), } } diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index e69de29bb2d1..e6c037108cf0 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::map::Map; +use crate::runtime::object::Object; + +/// A program source in any language. +/// +/// Could represent the source from an ML framework or a source of an IRModule. +#[repr(C)] +#[derive(Object)] +#[type_key = "Source"] +#[ref_key = "Source"] +struct SourceNode { + pub base: Object, + /*! \brief The source name. */ + SourceName source_name; + + /*! \brief The raw source. */ + String source; + + /*! \brief A mapping of line breaks into the raw source. */ + std::vector> line_map; +} + + +// class Source : public ObjectRef { +// public: +// TVM_DLL Source(SourceName src_name, std::string source); +// TVM_DLL tvm::String GetLine(int line); + +// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); +// }; + + +/// A mapping from a unique source name to source fragments. +#[repr(C)] +#[derive(Object)] +#[type_key = "SourceMap"] +#[ref_key = "SourceMap"] +struct SourceMapNode { + pub base: Object, + /// The source mapping. + pub source_map: Map, +} diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index d2e19a25a950..c54fd5143c12 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -1,22 +1,75 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::runtime::ObjectRef; - -pub type Span = ObjectRef; +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the + +* specific language governing permissions and limitations +* under the License. +*/ + +use crate::runtime::{ObjectRef, Object, String as TString}; +use tvm_macros::Object; + +/// A source file name, contained in a Span. + +#[repr(C)] +#[derive(Object)] +#[type_key = "SourceName"] +#[ref_name = "SourceName"] +pub struct SourceNameNode { + pub base: Object, + pub name: TString, +} + +// /*! +// * \brief The source name of a file span. +// * \sa SourceNameNode, Span +// */ +// class SourceName : public ObjectRef { +// public: +// /*! +// * \brief Get an SourceName for a given operator name. +// * Will raise an error if the source name has not been registered. +// * \param name Name of the operator. +// * \return SourceName valid throughout program lifetime. +// */ +// TVM_DLL static SourceName Get(const String& name); + +// TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); +// }; + +/// Span information for diagnostic purposes. +#[repr(C)] +#[derive(Object)] +#[type_key = "Span"] +#[ref_name = "Span"] +pub struct SpanNode { + pub base: Object, + /// The source name. + pub source_name: SourceName, + /// The line number. + pub line: i32, + /// The column offset. + pub column: i32, + /// The end line number. + pub end_line: i32, + /// The end column number. + pub end_column: i32, +} + +impl Span { + pub fn empty() -> Span { + todo!() + } +} diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 05d41cf204d6..62e28483ffd5 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -192,4 +192,15 @@ TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { return ss.str(); }); + + } // namespace tvm + +#ifdef RUST_COMPILER_EXT + +extern "C" { + int compiler_ext_initialize(); + static int test = compiler_ext_initialize(); +} + +#endif diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 40998b0c9dc4..c6ea808733e1 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -77,12 +77,6 @@ tvm::String Source::GetLine(int line) { return line_text; } -// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -// .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { -// auto* node = static_cast(ref.get()); -// p->stream << "SourceName(" << node->name << ", " << node << ")"; -// }); - TVM_REGISTER_NODE_TYPE(SourceMapNode); SourceMap::SourceMap(Map source_map) { @@ -91,11 +85,6 @@ SourceMap::SourceMap(Map source_map) { data_ = std::move(n); } -// TODO(@jroesch): fix this -static SourceMap global_source_map = SourceMap(Map()); - -SourceMap SourceMap::Global() { return global_source_map; } - void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) { From c5b40617639ea1b6cc729d8147ea8a3008aceacc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 15 Oct 2020 01:42:14 -0700 Subject: [PATCH 07/32] Fix Linux build --- cmake/modules/LLVM.cmake | 7 ++++++- rust/tvm-sys/Cargo.toml | 2 +- rust/tvm-sys/build.rs | 3 +-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index 5f8ace17111f..ca4ecd6db1ca 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -16,7 +16,12 @@ # under the License. # LLVM rules -add_definitions(-DDMLC_USE_FOPEN64=0) +# Due to LLVM debug symbols you can sometimes face linking issues on +# certain compiler, platform combinations if you don't set NDEBUG. +# +# See https://github.com/imageworks/OpenShadingLanguage/issues/1069 +# for more discussion. +add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1) # Test if ${USE_LLVM} is not an explicit boolean false # It may be a boolean or a string diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml index c25a5bf3d957..2952aa4938d7 100644 --- a/rust/tvm-sys/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -23,7 +23,7 @@ license = "Apache-2.0" edition = "2018" [features] -default = ["bindings"] +default = [] bindings = [] [dependencies] diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index 2d86c4b9b5ed..159023463e8d 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -60,8 +60,7 @@ fn main() -> Result<()> { if cfg!(feature = "bindings") { println!("cargo:rerun-if-env-changed=TVM_HOME"); println!("cargo:rustc-link-lib=dylib=tvm"); - println!("cargo:rustc-link-lib=dylib=llvm-10"); - println!("cargo:rustc-link-search={}/build", tvm_home); + println!("cargo:rustc-link-search=native={}/build", tvm_home); } // @see rust-bindgen#550 for `blacklist_type` From 29754ae570b55701eee30b0e66641cdbfa955d02 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 15 Oct 2020 17:03:00 -0700 Subject: [PATCH 08/32] Clean up exporting to show off new diagnostics --- rust/compiler-ext/src/lib.rs | 12 ++++++++++-- rust/tvm-rt/src/array.rs | 32 ++++++++++++++++++++++++++++++++ rust/tvm/src/bin/tyck.rs | 7 ++++++- rust/tvm/src/ir/diagnostics.rs | 10 +++++----- rust/tvm/src/ir/mod.rs | 1 + rust/tvm/src/ir/module.rs | 3 +++ rust/tvm/src/ir/source_map.rs | 26 +++++++++++++++----------- rust/tvm/src/lib.rs | 24 ++++++++++++++++++++++++ 8 files changed, 96 insertions(+), 19 deletions(-) diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs index 3e37d21e756a..c136d06f492c 100644 --- a/rust/compiler-ext/src/lib.rs +++ b/rust/compiler-ext/src/lib.rs @@ -22,14 +22,22 @@ use tvm; use tvm::runtime::function::register_override; fn test_fn() -> Result<(), tvm::Error> { - println!("Hello from Rust!"); + println!("Hello Greg from Rust!"); Ok(()) } +fn test_fn2(message: tvm::runtime::string::String) -> Result<(), tvm::Error> { + println!("The message: {}", message); + Ok(()) +} + +tvm::export!(test_fn, test_fn2); + #[no_mangle] fn compiler_ext_initialize() -> i32 { let _ = env_logger::try_init(); - register_override(test_fn, "rust_ext.test_fn", true).expect("failed to initialize simplifier"); + tvm_export("rust_ext") + .expect("failed to initialize Rust compiler_ext"); log::debug!("done!"); return 0; } diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 5e19cefd8e97..032ca79bf744 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -19,6 +19,7 @@ use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; +use std::iter::{IntoIterator, Iterator}; use crate::errors::Error; use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef}; @@ -81,6 +82,37 @@ impl Array { } } +pub struct IntoIter { + array: Array, + pos: isize, + size: isize, +} + +impl Iterator for IntoIter { + type Item = T; + + fn next(&mut self) -> Option { + if self.pos < self.size { + let item = self.array.get(self.pos) + .expect("should not fail"); + self.pos += 1; + Some(item) + } else { + None + } + } +} + +impl IntoIterator for Array { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + let size = self.len() as isize; + IntoIter { array: self, pos: 0, size: size } + } +} + impl From> for ArgValue<'static> { fn from(array: Array) -> ArgValue<'static> { array.object.into() diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs index b869012e4e6c..e0c71369fe0a 100644 --- a/rust/tvm/src/bin/tyck.rs +++ b/rust/tvm/src/bin/tyck.rs @@ -18,6 +18,11 @@ fn main() -> Result<()> { codespan::init().expect("Rust based diagnostics"); let opt = Opt::from_args(); println!("{:?}", &opt); - let file = IRModule::parse_file(opt.input)?; + let module = IRModule::parse_file(opt.input)?; + + // for (k, v) in module.functions { + // println!("Function name: {:?}", v); + // } + Ok(()) } diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs index b76e43f5b6de..4975a45f45ee 100644 --- a/rust/tvm/src/ir/diagnostics.rs +++ b/rust/tvm/src/ir/diagnostics.rs @@ -135,6 +135,7 @@ pub struct DiagnosticRendererNode { pub base: Object, // TODO(@jroesch): we can't easily exposed packed functions due to // memory layout + // missing field here } // def render(self, ctx): @@ -283,11 +284,10 @@ pub mod codespan { pub fn init() -> Result<()> { let mut files: SimpleFiles = SimpleFiles::new(); let render_fn = move |diag_ctx: DiagnosticContext| { - // let source_map = diag_ctx.module.source_map; - // for diagnostic in diag_ctx.diagnostics { - - // } - panic!("render_fn"); + let source_map = diag_ctx.module.source_map.clone(); + for diagnostic in diag_ctx.diagnostics.clone() { + println!("Diagnostic: {}", diagnostic.message); + } }; override_renderer(Some(render_fn))?; diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 401b6c289966..df9bc688cb32 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -26,6 +26,7 @@ pub mod module; pub mod op; pub mod relay; pub mod span; +pub mod source_map; pub mod tir; pub mod ty; diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index e0444b3101da..5156e7445012 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -25,6 +25,7 @@ use crate::runtime::{external, Object, ObjectRef}; use super::expr::GlobalVar; use super::function::BaseFunc; +use super::source_map::SourceMap; use std::io::Result as IOResult; use std::path::Path; @@ -43,6 +44,8 @@ pub struct IRModuleNode { pub base: Object, pub functions: Map, pub type_definitions: Map, + pub source_map: SourceMap, + // TODO(@jroesch): this is missing some fields } external! { diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index e6c037108cf0..56c0830f4a77 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -12,7 +12,7 @@ * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the + * KIND, either exprss or implied. See the License for the * specific language governing permissions and limitations * under the License. */ @@ -20,23 +20,27 @@ use crate::runtime::map::Map; use crate::runtime::object::Object; +use super::span::{SourceName, Span}; + +use tvm_macros::Object; + /// A program source in any language. /// /// Could represent the source from an ML framework or a source of an IRModule. #[repr(C)] #[derive(Object)] #[type_key = "Source"] -#[ref_key = "Source"] -struct SourceNode { +#[ref_name = "Source"] +pub struct SourceNode { pub base: Object, - /*! \brief The source name. */ - SourceName source_name; + /// The source name. */ + pub source_name: SourceName, - /*! \brief The raw source. */ - String source; + /// The raw source. */ + source: String, - /*! \brief A mapping of line breaks into the raw source. */ - std::vector> line_map; + // A mapping of line breaks into the raw source. + // std::vector> line_map; } @@ -53,8 +57,8 @@ struct SourceNode { #[repr(C)] #[derive(Object)] #[type_key = "SourceMap"] -#[ref_key = "SourceMap"] -struct SourceMapNode { +#[ref_name = "SourceMap"] +pub struct SourceMapNode { pub base: Object, /// The source mapping. pub source_map: Map, diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index 36c750328249..d193f09803ae 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -47,3 +47,27 @@ pub mod runtime; pub mod transform; pub use runtime::version; + +#[macro_export] +macro_rules! export { + ($($fn_names:expr),*) => { + pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> { + $( + register_override($fn_name, concat!($ns, stringfy!($fn_name)), true)?; + )* + Ok(()) + } + } +} + +#[macro_export] +macro_rules! export_mod { + ($ns:expr, $($mod_name:expr),*) => { + pub fn tvm_mod_export() -> Result<(), tvm::Error> { + $( + $mod_names::tvm_export($ns)?; + )* + Ok(()) + } + } +} From 39c90daf60b1cb5db9fdddf81212730d21642bab Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 15 Oct 2020 20:24:43 -0700 Subject: [PATCH 09/32] Improve Rust bindings --- rust/tvm/src/ir/diagnostics/codespan.rs | 131 ++++++++++++++++++ .../ir/{diagnostics.rs => diagnostics/mod.rs} | 69 +-------- rust/tvm/src/ir/source_map.rs | 3 +- rust/tvm/test.rly | 3 +- src/ir/diagnostic.cc | 1 + 5 files changed, 138 insertions(+), 69 deletions(-) create mode 100644 rust/tvm/src/ir/diagnostics/codespan.rs rename rust/tvm/src/ir/{diagnostics.rs => diagnostics/mod.rs} (76%) diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs new file mode 100644 index 000000000000..80a8784a219a --- /dev/null +++ b/rust/tvm/src/ir/diagnostics/codespan.rs @@ -0,0 +1,131 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; + +use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity}; +use codespan_reporting::files::SimpleFiles; +use codespan_reporting::term::termcolor::{ColorChoice, StandardStream}; + +use crate::ir::source_map::*; +use super::*; + +enum StartOrEnd { + Start, + End, +} + +enum FileSpanToByteRange { + AsciiSource, + Utf8 { + /// Map character regions which are larger then 1-byte to length. + lengths: HashMap, + source: String, + } +} + +impl FileSpanToByteRange { + fn new(source: String) -> FileSpanToByteRange { + let mut last_index = 0; + let mut is_ascii = true; + if source.is_ascii() { + FileSpanToByteRange::AsciiSource + } else { + panic!() + } + + // for (index, _) in source.char_indices() { + // if last_index - 1 != last_index { + // is_ascii = false; + // } else { + // panic!(); + // } + // last_index = index; + // } + } +} + +struct SpanToByteRange { + map: HashMap +} + +impl SpanToByteRange { + fn new() -> SpanToByteRange { + SpanToByteRange { map: HashMap::new() } + } + + pub fn add_source(&mut self, source: Source) { + let source_name: String = source.source_name.name.as_str().expect("foo").into(); + + if self.map.contains_key(&source_name) { + panic!() + } else { + let source = source.source.as_str().expect("fpp").into(); + self.map.insert(source_name, FileSpanToByteRange::new(source)); + } + } +} + +struct ByteRange { + file_id: FileId, + start_pos: usize, + end_pos: usize, +} + + +pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic { + let severity = match diag.level { + DiagnosticLevel::Error => Severity::Error, + DiagnosticLevel::Warning => Severity::Warning, + DiagnosticLevel::Note => Severity::Note, + DiagnosticLevel::Help => Severity::Help, + DiagnosticLevel::Bug => Severity::Bug, + }; + + let file_id = "foo".into(); // diag.span.source_name; + + let message: String = diag.message.as_str().unwrap().into(); + let inner_message: String = "expected `String`, found `Nat`".into(); + let diagnostic = CDiagnostic::new(severity) + .with_message(message) + .with_code("EXXX") + .with_labels(vec![ + Label::primary(file_id, 328..331).with_message(inner_message) + ]); + + diagnostic +} + +struct DiagnosticState { + files: SimpleFiles, + span_map: SpanToByteRange, +} + +impl DiagnosticState { + fn new() -> DiagnosticState { + DiagnosticState { + files: SimpleFiles::new(), + span_map: SpanToByteRange::new(), + } + } +} + +fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) { + let source_map = diag_ctx.module.source_map.clone(); + for diagnostic in diag_ctx.diagnostics.clone() { + match source_map.source_map.get(&diagnostic.span.source_name) { + Err(err) => panic!(), + Ok(source) => state.span_map.add_source(source), + } + println!("Diagnostic: {}", diagnostic.message); + } +} + +pub fn init() -> Result<()> { + let diag_state = Arc::new(Mutex::new(DiagnosticState::new())); + let render_fn = move |diag_ctx: DiagnosticContext| { + // let mut guard = diag_state.lock().unwrap(); + // renderer(&mut *guard, diag_ctx); + }; + + override_renderer(Some(render_fn))?; + Ok(()) +} diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics/mod.rs similarity index 76% rename from rust/tvm/src/ir/diagnostics.rs rename to rust/tvm/src/ir/diagnostics/mod.rs index 4975a45f45ee..fce214a69c6a 100644 --- a/rust/tvm/src/ir/diagnostics.rs +++ b/rust/tvm/src/ir/diagnostics/mod.rs @@ -18,7 +18,7 @@ */ use super::module::IRModule; -use super::span::Span; +use super::span::*; use crate::runtime::function::Result; use crate::runtime::object::{Object, ObjectPtr, ObjectRef}; use crate::runtime::{ @@ -32,7 +32,7 @@ use crate::runtime::{ /// and the DiagnosticRenderer. use tvm_macros::{external, Object}; -type SourceName = ObjectRef; +pub mod codespan; // Get the the diagnostic renderer. external! { @@ -229,68 +229,3 @@ where } } } - -pub mod codespan { - use super::*; - - use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity}; - use codespan_reporting::files::SimpleFiles; - use codespan_reporting::term::termcolor::{ColorChoice, StandardStream}; - - enum StartOrEnd { - Start, - End, - } - - // struct SpanToBytes { - // inner: HashMap { - file_id: FileId, - start_pos: usize, - end_pos: usize, - } - - // impl SpanToBytes { - // fn to_byte_pos(&self, span: tvm::ir::Span) -> ByteRange { - - // } - // } - - pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic { - let severity = match diag.level { - DiagnosticLevel::Error => Severity::Error, - DiagnosticLevel::Warning => Severity::Warning, - DiagnosticLevel::Note => Severity::Note, - DiagnosticLevel::Help => Severity::Help, - DiagnosticLevel::Bug => Severity::Bug, - }; - - let file_id = "foo".into(); // diag.span.source_name; - - let message: String = diag.message.as_str().unwrap().into(); - let inner_message: String = "expected `String`, found `Nat`".into(); - let diagnostic = CDiagnostic::new(severity) - .with_message(message) - .with_code("EXXX") - .with_labels(vec![ - Label::primary(file_id, 328..331).with_message(inner_message) - ]); - - diagnostic - } - - pub fn init() -> Result<()> { - let mut files: SimpleFiles = SimpleFiles::new(); - let render_fn = move |diag_ctx: DiagnosticContext| { - let source_map = diag_ctx.module.source_map.clone(); - for diagnostic in diag_ctx.diagnostics.clone() { - println!("Diagnostic: {}", diagnostic.message); - } - }; - - override_renderer(Some(render_fn))?; - Ok(()) - } -} diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index 56c0830f4a77..ebe7e46804fa 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -19,6 +19,7 @@ use crate::runtime::map::Map; use crate::runtime::object::Object; +use crate::runtime::string::{String as TString}; use super::span::{SourceName, Span}; @@ -37,7 +38,7 @@ pub struct SourceNode { pub source_name: SourceName, /// The raw source. */ - source: String, + pub source: TString, // A mapping of line breaks into the raw source. // std::vector> line_map; diff --git a/rust/tvm/test.rly b/rust/tvm/test.rly index d8b7c6960fef..e9407b029787 100644 --- a/rust/tvm/test.rly +++ b/rust/tvm/test.rly @@ -1,2 +1,3 @@ #[version = "0.0.5"] -fn @main(%x: int32) -> float32 { %x } + +def @main(%x: int32) -> float32 { %x } diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 148831dc3ab6..c79ed3dd6969 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -113,6 +113,7 @@ TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender") }); DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) { + CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function"; auto n = make_object(); n->module = module; n->renderer = renderer; From c1b994c81ca9f2854a5727a935effdd5ca5cf909 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 15 Oct 2020 21:37:34 -0700 Subject: [PATCH 10/32] Fix calling --- rust/tvm-rt/src/function.rs | 28 +++++++++++++++++----------- rust/tvm/src/bin/tyck.rs | 2 +- rust/tvm/src/ir/diagnostics/mod.rs | 4 +--- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index bae06e929361..c7aebdd46d57 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -33,6 +33,7 @@ use std::{ }; use crate::errors::Error; +use crate::object::ObjectPtr; pub use super::to_function::{ToFunction, Typed}; pub use tvm_sys::{ffi, ArgValue, RetValue}; @@ -120,21 +121,26 @@ impl Function { let mut ret_val = ffi::TVMValue { v_int64: 0 }; let mut ret_type_code = 0i32; - check_call!(ffi::TVMFuncCall( - self.handle, - values.as_mut_ptr() as *mut ffi::TVMValue, - type_codes.as_mut_ptr() as *mut c_int, - num_args as c_int, - &mut ret_val as *mut _, - &mut ret_type_code as *mut _ - )); + let ret_code = unsafe { + ffi::TVMFuncCall( + self.handle, + values.as_mut_ptr() as *mut ffi::TVMValue, + type_codes.as_mut_ptr() as *mut c_int, + num_args as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _ + ) + }; + + if ret_code != 0 { + return Err(Error::CallFailed(crate::get_last_error().into())); + } let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); match rv { RetValue::ObjectHandle(object) => { - let optr = crate::object::ObjectPtr::from_raw(object as _).unwrap(); - // println!("after wrapped call: {}", optr.count()); - crate::object::ObjectPtr::leak(optr); + let optr = ObjectPtr::from_raw(object as _).unwrap(); + ObjectPtr::leak(optr); } _ => {} }; diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs index e0c71369fe0a..fbab027ab10e 100644 --- a/rust/tvm/src/bin/tyck.rs +++ b/rust/tvm/src/bin/tyck.rs @@ -18,7 +18,7 @@ fn main() -> Result<()> { codespan::init().expect("Rust based diagnostics"); let opt = Opt::from_args(); println!("{:?}", &opt); - let module = IRModule::parse_file(opt.input)?; + let module = IRModule::parse_file(opt.input); // for (k, v) in module.functions { // println!("Function name: {:?}", v); diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs index fce214a69c6a..039d1ed347cd 100644 --- a/rust/tvm/src/ir/diagnostics/mod.rs +++ b/rust/tvm/src/ir/diagnostics/mod.rs @@ -207,9 +207,7 @@ impl DiagnosticContext { } } -// Override the global diagnostics renderer. -// Params -// ------ +/// Override the global diagnostics renderer. // render_func: Option[Callable[[DiagnosticContext], None]] // If the render_func is None it will remove the current custom renderer // and return to default behavior. From 7ea0c341374c3377901123e4bb9af4dcbab33afb Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 16 Oct 2020 02:11:53 -0700 Subject: [PATCH 11/32] Fix --- rust/tvm/src/ir/module.rs | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 5156e7445012..11d6c491842c 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -16,6 +16,11 @@ * specific language governing permissions and limitations * under the License. */ +use std::io::Result as IOResult; +use std::path::Path; + +use thiserror::Error; +use tvm_macros::Object; use crate::runtime::array::Array; use crate::runtime::function::Result; @@ -27,15 +32,19 @@ use super::expr::GlobalVar; use super::function::BaseFunc; use super::source_map::SourceMap; -use std::io::Result as IOResult; -use std::path::Path; - -use tvm_macros::Object; // TODO(@jroesch): define type type TypeData = ObjectRef; type GlobalTypeVar = ObjectRef; +#[derive(Error, Debug)] +pub enum Error { + #[error("{0}")] + IO(#[from] std::io::Error), + #[error("{0}")] + TVM(#[from] crate::runtime::Error), +} + #[repr(C)] #[derive(Object)] #[ref_name = "IRModule"] @@ -116,19 +125,19 @@ external! { // }); impl IRModule { - pub fn parse(file_name: N, source: S) -> IRModule + pub fn parse(file_name: N, source: S) -> Result where N: Into, S: Into, { - parse_module(file_name.into(), source.into()).expect("failed to call parser") + parse_module(file_name.into(), source.into()) } - pub fn parse_file>(file_path: P) -> IOResult { + pub fn parse_file>(file_path: P) -> std::result::Result { let file_path = file_path.as_ref(); let file_path_as_str = file_path.to_str().unwrap().to_string(); let source = std::fs::read_to_string(file_path)?; - let module = IRModule::parse(file_path_as_str, source); + let module = IRModule::parse(file_path_as_str, source)?; Ok(module) } From 46c46ad85452f53a0d9ebb6945bfb21d28491e95 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 16 Oct 2020 14:17:01 -0700 Subject: [PATCH 12/32] Rust Diagnostics work --- rust/tvm-rt/src/errors.rs | 15 +++ rust/tvm-rt/src/function.rs | 7 +- rust/tvm/src/bin/tyck.rs | 13 +-- rust/tvm/src/ir/diagnostics/codespan.rs | 126 +++++++++++++++++------- 4 files changed, 117 insertions(+), 44 deletions(-) diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index c884c56fed44..3de9f3cf3ee9 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -68,6 +68,21 @@ pub enum Error { Infallible(#[from] std::convert::Infallible), #[error("a panic occurred while executing a Rust packed function")] Panic, + #[error("one or more error diagnostics were emitted, please check diagnostic render for output.")] + DiagnosticError(String), + #[error("{0}")] + Raw(String), +} + +impl Error { + pub fn from_raw_tvm(raw: &str) -> Error { + let err_header = raw.find(":").unwrap_or(0); + let (err_ty, err_content) = raw.split_at(err_header); + match err_ty { + "DiagnosticError" => Error::DiagnosticError((&err_content[1..]).into()), + _ => Error::Raw(raw.into()), + } + } } impl Error { diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index c7aebdd46d57..173b60a222ba 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -133,7 +133,12 @@ impl Function { }; if ret_code != 0 { - return Err(Error::CallFailed(crate::get_last_error().into())); + let raw_error = crate::get_last_error(); + let error = match Error::from_raw_tvm(raw_error) { + Error::Raw(string) => Error::CallFailed(string), + e => e, + }; + return Err(error); } let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs index fbab027ab10e..13470e776195 100644 --- a/rust/tvm/src/bin/tyck.rs +++ b/rust/tvm/src/bin/tyck.rs @@ -4,7 +4,8 @@ use anyhow::Result; use structopt::StructOpt; use tvm::ir::diagnostics::codespan; -use tvm::ir::IRModule; +use tvm::ir::{self, IRModule}; +use tvm::runtime::Error; #[derive(Debug, StructOpt)] #[structopt(name = "tyck", about = "Parse and type check a Relay program.")] @@ -18,11 +19,11 @@ fn main() -> Result<()> { codespan::init().expect("Rust based diagnostics"); let opt = Opt::from_args(); println!("{:?}", &opt); - let module = IRModule::parse_file(opt.input); - - // for (k, v) in module.functions { - // println!("Function name: {:?}", v); - // } + let _module = match IRModule::parse_file(opt.input) { + Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => { return Ok(()) }, + Err(e) => { return Err(e.into()); }, + Ok(module) => module + }; Ok(()) } diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs index 80a8784a219a..9fc1ee00ae51 100644 --- a/rust/tvm/src/ir/diagnostics/codespan.rs +++ b/rust/tvm/src/ir/diagnostics/codespan.rs @@ -4,6 +4,7 @@ use std::sync::{Arc, Mutex}; use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity}; use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::termcolor::{ColorChoice, StandardStream}; +use codespan_reporting::term::{self, ColorArg}; use crate::ir::source_map::*; use super::*; @@ -13,8 +14,14 @@ enum StartOrEnd { End, } +struct ByteRange { + file_id: FileId, + start_pos: usize, + end_pos: usize, +} + enum FileSpanToByteRange { - AsciiSource, + AsciiSource(Vec), Utf8 { /// Map character regions which are larger then 1-byte to length. lengths: HashMap, @@ -27,7 +34,12 @@ impl FileSpanToByteRange { let mut last_index = 0; let mut is_ascii = true; if source.is_ascii() { - FileSpanToByteRange::AsciiSource + let line_lengths = + source + .lines() + .map(|line| line.len()) + .collect(); + FileSpanToByteRange::AsciiSource(line_lengths) } else { panic!() } @@ -41,6 +53,21 @@ impl FileSpanToByteRange { // last_index = index; // } } + + fn lookup(&self, span: &Span) -> ByteRange { + use FileSpanToByteRange::*; + + let source_name: String = span.source_name.name.as_str().unwrap().into(); + + match self { + AsciiSource(ref line_lengths) => { + let start_pos = (&line_lengths[0..(span.line - 1) as usize]).into_iter().sum::() + (span.column) as usize; + let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]).into_iter().sum::() + (span.end_column) as usize; + ByteRange { file_id: source_name, start_pos, end_pos } + }, + _ => panic!() + } + } } struct SpanToByteRange { @@ -62,41 +89,22 @@ impl SpanToByteRange { self.map.insert(source_name, FileSpanToByteRange::new(source)); } } -} - -struct ByteRange { - file_id: FileId, - start_pos: usize, - end_pos: usize, -} - - -pub fn to_diagnostic(diag: super::Diagnostic) -> CDiagnostic { - let severity = match diag.level { - DiagnosticLevel::Error => Severity::Error, - DiagnosticLevel::Warning => Severity::Warning, - DiagnosticLevel::Note => Severity::Note, - DiagnosticLevel::Help => Severity::Help, - DiagnosticLevel::Bug => Severity::Bug, - }; - - let file_id = "foo".into(); // diag.span.source_name; - let message: String = diag.message.as_str().unwrap().into(); - let inner_message: String = "expected `String`, found `Nat`".into(); - let diagnostic = CDiagnostic::new(severity) - .with_message(message) - .with_code("EXXX") - .with_labels(vec![ - Label::primary(file_id, 328..331).with_message(inner_message) - ]); + pub fn lookup(&self, span: &Span) -> ByteRange { + let source_name: String = span.source_name.name.as_str().expect("foo").into(); - diagnostic + match self.map.get(&source_name) { + Some(file_span_to_bytes) => file_span_to_bytes.lookup(span), + None => panic!(), + } + } } struct DiagnosticState { files: SimpleFiles, span_map: SpanToByteRange, + // todo unify wih source name + source_to_id: HashMap, } impl DiagnosticState { @@ -104,26 +112,70 @@ impl DiagnosticState { DiagnosticState { files: SimpleFiles::new(), span_map: SpanToByteRange::new(), + source_to_id: HashMap::new(), } } + + fn add_source(&mut self, source: Source) { + let source_str: String = source.source.as_str().unwrap().into(); + let source_name: String = source.source_name.name.as_str().unwrap().into(); + self.span_map.add_source(source); + let file_id = self.files.add(source_name.clone(), source_str); + self.source_to_id.insert(source_name, file_id); + } + + fn to_diagnostic(&self, diag: super::Diagnostic) -> CDiagnostic { + let severity = match diag.level { + DiagnosticLevel::Error => Severity::Error, + DiagnosticLevel::Warning => Severity::Warning, + DiagnosticLevel::Note => Severity::Note, + DiagnosticLevel::Help => Severity::Help, + DiagnosticLevel::Bug => Severity::Bug, + }; + + let source_name: String = diag.span.source_name.name.as_str().unwrap().into(); + let file_id = *self.source_to_id.get(&source_name).unwrap(); + + let message: String = diag.message.as_str().unwrap().into(); + + let byte_range = self.span_map.lookup(&diag.span); + + let diagnostic = CDiagnostic::new(severity) + .with_message(message) + .with_code("EXXX") + .with_labels(vec![ + Label::primary(file_id, byte_range.start_pos..byte_range.end_pos) + ]); + + diagnostic + } } fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) { let source_map = diag_ctx.module.source_map.clone(); - for diagnostic in diag_ctx.diagnostics.clone() { - match source_map.source_map.get(&diagnostic.span.source_name) { - Err(err) => panic!(), - Ok(source) => state.span_map.add_source(source), + let writer = StandardStream::stderr(ColorChoice::Always); + let config = codespan_reporting::term::Config::default(); + for diagnostic in diag_ctx.diagnostics.clone() { + match source_map.source_map.get(&diagnostic.span.source_name) { + Err(err) => panic!(err), + Ok(source) => { + state.add_source(source); + let diagnostic = state.to_diagnostic(diagnostic); + term::emit( + &mut writer.lock(), + &config, + &state.files, + &diagnostic).unwrap(); } - println!("Diagnostic: {}", diagnostic.message); } + } } pub fn init() -> Result<()> { let diag_state = Arc::new(Mutex::new(DiagnosticState::new())); let render_fn = move |diag_ctx: DiagnosticContext| { - // let mut guard = diag_state.lock().unwrap(); - // renderer(&mut *guard, diag_ctx); + let mut guard = diag_state.lock().unwrap(); + renderer(&mut *guard, diag_ctx); }; override_renderer(Some(render_fn))?; From 8f219a6cd0ec5cbb9b44ae9e4b2e3e9af2561f43 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 16 Oct 2020 15:56:10 -0700 Subject: [PATCH 13/32] Remove type checker --- tests/python/relay/test_type_infer2.py | 419 ------------------------- 1 file changed, 419 deletions(-) delete mode 100644 tests/python/relay/test_type_infer2.py diff --git a/tests/python/relay/test_type_infer2.py b/tests/python/relay/test_type_infer2.py deleted file mode 100644 index 6758d96773a2..000000000000 --- a/tests/python/relay/test_type_infer2.py +++ /dev/null @@ -1,419 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Test that type checker correcly computes types - for expressions. -""" -import pytest -import tvm - -from tvm import IRModule, te, relay, parser -from tvm.relay import op, transform, analysis -from tvm.relay import Any - - -def infer_mod(mod, annotate_spans=True): - if annotate_spans: - mod = relay.transform.AnnotateSpans()(mod) - - mod = transform.InferType()(mod) - return mod - - -def infer_expr(expr, annotate_spans=True): - mod = IRModule.from_expr(expr) - mod = infer_mod(mod, annotate_spans) - mod = transform.InferType()(mod) - entry = mod["main"] - return entry if isinstance(expr, relay.Function) else entry.body - - -def assert_has_type(expr, typ, mod=None): - if not mod: - mod = tvm.IRModule({}) - - mod["main"] = expr - mod = infer_mod(mod) - checked_expr = mod["main"] - checked_type = checked_expr.checked_type - if checked_type != typ: - raise RuntimeError("Type mismatch %s vs %s" % (checked_type, typ)) - - -def initialize_box_adt(mod): - # initializes simple ADT for tests - box = relay.GlobalTypeVar("box") - tv = relay.TypeVar("tv") - constructor = relay.Constructor("constructor", [tv], box) - data = relay.TypeData(box, [tv], [constructor]) - mod[box] = data - return box, constructor - - -def test_monomorphic_let(): - "Program: let %x = 1; %x" - # TODO(@jroesch): this seems whack. - sb = relay.ScopeBuilder() - x = relay.var("x", dtype="float64", shape=()) - x = sb.let("x", relay.const(1.0, "float64")) - sb.ret(x) - xchecked = infer_expr(sb.get()) - assert xchecked.checked_type == relay.scalar_type("float64") - - -def test_single_op(): - "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }" - x = relay.var("x", shape=[]) - func = relay.Function([x], op.log(x)) - ttype = relay.TensorType([], dtype="float32") - assert_has_type(func, relay.FuncType([ttype], ttype)) - - -def test_add_broadcast_op(): - """ - Program: - fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32]) - -> Tensor[(5, 10, 4), float32] { - %x + %y - } - """ - x = relay.var("x", shape=(10, 4)) - y = relay.var("y", shape=(5, 10, 1)) - z = x + y - func = relay.Function([x, y], z) - t1 = relay.TensorType((10, 4), "float32") - t2 = relay.TensorType((5, 10, 1), "float32") - t3 = relay.TensorType((5, 10, 4), "float32") - expected_ty = relay.FuncType([t1, t2], t3) - assert_has_type(func, expected_ty) - - -def test_dual_op(): - """Program: - fn (%x : Tensor[(10, 10), float32]) { - let %t1 = log(x); - let %t2 = add(%t1, %x); - %t1 - } - """ - tp = relay.TensorType((10, 10), "float32") - x = relay.var("x", tp) - sb = relay.ScopeBuilder() - t1 = sb.let("t1", relay.log(x)) - t2 = sb.let("t2", relay.add(t1, x)) - sb.ret(t2) - f = relay.Function([x], sb.get()) - fchecked = infer_expr(f) - assert fchecked.checked_type == relay.FuncType([tp], tp) - - -def test_decl(): - """Program: - def @f(%x : Tensor[(10, 10), float32]) { - log(%x) - } - """ - tp = relay.TensorType((10, 10)) - x = relay.var("x", tp) - f = relay.Function([x], relay.log(x)) - fchecked = infer_expr(f) - assert fchecked.checked_type == relay.FuncType([tp], tp) - - -def test_recursion(): - """ - Program: - def @f(%n: int32, %data: float32) -> float32 { - if (%n == 0) { - %data - } else { - @f(%n - 1, log(%data)) - } - } - """ - sb = relay.ScopeBuilder() - f = relay.GlobalVar("f") - ti32 = relay.scalar_type("int32") - tf32 = relay.scalar_type("float32") - n = relay.var("n", ti32) - data = relay.var("data", tf32) - - with sb.if_scope(relay.equal(n, relay.const(0, ti32))): - sb.ret(data) - with sb.else_scope(): - sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) - mod = tvm.IRModule() - mod[f] = relay.Function([n, data], sb.get()) - mod = infer_mod(mod) - assert "@f(%1, %2)" in mod.astext() - assert mod["f"].checked_type == relay.FuncType([ti32, tf32], tf32) - - -def test_incomplete_call(): - tt = relay.scalar_type("int32") - x = relay.var("x", tt) - f = relay.var("f") - func = relay.Function([x, f], relay.Call(f, [x]), tt) - - ft = infer_expr(func) - f_type = relay.FuncType([tt], tt) - assert ft.checked_type == relay.FuncType([tt, f_type], tt) - - -def test_higher_order_argument(): - a = relay.TypeVar("a") - x = relay.Var("x", a) - id_func = relay.Function([x], x, a, [a]) - - b = relay.TypeVar("b") - f = relay.Var("f", relay.FuncType([b], b)) - y = relay.Var("y", b) - ho_func = relay.Function([f, y], f(y), b, [b]) - - # id func should be an acceptable argument to the higher-order - # function even though id_func takes a type parameter - ho_call = ho_func(id_func, relay.const(0, "int32")) - - hc = infer_expr(ho_call) - expected = relay.scalar_type("int32") - assert hc.checked_type == expected - - -def test_higher_order_return(): - a = relay.TypeVar("a") - x = relay.Var("x", a) - id_func = relay.Function([x], x, a, [a]) - - b = relay.TypeVar("b") - nested_id = relay.Function([], id_func, relay.FuncType([b], b), [b]) - - ft = infer_expr(nested_id) - assert ft.checked_type == relay.FuncType([], relay.FuncType([b], b), [b]) - - -def test_higher_order_nested(): - a = relay.TypeVar("a") - x = relay.Var("x", a) - id_func = relay.Function([x], x, a, [a]) - - choice_t = relay.FuncType([], relay.scalar_type("bool")) - f = relay.Var("f", choice_t) - - b = relay.TypeVar("b") - z = relay.Var("z") - top = relay.Function( - [f], relay.If(f(), id_func, relay.Function([z], z)), relay.FuncType([b], b), [b] - ) - - expected = relay.FuncType([choice_t], relay.FuncType([b], b), [b]) - ft = infer_expr(top) - assert ft.checked_type == expected - - -def test_tuple(): - tp = relay.TensorType((10,)) - x = relay.var("x", tp) - res = relay.Tuple([x, x]) - assert infer_expr(res).checked_type == relay.TupleType([tp, tp]) - - -def test_ref(): - x = relay.var("x", "float32") - y = relay.var("y", "float32") - r = relay.RefCreate(x) - st = relay.scalar_type("float32") - assert infer_expr(r).checked_type == relay.RefType(st) - g = relay.RefRead(r) - assert infer_expr(g).checked_type == st - w = relay.RefWrite(r, y) - assert infer_expr(w).checked_type == relay.TupleType([]) - - -def test_free_expr(): - x = relay.var("x", "float32") - y = relay.add(x, x) - yy = infer_expr(y, annotate_spans=False) - assert tvm.ir.structural_equal(yy.args[0], x, map_free_vars=True) - assert yy.checked_type == relay.scalar_type("float32") - assert x.vid.same_as(yy.args[0].vid) - - -def test_type_args(): - x = relay.var("x", shape=(10, 10)) - y = relay.var("y", shape=(1, 10)) - z = relay.add(x, y) - ty_z = infer_expr(z) - ty_args = ty_z.type_args - assert len(ty_args) == 2 - assert ty_args[0].dtype == "float32" - assert ty_args[1].dtype == "float32" - sh1 = ty_args[0].shape - sh2 = ty_args[1].shape - assert sh1[0].value == 10 - assert sh1[1].value == 10 - assert sh2[0].value == 1 - assert sh2[1].value == 10 - - -def test_global_var_recursion(): - mod = tvm.IRModule({}) - gv = relay.GlobalVar("main") - x = relay.var("x", shape=[]) - tt = relay.scalar_type("float32") - - func = relay.Function([x], relay.Call(gv, [x]), tt) - mod[gv] = func - mod = infer_mod(mod) - func_ty = mod["main"].checked_type - - assert func_ty == relay.FuncType([tt], tt) - - -def test_equal(): - i = relay.var("i", shape=[], dtype="int32") - eq = op.equal(i, relay.const(0, dtype="int32")) - func = relay.Function([i], eq) - ft = infer_expr(func) - expected = relay.FuncType([relay.scalar_type("int32")], relay.scalar_type("bool")) - assert ft.checked_type == expected - - assert ft.checked_type == relay.FuncType( - [relay.scalar_type("int32")], relay.scalar_type("bool") - ) - - -def test_constructor_type(): - mod = tvm.IRModule() - box, constructor = initialize_box_adt(mod) - - a = relay.TypeVar("a") - x = relay.Var("x", a) - func = relay.Function([x], constructor(x), box(a), [a]) - mod["main"] = func - mod = infer_mod(mod) - func_ty = mod["main"].checked_type - box = mod.get_global_type_var("box") - expected = relay.FuncType([a], box(a), [a]) - assert func_ty == expected - - -def test_constructor_call(): - mod = tvm.IRModule() - box, constructor = initialize_box_adt(mod) - - box_unit = constructor(relay.Tuple([])) - box_constant = constructor(relay.const(0, "float32")) - - func = relay.Function([], relay.Tuple([box_unit, box_constant])) - mod["main"] = func - mod = infer_mod(mod) - ret_type = mod["main"].checked_type.ret_type.fields - # NB(@jroesch): when we annotate spans the ast fragments before - # annotation the previous fragments will no longer be directly equal. - box = mod.get_global_type_var("box") - expected1 = box(relay.TupleType([])) - expected2 = box(relay.TensorType((), "float32")) - assert ret_type[0] == expected1 - assert ret_type[1] == expected2 - - -def test_adt_match(): - mod = tvm.IRModule() - box, constructor = initialize_box_adt(mod) - - v = relay.Var("v", relay.TensorType((), "float32")) - match = relay.Match( - constructor(relay.const(0, "float32")), - [ - relay.Clause( - relay.PatternConstructor(constructor, [relay.PatternVar(v)]), relay.Tuple([]) - ), - # redundant but shouldn't matter to typechecking - relay.Clause(relay.PatternWildcard(), relay.Tuple([])), - ], - ) - - func = relay.Function([], match) - mod["main"] = func - mod = infer_mod(mod) - actual = mod["main"].checked_type.ret_type - assert actual == relay.TupleType([]) - - -def test_adt_match_type_annotations(): - mod = tvm.IRModule() - box, constructor = initialize_box_adt(mod) - - # the only type annotation is inside the match pattern var - # but that should be enough info - tt = relay.TensorType((2, 2), "float32") - x = relay.Var("x") - mv = relay.Var("mv", tt) - match = relay.Match( - constructor(x), - [ - relay.Clause( - relay.PatternConstructor(constructor, [relay.PatternVar(mv)]), relay.Tuple([]) - ) - ], - ) - - mod["main"] = relay.Function([x], match) - mod = infer_mod(mod) - ft = mod["main"].checked_type - assert ft == relay.FuncType([tt], relay.TupleType([])) - - -def test_let_polymorphism(): - id = relay.Var("id") - xt = relay.TypeVar("xt") - x = relay.Var("x", xt) - body = relay.Tuple([id(relay.const(1)), id(relay.Tuple([]))]) - body = relay.Let(id, relay.Function([x], x, xt, [xt]), body) - body = infer_expr(body) - int32 = relay.TensorType((), "int32") - tvm.ir.assert_structural_equal(body.checked_type, relay.TupleType([int32, relay.TupleType([])])) - - -def test_if(): - choice_t = relay.FuncType([], relay.scalar_type("bool")) - f = relay.Var("f", choice_t) - true_branch = relay.Var("True", relay.TensorType([Any(), 1], dtype="float32")) - false_branch = relay.Var("False", relay.TensorType([Any(), Any()], dtype="float32")) - top = relay.Function([f, true_branch, false_branch], relay.If(f(), true_branch, false_branch)) - ft = infer_expr(top) - tvm.ir.assert_structural_equal(ft.ret_type, relay.TensorType([Any(), 1], dtype="float32")) - - -def test_type_arg_infer(): - code = """ -#[version = "0.0.5"] -def @id[A](%x: A) -> A { - %x -} -def @main(%f: float32) -> float32 { - @id(%f) -} -""" - mod = tvm.parser.fromtext(code) - mod = transform.InferType()(mod) - tvm.ir.assert_structural_equal(mod["main"].body.type_args, [relay.TensorType((), "float32")]) - - -if __name__ == "__main__": - import sys - - pytest.main(sys.argv) From 6f2841448a5e7b4c3efe088753f2719678bf03cd Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 16 Oct 2020 16:31:37 -0700 Subject: [PATCH 14/32] Format and cleanup --- python/tvm/ir/diagnostics/__init__.py | 1 + rust/compiler-ext/src/lib.rs | 3 +- rust/tvm-rt/src/array.rs | 11 ++-- rust/tvm-rt/src/errors.rs | 4 +- rust/tvm-rt/src/function.rs | 2 +- rust/tvm/src/bin/tyck.rs | 8 ++- rust/tvm/src/ir/diagnostics/codespan.rs | 87 ++++++++++++++----------- rust/tvm/src/ir/mod.rs | 2 +- rust/tvm/src/ir/module.rs | 5 +- rust/tvm/src/ir/relay/mod.rs | 2 +- rust/tvm/src/ir/relay/visitor.rs | 24 ------- rust/tvm/src/ir/source_map.rs | 13 ++-- rust/tvm/src/ir/span.rs | 2 +- src/ir/expr.cc | 11 ---- 14 files changed, 79 insertions(+), 96 deletions(-) delete mode 100644 rust/tvm/src/ir/relay/visitor.rs diff --git a/python/tvm/ir/diagnostics/__init__.py b/python/tvm/ir/diagnostics/__init__.py index 0ad2a7aa6bfd..3a6402c0359d 100644 --- a/python/tvm/ir/diagnostics/__init__.py +++ b/python/tvm/ir/diagnostics/__init__.py @@ -37,6 +37,7 @@ def get_renderer(): """ return _ffi_api.GetRenderer() + @tvm.register_func("diagnostics.override_renderer") def override_renderer(render_func): """ diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs index c136d06f492c..346f40fa2ce0 100644 --- a/rust/compiler-ext/src/lib.rs +++ b/rust/compiler-ext/src/lib.rs @@ -36,8 +36,7 @@ tvm::export!(test_fn, test_fn2); #[no_mangle] fn compiler_ext_initialize() -> i32 { let _ = env_logger::try_init(); - tvm_export("rust_ext") - .expect("failed to initialize Rust compiler_ext"); + tvm_export("rust_ext").expect("failed to initialize Rust compiler_ext"); log::debug!("done!"); return 0; } diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 032ca79bf744..66e32a7e7177 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -18,8 +18,8 @@ */ use std::convert::{TryFrom, TryInto}; -use std::marker::PhantomData; use std::iter::{IntoIterator, Iterator}; +use std::marker::PhantomData; use crate::errors::Error; use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef}; @@ -93,8 +93,7 @@ impl Iterator for IntoIter { fn next(&mut self) -> Option { if self.pos < self.size { - let item = self.array.get(self.pos) - .expect("should not fail"); + let item = self.array.get(self.pos).expect("should not fail"); self.pos += 1; Some(item) } else { @@ -109,7 +108,11 @@ impl IntoIterator for Array { fn into_iter(self) -> Self::IntoIter { let size = self.len() as isize; - IntoIter { array: self, pos: 0, size: size } + IntoIter { + array: self, + pos: 0, + size: size, + } } } diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index 3de9f3cf3ee9..31ce385ef662 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -68,7 +68,9 @@ pub enum Error { Infallible(#[from] std::convert::Infallible), #[error("a panic occurred while executing a Rust packed function")] Panic, - #[error("one or more error diagnostics were emitted, please check diagnostic render for output.")] + #[error( + "one or more error diagnostics were emitted, please check diagnostic render for output." + )] DiagnosticError(String), #[error("{0}")] Raw(String), diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 173b60a222ba..4c6f56ea1e76 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -128,7 +128,7 @@ impl Function { type_codes.as_mut_ptr() as *mut c_int, num_args as c_int, &mut ret_val as *mut _, - &mut ret_type_code as *mut _ + &mut ret_type_code as *mut _, ) }; diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs index 13470e776195..e9c26630281c 100644 --- a/rust/tvm/src/bin/tyck.rs +++ b/rust/tvm/src/bin/tyck.rs @@ -20,9 +20,11 @@ fn main() -> Result<()> { let opt = Opt::from_args(); println!("{:?}", &opt); let _module = match IRModule::parse_file(opt.input) { - Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => { return Ok(()) }, - Err(e) => { return Err(e.into()); }, - Ok(module) => module + Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => return Ok(()), + Err(e) => { + return Err(e.into()); + } + Ok(module) => module, }; Ok(()) diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs index 9fc1ee00ae51..9a31691728b9 100644 --- a/rust/tvm/src/ir/diagnostics/codespan.rs +++ b/rust/tvm/src/ir/diagnostics/codespan.rs @@ -1,3 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/// A TVM diagnostics renderer which uses the Rust `codespan` +/// library to produce error messages. use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -6,13 +27,8 @@ use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::termcolor::{ColorChoice, StandardStream}; use codespan_reporting::term::{self, ColorArg}; -use crate::ir::source_map::*; use super::*; - -enum StartOrEnd { - Start, - End, -} +use crate::ir::source_map::*; struct ByteRange { file_id: FileId, @@ -26,7 +42,7 @@ enum FileSpanToByteRange { /// Map character regions which are larger then 1-byte to length. lengths: HashMap, source: String, - } + }, } impl FileSpanToByteRange { @@ -34,24 +50,11 @@ impl FileSpanToByteRange { let mut last_index = 0; let mut is_ascii = true; if source.is_ascii() { - let line_lengths = - source - .lines() - .map(|line| line.len()) - .collect(); + let line_lengths = source.lines().map(|line| line.len()).collect(); FileSpanToByteRange::AsciiSource(line_lengths) } else { panic!() } - - // for (index, _) in source.char_indices() { - // if last_index - 1 != last_index { - // is_ascii = false; - // } else { - // panic!(); - // } - // last_index = index; - // } } fn lookup(&self, span: &Span) -> ByteRange { @@ -61,22 +64,34 @@ impl FileSpanToByteRange { match self { AsciiSource(ref line_lengths) => { - let start_pos = (&line_lengths[0..(span.line - 1) as usize]).into_iter().sum::() + (span.column) as usize; - let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]).into_iter().sum::() + (span.end_column) as usize; - ByteRange { file_id: source_name, start_pos, end_pos } - }, - _ => panic!() + let start_pos = (&line_lengths[0..(span.line - 1) as usize]) + .into_iter() + .sum::() + + (span.column) as usize; + let end_pos = (&line_lengths[0..(span.end_line - 1) as usize]) + .into_iter() + .sum::() + + (span.end_column) as usize; + ByteRange { + file_id: source_name, + start_pos, + end_pos, + } + } + _ => panic!(), } } } struct SpanToByteRange { - map: HashMap + map: HashMap, } impl SpanToByteRange { fn new() -> SpanToByteRange { - SpanToByteRange { map: HashMap::new() } + SpanToByteRange { + map: HashMap::new(), + } } pub fn add_source(&mut self, source: Source) { @@ -86,7 +101,8 @@ impl SpanToByteRange { panic!() } else { let source = source.source.as_str().expect("fpp").into(); - self.map.insert(source_name, FileSpanToByteRange::new(source)); + self.map + .insert(source_name, FileSpanToByteRange::new(source)); } } @@ -143,9 +159,10 @@ impl DiagnosticState { let diagnostic = CDiagnostic::new(severity) .with_message(message) .with_code("EXXX") - .with_labels(vec![ - Label::primary(file_id, byte_range.start_pos..byte_range.end_pos) - ]); + .with_labels(vec![Label::primary( + file_id, + byte_range.start_pos..byte_range.end_pos, + )]); diagnostic } @@ -161,11 +178,7 @@ fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) { Ok(source) => { state.add_source(source); let diagnostic = state.to_diagnostic(diagnostic); - term::emit( - &mut writer.lock(), - &config, - &state.files, - &diagnostic).unwrap(); + term::emit(&mut writer.lock(), &config, &state.files, &diagnostic).unwrap(); } } } diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index df9bc688cb32..6d5158005497 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -25,8 +25,8 @@ pub mod function; pub mod module; pub mod op; pub mod relay; -pub mod span; pub mod source_map; +pub mod span; pub mod tir; pub mod ty; diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 11d6c491842c..443915ff27b7 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -32,7 +32,6 @@ use super::expr::GlobalVar; use super::function::BaseFunc; use super::source_map::SourceMap; - // TODO(@jroesch): define type type TypeData = ObjectRef; type GlobalTypeVar = ObjectRef; @@ -133,7 +132,9 @@ impl IRModule { parse_module(file_name.into(), source.into()) } - pub fn parse_file>(file_path: P) -> std::result::Result { + pub fn parse_file>( + file_path: P, + ) -> std::result::Result { let file_path = file_path.as_ref(); let file_path_as_str = file_path.to_str().unwrap().to_string(); let source = std::fs::read_to_string(file_path)?; diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 4b091285d245..530b1203bd98 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -27,8 +27,8 @@ use crate::runtime::{object::*, String as TString}; use super::attrs::Attrs; use super::expr::BaseExprNode; use super::function::BaseFuncNode; -use super::ty::{Type, TypeNode}; use super::span::Span; +use super::ty::{Type, TypeNode}; use tvm_macros::Object; use tvm_rt::NDArray; diff --git a/rust/tvm/src/ir/relay/visitor.rs b/rust/tvm/src/ir/relay/visitor.rs deleted file mode 100644 index 31661742c4fb..000000000000 --- a/rust/tvm/src/ir/relay/visitor.rs +++ /dev/null @@ -1,24 +0,0 @@ -use super::Expr; - -macro_rules! downcast_match { - ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => { - $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+ - { $default } - } -} - -trait ExprVisitorMut { - fn visit(&mut self, expr: Expr) { - downcast_match!(expr; { - else => { - panic!() - } - }); - } - - fn visit(&mut self, expr: Expr); -} - -// trait ExprTransformer { -// fn -// } diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index ebe7e46804fa..b0cc0d898e94 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -19,7 +19,7 @@ use crate::runtime::map::Map; use crate::runtime::object::Object; -use crate::runtime::string::{String as TString}; +use crate::runtime::string::String as TString; use super::span::{SourceName, Span}; @@ -39,12 +39,10 @@ pub struct SourceNode { /// The raw source. */ pub source: TString, - - // A mapping of line breaks into the raw source. - // std::vector> line_map; + // A mapping of line breaks into the raw source. + // std::vector> line_map; } - // class Source : public ObjectRef { // public: // TVM_DLL Source(SourceName src_name, std::string source); @@ -53,7 +51,6 @@ pub struct SourceNode { // TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); // }; - /// A mapping from a unique source name to source fragments. #[repr(C)] #[derive(Object)] @@ -61,6 +58,6 @@ pub struct SourceNode { #[ref_name = "SourceMap"] pub struct SourceMapNode { pub base: Object, - /// The source mapping. - pub source_map: Map, + /// The source mapping. + pub source_map: Map, } diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index c54fd5143c12..afcbe9c01698 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -18,7 +18,7 @@ * under the License. */ -use crate::runtime::{ObjectRef, Object, String as TString}; +use crate::runtime::{Object, ObjectRef, String as TString}; use tvm_macros::Object; /// A source file name, contained in a Span. diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 62e28483ffd5..05d41cf204d6 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -192,15 +192,4 @@ TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { return ss.str(); }); - - } // namespace tvm - -#ifdef RUST_COMPILER_EXT - -extern "C" { - int compiler_ext_initialize(); - static int test = compiler_ext_initialize(); -} - -#endif From af518e1c6f42a3f3dcfd838fe21409384d54a94a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 16 Oct 2020 16:32:13 -0700 Subject: [PATCH 15/32] Fix the extension code --- src/contrib/rust_extension.cc | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 src/contrib/rust_extension.cc diff --git a/src/contrib/rust_extension.cc b/src/contrib/rust_extension.cc new file mode 100644 index 000000000000..075cbc670f66 --- /dev/null +++ b/src/contrib/rust_extension.cc @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/rust_extension.cc + * \brief Expose Rust extensions initialization. + */ +#ifdef RUST_COMPILER_EXT + +extern "C" { + int compiler_ext_initialize(); + static int test = compiler_ext_initialize(); +} + +#endif From beb8f1c8c8d6b696cfd137f429a6c74cc95b8b1b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 16 Oct 2020 16:45:12 -0700 Subject: [PATCH 16/32] More cleanup --- rust/compiler-ext/src/lib.rs | 22 ++++++++-------------- rust/tvm/src/lib.rs | 7 ++++--- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs index 346f40fa2ce0..5f83f7b3b8ef 100644 --- a/rust/compiler-ext/src/lib.rs +++ b/rust/compiler-ext/src/lib.rs @@ -18,25 +18,19 @@ */ use env_logger; -use tvm; -use tvm::runtime::function::register_override; +use tvm::export; -fn test_fn() -> Result<(), tvm::Error> { - println!("Hello Greg from Rust!"); - Ok(()) +fn diagnostics() -> Result<(), tvm::Error> { + tvm::ir::diagnostics::codespan::init() } -fn test_fn2(message: tvm::runtime::string::String) -> Result<(), tvm::Error> { - println!("The message: {}", message); - Ok(()) -} - -tvm::export!(test_fn, test_fn2); +export!(diagnostics); #[no_mangle] -fn compiler_ext_initialize() -> i32 { +extern fn compiler_ext_initialize() -> i32 { let _ = env_logger::try_init(); - tvm_export("rust_ext").expect("failed to initialize Rust compiler_ext"); - log::debug!("done!"); + tvm_export("rust_ext") + .expect("failed to initialize the Rust compiler extensions."); + log::debug!("Loaded the Rust compiler extension."); return 0; } diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index d193f09803ae..ec80ece1e37a 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -50,10 +50,11 @@ pub use runtime::version; #[macro_export] macro_rules! export { - ($($fn_names:expr),*) => { + ($($fn_name:expr),*) => { pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> { $( - register_override($fn_name, concat!($ns, stringfy!($fn_name)), true)?; + let name = String::from(ns) + ::std::stringify!($fn_name); + tvm::runtime::function::register_override($fn_name, name, true)?; )* Ok(()) } @@ -65,7 +66,7 @@ macro_rules! export_mod { ($ns:expr, $($mod_name:expr),*) => { pub fn tvm_mod_export() -> Result<(), tvm::Error> { $( - $mod_names::tvm_export($ns)?; + $mod_name::tvm_export($ns)?; )* Ok(()) } From 657c708f298a70b9627dc1ffbed38a5aba6ee59c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 19 Oct 2020 19:52:20 -0700 Subject: [PATCH 17/32] Fix some CR --- rust/tvm/src/ir/diagnostics/codespan.rs | 6 ++++-- rust/tvm/src/lib.rs | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs index 9a31691728b9..54fd33617b66 100644 --- a/rust/tvm/src/ir/diagnostics/codespan.rs +++ b/rust/tvm/src/ir/diagnostics/codespan.rs @@ -17,8 +17,10 @@ * under the License. */ -/// A TVM diagnostics renderer which uses the Rust `codespan` -/// library to produce error messages. +/// A TVM diagnostics renderer which uses the Rust `codespan` library +/// to produce error messages. +/// +/// use std::collections::HashMap; use std::sync::{Arc, Mutex}; diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs index ec80ece1e37a..7e0682b86b33 100644 --- a/rust/tvm/src/lib.rs +++ b/rust/tvm/src/lib.rs @@ -24,7 +24,7 @@ //! One particular use case is that given optimized deep learning model artifacts, //! (compiled with TVM) which include a shared library //! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them -//! in Rust idomatically to create a TVM Graph Runtime and +//! in Rust idiomatically to create a TVM Graph Runtime and //! run the model for some inputs and get the //! desired predictions *all in Rust*. //! @@ -53,7 +53,7 @@ macro_rules! export { ($($fn_name:expr),*) => { pub fn tvm_export(ns: &str) -> Result<(), tvm::Error> { $( - let name = String::from(ns) + ::std::stringify!($fn_name); + let name = String::fromwe(ns) + ::std::stringify!($fn_name); tvm::runtime::function::register_override($fn_name, name, true)?; )* Ok(()) From 06cdc4753434d19a06f5c2702abd5d759d20ec0a Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 20 Oct 2020 15:06:06 -0700 Subject: [PATCH 18/32] Add docs and address feedback --- cmake/modules/RustExt.cmake | 2 +- rust/tvm/src/ir/diagnostics/codespan.rs | 26 +++++++++++++++++++++---- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake index 2ad726e94213..4c1534bcc4c0 100644 --- a/cmake/modules/RustExt.cmake +++ b/cmake/modules/RustExt.cmake @@ -1,4 +1,4 @@ -if(USE_RUST_EXT AND NOT USE_RUST_EXT EQUAL OFF) +if(USE_RUST_EXT) set(RUST_SRC_DIR "${CMAKE_SOURCE_DIR}/rust") set(CARGO_OUT_DIR "${CMAKE_SOURCE_DIR}/rust/target") diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs index 54fd33617b66..ebedc18fb02e 100644 --- a/rust/tvm/src/ir/diagnostics/codespan.rs +++ b/rust/tvm/src/ir/diagnostics/codespan.rs @@ -17,10 +17,11 @@ * under the License. */ -/// A TVM diagnostics renderer which uses the Rust `codespan` library -/// to produce error messages. -/// -/// +//! A TVM diagnostics renderer which uses the Rust `codespan` library +//! to produce error messages. +//! +//! This is an example of using the exposed API surface of TVM to +//! customize the compiler behavior. use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -32,22 +33,29 @@ use codespan_reporting::term::{self, ColorArg}; use super::*; use crate::ir::source_map::*; +/// A representation of a TVM Span as a range of bytes in a file. struct ByteRange { + /// The file in which the range occurs. file_id: FileId, + /// The range start. start_pos: usize, + /// The range end. end_pos: usize, } +/// A mapping from Span to ByteRange for a single file. enum FileSpanToByteRange { AsciiSource(Vec), Utf8 { /// Map character regions which are larger then 1-byte to length. lengths: HashMap, + /// The source of the program. source: String, }, } impl FileSpanToByteRange { + /// Construct a span to byte range mapping from the program source. fn new(source: String) -> FileSpanToByteRange { let mut last_index = 0; let mut is_ascii = true; @@ -59,6 +67,7 @@ impl FileSpanToByteRange { } } + /// Lookup the corresponding ByteRange for a given Span. fn lookup(&self, span: &Span) -> ByteRange { use FileSpanToByteRange::*; @@ -85,6 +94,7 @@ impl FileSpanToByteRange { } } +/// A mapping for all files in a source map to byte ranges. struct SpanToByteRange { map: HashMap, } @@ -96,6 +106,7 @@ impl SpanToByteRange { } } + /// Add a source file to the span mapping. pub fn add_source(&mut self, source: Source) { let source_name: String = source.source_name.name.as_str().expect("foo").into(); @@ -108,6 +119,9 @@ impl SpanToByteRange { } } + /// Lookup a span to byte range mapping. + /// + /// First resolves the Span to a file, and then maps the span to a byte range in the file. pub fn lookup(&self, span: &Span) -> ByteRange { let source_name: String = span.source_name.name.as_str().expect("foo").into(); @@ -118,6 +132,7 @@ impl SpanToByteRange { } } +/// The state of the `codespan` based diagnostics. struct DiagnosticState { files: SimpleFiles, span_map: SpanToByteRange, @@ -186,6 +201,9 @@ fn renderer(state: &mut DiagnosticState, diag_ctx: DiagnosticContext) { } } +/// Initialize the `codespan` based diagnostics. +/// +/// Calling this function will globally override the TVM diagnostics renderer. pub fn init() -> Result<()> { let diag_state = Arc::new(Mutex::new(DiagnosticState::new())); let render_fn = move |diag_ctx: DiagnosticContext| { From db6b35503245e41924ddb4e5aaee806214866dfb Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 21 Oct 2020 02:40:44 -0700 Subject: [PATCH 19/32] WIP more improvments --- rust/tvm/src/ir/diagnostics/mod.rs | 32 +++++++++++++++++------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs index 039d1ed347cd..e49c4e83cece 100644 --- a/rust/tvm/src/ir/diagnostics/mod.rs +++ b/rust/tvm/src/ir/diagnostics/mod.rs @@ -34,7 +34,6 @@ use tvm_macros::{external, Object}; pub mod codespan; -// Get the the diagnostic renderer. external! { #[name("node.ArrayGetItem")] fn get_renderer() -> DiagnosticRenderer; @@ -48,6 +47,9 @@ external! { #[name("diagnostics.DiagnosticContextRender")] fn diagnostic_context_render(ctx: DiagnosticContext) -> (); + #[name("diagnostics.DiagnosticRendererRender")] + fn diagnositc_renderer_render(renderer: DiagnosticRenderer,ctx: DiagnosticContext) -> (); + #[name("diagnostics.ClearRenderer")] fn clear_renderer() -> (); } @@ -108,11 +110,17 @@ pub struct DiagnosticBuilder { /// The level. pub level: DiagnosticLevel, - /// The source name. - pub source_name: SourceName, - /// The span of the diagnostic. pub span: Span, + + /// The in progress message. + pub message: String, +} + +impl DiagnosticBuilder { + pub fn new(level: DiagnosticLevel, span: Span) -> DiagnosticBuilder { + DiagnosticBuilder { level, span, message: "".into() } + } } // /*! \brief Display diagnostics in a given display format. @@ -138,16 +146,12 @@ pub struct DiagnosticRendererNode { // missing field here } -// def render(self, ctx): -// """ -// Render the provided context. - -// Params -// ------ -// ctx: DiagnosticContext -// The diagnostic context to render. -// """ -// return _ffi_api.DiagnosticRendererRender(self, ctx +impl DiagnosticRenderer { + /// Render the provided context. + pub fn render(&self, ctx: DiagnosticContext) -> Result<()> { + diagnositc_renderer_render(self.clone(), ctx) + } +} #[repr(C)] #[derive(Object)] From 9aa1a091ad94fc257078aed897fced9784d9eeee Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 24 Oct 2020 00:28:32 -0700 Subject: [PATCH 20/32] Update cmake/modules/RustExt.cmake Co-authored-by: Robert Kimball --- cmake/modules/RustExt.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake index 4c1534bcc4c0..f6cb9199015b 100644 --- a/cmake/modules/RustExt.cmake +++ b/cmake/modules/RustExt.cmake @@ -7,7 +7,7 @@ if(USE_RUST_EXT) elseif(USE_RUST_EXT STREQUAL "DYNAMIC") set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.so") else() - message(FATAL_ERROR "invalid setting for RUST_EXT") + message(FATAL_ERROR "invalid setting for USE_RUST_EXT") endif() add_custom_command( From 7e038e08c9c1aeda2b79d63bf7ce6214882b00f7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 24 Oct 2020 00:30:48 -0700 Subject: [PATCH 21/32] Update rust/tvm/src/ir/diagnostics/mod.rs Co-authored-by: Robert Kimball --- rust/tvm/src/ir/diagnostics/mod.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs index e49c4e83cece..4b1a2e290207 100644 --- a/rust/tvm/src/ir/diagnostics/mod.rs +++ b/rust/tvm/src/ir/diagnostics/mod.rs @@ -123,15 +123,15 @@ impl DiagnosticBuilder { } } -// /*! \brief Display diagnostics in a given display format. -// * -// * A diagnostic renderer is responsible for converting the -// * raw diagnostics into consumable output. -// * -// * For example the terminal renderer will render a sequence -// * of compiler diagnostics to std::out and std::err in -// * a human readable form. -// */ +/// \brief Display diagnostics in a given display format. +/// +/// A diagnostic renderer is responsible for converting the +/// raw diagnostics into consumable output. +/// +/// For example the terminal renderer will render a sequence +/// of compiler diagnostics to std::out and std::err in +/// a human readable form. +/// #[repr(C)] #[derive(Object)] #[ref_name = "DiagnosticRenderer"] From 9a0e727b0f5aa6225ec6952ffb9dc601d9d2ad8f Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 24 Oct 2020 00:43:30 -0700 Subject: [PATCH 22/32] Clean up PR --- rust/compiler-ext/Cargo.toml | 1 - rust/compiler-ext/src/lib.rs | 5 ++- rust/tvm-rt/src/object/object_ptr.rs | 14 ++++----- rust/tvm-rt/src/string.rs | 4 +-- rust/tvm/src/ir/arith.rs | 2 +- rust/tvm/src/ir/diagnostics/codespan.rs | 6 ++-- rust/tvm/src/ir/diagnostics/mod.rs | 42 ++++++++++++++++--------- rust/tvm/src/ir/expr.rs | 2 +- rust/tvm/src/ir/module.rs | 1 - rust/tvm/src/ir/relay/mod.rs | 8 ++--- rust/tvm/src/ir/source_map.rs | 2 +- rust/tvm/src/ir/span.rs | 2 +- rust/tvm/src/ir/ty.rs | 2 +- rust/tvm/src/transform.rs | 2 +- src/ir/diagnostic.cc | 5 +++ 15 files changed, 56 insertions(+), 42 deletions(-) diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml index 3b13bc5200d9..1633e1dfee1d 100644 --- a/rust/compiler-ext/Cargo.toml +++ b/rust/compiler-ext/Cargo.toml @@ -3,7 +3,6 @@ name = "compiler-ext" version = "0.1.0" authors = ["Jared Roesch "] edition = "2018" -# TODO(@jroesch): would be cool to figure out how to statically link instead. [lib] crate-type = ["staticlib", "cdylib"] diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs index 5f83f7b3b8ef..278060ef4897 100644 --- a/rust/compiler-ext/src/lib.rs +++ b/rust/compiler-ext/src/lib.rs @@ -27,10 +27,9 @@ fn diagnostics() -> Result<(), tvm::Error> { export!(diagnostics); #[no_mangle] -extern fn compiler_ext_initialize() -> i32 { +extern "C" fn compiler_ext_initialize() -> i32 { let _ = env_logger::try_init(); - tvm_export("rust_ext") - .expect("failed to initialize the Rust compiler extensions."); + tvm_export("rust_ext").expect("failed to initialize the Rust compiler extensions."); log::debug!("Loaded the Rust compiler extension."); return 0; } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 77254d2fbca2..8d535368c352 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -125,7 +125,7 @@ impl Object { /// By using associated constants and generics we can provide a /// type indexed abstraction over allocating objects with the /// correct index and deleter. - pub fn base_object() -> Object { + pub fn base() -> Object { let index = Object::get_type_index::(); Object::new(index, delete::) } @@ -351,7 +351,7 @@ mod tests { #[test] fn test_new_object() -> anyhow::Result<()> { - let object = Object::base_object::(); + let object = Object::base::(); let ptr = ObjectPtr::new(object); assert_eq!(ptr.count(), 1); Ok(()) @@ -359,7 +359,7 @@ mod tests { #[test] fn test_leak() -> anyhow::Result<()> { - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let object = ObjectPtr::leak(ptr); assert_eq!(object.count(), 1); @@ -368,7 +368,7 @@ mod tests { #[test] fn test_clone() -> anyhow::Result<()> { - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let ptr2 = ptr.clone(); assert_eq!(ptr2.count(), 2); @@ -379,7 +379,7 @@ mod tests { #[test] fn roundtrip_retvalue() -> Result<()> { - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let ret_value: RetValue = ptr.clone().into(); let ptr2: ObjectPtr = ret_value.try_into()?; @@ -401,7 +401,7 @@ mod tests { #[test] fn roundtrip_argvalue() -> Result<()> { - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let ptr_clone = ptr.clone(); assert_eq!(ptr.count(), 2); @@ -435,7 +435,7 @@ mod tests { fn test_ref_count_boundary3() { use super::*; use crate::function::{register, Function}; - let ptr = ObjectPtr::new(Object::base_object::()); + let ptr = ObjectPtr::new(Object::base::()); assert_eq!(ptr.count(), 1); let stay = ptr.clone(); assert_eq!(ptr.count(), 2); diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 6ff24bef3a60..3cd33a226d44 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -38,7 +38,7 @@ impl From for String { fn from(s: std::string::String) -> Self { let size = s.len() as u64; let data = Box::into_raw(s.into_boxed_str()).cast(); - let base = Object::base_object::(); + let base = Object::base::(); StringObj { base, data, size }.into() } } @@ -47,7 +47,7 @@ impl From<&'static str> for String { fn from(s: &'static str) -> Self { let size = s.len() as u64; let data = s.as_bytes().as_ptr(); - let base = Object::base_object::(); + let base = Object::base::(); StringObj { base, data, size }.into() } } diff --git a/rust/tvm/src/ir/arith.rs b/rust/tvm/src/ir/arith.rs index f589f2ac25c6..92a1de69ff78 100644 --- a/rust/tvm/src/ir/arith.rs +++ b/rust/tvm/src/ir/arith.rs @@ -34,7 +34,7 @@ macro_rules! define_node { impl $name { pub fn new($($id : $t,)*) -> $name { - let base = Object::base_object::<$node>(); + let base = Object::base::<$node>(); let node = $node { base, $($id),* }; $name(Some(ObjectPtr::new(node))) } diff --git a/rust/tvm/src/ir/diagnostics/codespan.rs b/rust/tvm/src/ir/diagnostics/codespan.rs index ebedc18fb02e..c411c0cd31a7 100644 --- a/rust/tvm/src/ir/diagnostics/codespan.rs +++ b/rust/tvm/src/ir/diagnostics/codespan.rs @@ -28,7 +28,7 @@ use std::sync::{Arc, Mutex}; use codespan_reporting::diagnostic::{Diagnostic as CDiagnostic, Label, Severity}; use codespan_reporting::files::SimpleFiles; use codespan_reporting::term::termcolor::{ColorChoice, StandardStream}; -use codespan_reporting::term::{self, ColorArg}; +use codespan_reporting::term::{self}; use super::*; use crate::ir::source_map::*; @@ -36,6 +36,7 @@ use crate::ir::source_map::*; /// A representation of a TVM Span as a range of bytes in a file. struct ByteRange { /// The file in which the range occurs. + #[allow(dead_code)] file_id: FileId, /// The range start. start_pos: usize, @@ -46,6 +47,7 @@ struct ByteRange { /// A mapping from Span to ByteRange for a single file. enum FileSpanToByteRange { AsciiSource(Vec), + #[allow(dead_code)] Utf8 { /// Map character regions which are larger then 1-byte to length. lengths: HashMap, @@ -57,8 +59,6 @@ enum FileSpanToByteRange { impl FileSpanToByteRange { /// Construct a span to byte range mapping from the program source. fn new(source: String) -> FileSpanToByteRange { - let mut last_index = 0; - let mut is_ascii = true; if source.is_ascii() { let line_lengths = source.lines().map(|line| line.len()).collect(); FileSpanToByteRange::AsciiSource(line_lengths) diff --git a/rust/tvm/src/ir/diagnostics/mod.rs b/rust/tvm/src/ir/diagnostics/mod.rs index 4b1a2e290207..051bb9eb16c4 100644 --- a/rust/tvm/src/ir/diagnostics/mod.rs +++ b/rust/tvm/src/ir/diagnostics/mod.rs @@ -20,7 +20,7 @@ use super::module::IRModule; use super::span::*; use crate::runtime::function::Result; -use crate::runtime::object::{Object, ObjectPtr, ObjectRef}; +use crate::runtime::object::{Object, ObjectPtr}; use crate::runtime::{ array::Array, function::{self, Function, ToFunction}, @@ -44,6 +44,9 @@ external! { #[name("diagnostics.Emit")] fn emit(ctx: DiagnosticContext, diagnostic: Diagnostic) -> (); + #[name("diagnostics.DiagnosticContextDefault")] + fn diagnostic_context_default(module: IRModule) -> DiagnosticContext; + #[name("diagnostics.DiagnosticContextRender")] fn diagnostic_context_render(ctx: DiagnosticContext) -> (); @@ -80,28 +83,34 @@ pub struct DiagnosticNode { } impl Diagnostic { - pub fn new(level: DiagnosticLevel, span: Span, message: TString) { - todo!() + pub fn new(level: DiagnosticLevel, span: Span, message: TString) -> Diagnostic { + let node = DiagnosticNode { + base: Object::base::(), + level, + span, + message, + }; + ObjectPtr::new(node).into() } pub fn bug(span: Span) -> DiagnosticBuilder { - todo!() + DiagnosticBuilder::new(DiagnosticLevel::Bug, span) } pub fn error(span: Span) -> DiagnosticBuilder { - todo!() + DiagnosticBuilder::new(DiagnosticLevel::Error, span) } pub fn warning(span: Span) -> DiagnosticBuilder { - todo!() + DiagnosticBuilder::new(DiagnosticLevel::Warning, span) } pub fn note(span: Span) -> DiagnosticBuilder { - todo!() + DiagnosticBuilder::new(DiagnosticLevel::Note, span) } pub fn help(span: Span) -> DiagnosticBuilder { - todo!() + DiagnosticBuilder::new(DiagnosticLevel::Help, span) } } @@ -119,19 +128,22 @@ pub struct DiagnosticBuilder { impl DiagnosticBuilder { pub fn new(level: DiagnosticLevel, span: Span) -> DiagnosticBuilder { - DiagnosticBuilder { level, span, message: "".into() } + DiagnosticBuilder { + level, + span, + message: "".into(), + } } } -/// \brief Display diagnostics in a given display format. -/// +/// Display diagnostics in a given display format. +/// /// A diagnostic renderer is responsible for converting the /// raw diagnostics into consumable output. -/// +/// /// For example the terminal renderer will render a sequence /// of compiler diagnostics to std::out and std::err in /// a human readable form. -/// #[repr(C)] #[derive(Object)] #[ref_name = "DiagnosticRenderer"] @@ -181,7 +193,7 @@ impl DiagnosticContext { { let renderer = diagnostic_renderer(render_func.to_function()).unwrap(); let node = DiagnosticContextNode { - base: Object::base_object::(), + base: Object::base::(), module, diagnostics: Array::from_vec(vec![]).unwrap(), renderer, @@ -190,7 +202,7 @@ impl DiagnosticContext { } pub fn default(module: IRModule) -> DiagnosticContext { - todo!() + diagnostic_context_default(module).unwrap() } /// Emit a diagnostic. diff --git a/rust/tvm/src/ir/expr.rs b/rust/tvm/src/ir/expr.rs index 91c42f0edbcf..f74522d91c70 100644 --- a/rust/tvm/src/ir/expr.rs +++ b/rust/tvm/src/ir/expr.rs @@ -35,7 +35,7 @@ pub struct BaseExprNode { impl BaseExprNode { pub fn base() -> BaseExprNode { BaseExprNode { - base: Object::base_object::(), + base: Object::base::(), } } } diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 443915ff27b7..190b477b98f2 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -use std::io::Result as IOResult; use std::path::Path; use thiserror::Error; diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 530b1203bd98..a51ff8fe82c5 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -51,7 +51,7 @@ impl ExprNode { base: BaseExprNode::base::(), span: ObjectRef::null(), checked_type: Type::from(TypeNode { - base: Object::base_object::(), + base: Object::base::(), span: Span::empty(), }), } @@ -84,7 +84,7 @@ pub struct IdNode { impl Id { fn new(name_hint: TString) -> Id { let node = IdNode { - base: Object::base_object::(), + base: Object::base::(), name_hint: name_hint, }; Id(Some(ObjectPtr::new(node))) @@ -352,7 +352,7 @@ pub struct PatternNode { impl PatternNode { pub fn base() -> PatternNode { PatternNode { - base: Object::base_object::(), + base: Object::base::(), span: ObjectRef::null(), } } @@ -451,7 +451,7 @@ pub struct ClauseNode { impl Clause { pub fn new(lhs: Pattern, rhs: Expr, _span: ObjectRef) -> Clause { let node = ClauseNode { - base: Object::base_object::(), + base: Object::base::(), lhs, rhs, }; diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index b0cc0d898e94..abc8bb2e5542 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -21,7 +21,7 @@ use crate::runtime::map::Map; use crate::runtime::object::Object; use crate::runtime::string::String as TString; -use super::span::{SourceName, Span}; +use super::span::SourceName; use tvm_macros::Object; diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index afcbe9c01698..fe2cf286c737 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -18,7 +18,7 @@ * under the License. */ -use crate::runtime::{Object, ObjectRef, String as TString}; +use crate::runtime::{Object, String as TString}; use tvm_macros::Object; /// A source file name, contained in a Span. diff --git a/rust/tvm/src/ir/ty.rs b/rust/tvm/src/ir/ty.rs index b6a47f553da4..d12f094a63ea 100644 --- a/rust/tvm/src/ir/ty.rs +++ b/rust/tvm/src/ir/ty.rs @@ -36,7 +36,7 @@ pub struct TypeNode { impl TypeNode { fn base(span: Span) -> Self { TypeNode { - base: Object::base_object::(), + base: Object::base::(), span, } } diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index 59fc60450825..c5a65c417c93 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -50,7 +50,7 @@ impl PassInfo { let required = Array::from_vec(required)?; let node = PassInfoNode { - base: Object::base_object::(), + base: Object::base::(), opt_level, name: name.into(), required, diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index c79ed3dd6969..522524e33dcd 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -168,6 +168,11 @@ DiagnosticContext DiagnosticContext::Default(const IRModule& module) { return DiagnosticContext(module, renderer); } +TVM_REGISTER_GLOBAL("diagnostics.Default") + .set_body_typed([](const IRModule& module) { + return DiagnosticContext::Default(module); + }); + std::ostream& EmitDiagnosticHeader(std::ostream& out, const Span& span, DiagnosticLevel level, std::string msg) { rang::fg diagnostic_color = rang::fg::reset; From d086193cff3f4ef87a2854dac6ee60fc90f681d4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 24 Oct 2020 00:43:52 -0700 Subject: [PATCH 23/32] Format all --- src/contrib/rust_extension.cc | 4 ++-- src/ir/diagnostic.cc | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/contrib/rust_extension.cc b/src/contrib/rust_extension.cc index 075cbc670f66..46e94fffdf55 100644 --- a/src/contrib/rust_extension.cc +++ b/src/contrib/rust_extension.cc @@ -24,8 +24,8 @@ #ifdef RUST_COMPILER_EXT extern "C" { - int compiler_ext_initialize(); - static int test = compiler_ext_initialize(); +int compiler_ext_initialize(); +static int test = compiler_ext_initialize(); } #endif diff --git a/src/ir/diagnostic.cc b/src/ir/diagnostic.cc index 522524e33dcd..e533972cc71a 100644 --- a/src/ir/diagnostic.cc +++ b/src/ir/diagnostic.cc @@ -168,10 +168,9 @@ DiagnosticContext DiagnosticContext::Default(const IRModule& module) { return DiagnosticContext(module, renderer); } -TVM_REGISTER_GLOBAL("diagnostics.Default") - .set_body_typed([](const IRModule& module) { - return DiagnosticContext::Default(module); - }); +TVM_REGISTER_GLOBAL("diagnostics.Default").set_body_typed([](const IRModule& module) { + return DiagnosticContext::Default(module); +}); std::ostream& EmitDiagnosticHeader(std::ostream& out, const Span& span, DiagnosticLevel level, std::string msg) { From 62f3e39d51f4ca95cb9f9732ca65d0e5515879d6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 24 Oct 2020 00:45:37 -0700 Subject: [PATCH 24/32] Remove dead comment --- rust/tvm/src/ir/source_map.rs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index abc8bb2e5542..4d35c17976af 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -34,29 +34,24 @@ use tvm_macros::Object; #[ref_name = "Source"] pub struct SourceNode { pub base: Object, - /// The source name. */ + /// The source name. pub source_name: SourceName, - /// The raw source. */ + /// The raw source. pub source: TString, + + // TODO(@jroesch): Non-ABI compat field // A mapping of line breaks into the raw source. // std::vector> line_map; } -// class Source : public ObjectRef { -// public: -// TVM_DLL Source(SourceName src_name, std::string source); -// TVM_DLL tvm::String GetLine(int line); - -// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); -// }; - /// A mapping from a unique source name to source fragments. #[repr(C)] #[derive(Object)] #[type_key = "SourceMap"] #[ref_name = "SourceMap"] pub struct SourceMapNode { + /// The base object. pub base: Object, /// The source mapping. pub source_map: Map, From 0b5645edbc4c807ab55c6d3cf55a6a424cff40eb Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 25 Oct 2020 23:35:41 -0700 Subject: [PATCH 25/32] Code review comments and apache headers --- CMakeLists.txt | 2 +- cmake/modules/RustExt.cmake | 17 +++++++++++++++++ rust/compiler-ext/Cargo.toml | 17 +++++++++++++++++ rust/tvm/src/bin/tyck.rs | 19 +++++++++++++++++++ rust/tvm/src/ir/span.rs | 18 ------------------ 5 files changed, 54 insertions(+), 19 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6fc19037e737..d67f7fd59aee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -79,7 +79,7 @@ tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF) tvm_option(USE_TENSORRT_CODEGEN "Build with TensorRT Codegen support" OFF) tvm_option(USE_TENSORRT_RUNTIME "Build with TensorRT runtime" OFF) -tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions" OFF) +tvm_option(USE_RUST_EXT "Build with Rust based compiler extensions, STATIC, DYNAMIC, or OFF" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake index f6cb9199015b..a0215fce2fac 100644 --- a/cmake/modules/RustExt.cmake +++ b/cmake/modules/RustExt.cmake @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + if(USE_RUST_EXT) set(RUST_SRC_DIR "${CMAKE_SOURCE_DIR}/rust") set(CARGO_OUT_DIR "${CMAKE_SOURCE_DIR}/rust/target") diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml index 1633e1dfee1d..c41552097025 100644 --- a/rust/compiler-ext/Cargo.toml +++ b/rust/compiler-ext/Cargo.toml @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + [package] name = "compiler-ext" version = "0.1.0" diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs index e9c26630281c..5737603b1cc1 100644 --- a/rust/tvm/src/bin/tyck.rs +++ b/rust/tvm/src/bin/tyck.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use std::path::PathBuf; use anyhow::Result; diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index fe2cf286c737..6d1a51f50e62 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -22,7 +22,6 @@ use crate::runtime::{Object, String as TString}; use tvm_macros::Object; /// A source file name, contained in a Span. - #[repr(C)] #[derive(Object)] #[type_key = "SourceName"] @@ -32,23 +31,6 @@ pub struct SourceNameNode { pub name: TString, } -// /*! -// * \brief The source name of a file span. -// * \sa SourceNameNode, Span -// */ -// class SourceName : public ObjectRef { -// public: -// /*! -// * \brief Get an SourceName for a given operator name. -// * Will raise an error if the source name has not been registered. -// * \param name Name of the operator. -// * \return SourceName valid throughout program lifetime. -// */ -// TVM_DLL static SourceName Get(const String& name); - -// TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); -// }; - /// Span information for diagnostic purposes. #[repr(C)] #[derive(Object)] From 2f778db07ac6bd0ffda2a408c7e62902107a48bf Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sun, 25 Oct 2020 23:44:07 -0700 Subject: [PATCH 26/32] Purge test file --- rust/tvm/test.rly | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 rust/tvm/test.rly diff --git a/rust/tvm/test.rly b/rust/tvm/test.rly deleted file mode 100644 index e9407b029787..000000000000 --- a/rust/tvm/test.rly +++ /dev/null @@ -1,3 +0,0 @@ -#[version = "0.0.5"] - -def @main(%x: int32) -> float32 { %x } From e8fd9a57b6715aeb146f43c6fdf839d45952be7d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 30 Oct 2020 13:51:21 -0700 Subject: [PATCH 27/32] Update cmake/modules/LLVM.cmake Co-authored-by: Tristan Konolige --- cmake/modules/LLVM.cmake | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index ca4ecd6db1ca..3b26058daf51 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -21,7 +21,8 @@ # # See https://github.com/imageworks/OpenShadingLanguage/issues/1069 # for more discussion. -add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1) +add_definitions(-DDMLC_USE_FOPEN64=0) +target_add_definitions(tvm PRIVATE NDEBUG=1) # Test if ${USE_LLVM} is not an explicit boolean false # It may be a boolean or a string From e92adcc668319fe254698f0b34830be985a53724 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 30 Oct 2020 13:50:55 -0700 Subject: [PATCH 28/32] Format Rust --- rust/tvm/src/ir/source_map.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index 4d35c17976af..54e16dac62ac 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -39,7 +39,6 @@ pub struct SourceNode { /// The raw source. pub source: TString, - // TODO(@jroesch): Non-ABI compat field // A mapping of line breaks into the raw source. // std::vector> line_map; From 5731a6037af32411e87c52990bce7993aa23bd17 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 30 Oct 2020 13:54:00 -0700 Subject: [PATCH 29/32] Add TK's suggestion --- cmake/modules/RustExt.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake index a0215fce2fac..2922bc48dee2 100644 --- a/cmake/modules/RustExt.cmake +++ b/cmake/modules/RustExt.cmake @@ -24,7 +24,7 @@ if(USE_RUST_EXT) elseif(USE_RUST_EXT STREQUAL "DYNAMIC") set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.so") else() - message(FATAL_ERROR "invalid setting for USE_RUST_EXT") + message(FATAL_ERROR "invalid setting for USE_RUST_EXT, STATIC, DYNAMIC or OFF") endif() add_custom_command( From 0065966f3f3c2e03876282ab5d29e15ed11f61ad Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 31 Oct 2020 15:06:27 -0700 Subject: [PATCH 30/32] More CR and cleanup --- cmake/modules/LLVM.cmake | 5 +++-- rust/compiler-ext/Cargo.toml | 2 +- rust/tvm-rt/src/array.rs | 4 +++- rust/tvm-rt/src/function.rs | 8 -------- rust/tvm/src/ir/relay/mod.rs | 7 ++++--- rust/tvm/src/ir/span.rs | 20 +++++++++++++++++--- 6 files changed, 28 insertions(+), 18 deletions(-) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index 3b26058daf51..ac870b17faeb 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -21,8 +21,9 @@ # # See https://github.com/imageworks/OpenShadingLanguage/issues/1069 # for more discussion. -add_definitions(-DDMLC_USE_FOPEN64=0) -target_add_definitions(tvm PRIVATE NDEBUG=1) +add_definitions(-DDMLC_USE_FOPEN64=0 -DNDEBUG=1) +# TODO(@jroesch, @tkonolige): if we actually use targets we can do this. +# target_compile_definitions(tvm PRIVATE NDEBUG=1) # Test if ${USE_LLVM} is not an explicit boolean false # It may be a boolean or a string diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml index c41552097025..b830b7a84135 100644 --- a/rust/compiler-ext/Cargo.toml +++ b/rust/compiler-ext/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "compiler-ext" version = "0.1.0" -authors = ["Jared Roesch "] +authors = ["TVM Contributors"] edition = "2018" [lib] diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 66e32a7e7177..98414f9c5b34 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -93,7 +93,9 @@ impl Iterator for IntoIter { fn next(&mut self) -> Option { if self.pos < self.size { - let item = self.array.get(self.pos).expect("should not fail"); + let item = + self.array.get(self.pos) + .expect("Can not index as in-bounds position after bounds checking.\nNote: this error can only be do to an uncaught issue with API bindings."); self.pos += 1; Some(item) } else { diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 4c6f56ea1e76..aec4a8ad44de 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -33,7 +33,6 @@ use std::{ }; use crate::errors::Error; -use crate::object::ObjectPtr; pub use super::to_function::{ToFunction, Typed}; pub use tvm_sys::{ffi, ArgValue, RetValue}; @@ -142,13 +141,6 @@ impl Function { } let rv = RetValue::from_tvm_value(ret_val, ret_type_code as u32); - match rv { - RetValue::ObjectHandle(object) => { - let optr = ObjectPtr::from_raw(object as _).unwrap(); - ObjectPtr::leak(optr); - } - _ => {} - }; Ok(rv) } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index a51ff8fe82c5..cc1a76bef7e3 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -22,7 +22,7 @@ pub mod attrs; use std::hash::Hash; use crate::runtime::array::Array; -use crate::runtime::{object::*, String as TString}; +use crate::runtime::{object::*, IsObjectRef, String as TString}; use super::attrs::Attrs; use super::expr::BaseExprNode; @@ -52,7 +52,7 @@ impl ExprNode { span: ObjectRef::null(), checked_type: Type::from(TypeNode { base: Object::base::(), - span: Span::empty(), + span: Span::null(), }), } } @@ -554,7 +554,8 @@ def @main() -> float32 { 0.01639530062675476f } "#, - ); + ) + .unwrap(); let main = module .lookup(module.get_global_var("main".to_string().into()).unwrap()) .unwrap(); diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index 6d1a51f50e62..eb6821af69dc 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -18,7 +18,7 @@ * under the License. */ -use crate::runtime::{Object, String as TString}; +use crate::runtime::{Object, ObjectPtr, String as TString}; use tvm_macros::Object; /// A source file name, contained in a Span. @@ -51,7 +51,21 @@ pub struct SpanNode { } impl Span { - pub fn empty() -> Span { - todo!() + pub fn new( + source_name: SourceName, + line: i32, + end_line: i32, + column: i32, + end_column: i32, + ) -> Span { + let span_node = SpanNode { + base: Object::base::(), + source_name, + line, + end_line, + column, + end_column, + }; + Span(Some(ObjectPtr::new(span_node))) } } From 5f2ad03e3d4e8865f926748832f9fc5c02c5fab5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 31 Oct 2020 15:09:27 -0700 Subject: [PATCH 31/32] Fix tyck line --- rust/tvm/src/bin/tyck.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs index 5737603b1cc1..e9eff2c73bca 100644 --- a/rust/tvm/src/bin/tyck.rs +++ b/rust/tvm/src/bin/tyck.rs @@ -35,9 +35,9 @@ struct Opt { } fn main() -> Result<()> { - codespan::init().expect("Rust based diagnostics"); + codespan::init() + .expect("Failed to initialize Rust based diagnostics."); let opt = Opt::from_args(); - println!("{:?}", &opt); let _module = match IRModule::parse_file(opt.input) { Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => return Ok(()), Err(e) => { From 9700d81ed0d169a645eaf96f32b0f7f93c0bc6bc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 31 Oct 2020 15:09:36 -0700 Subject: [PATCH 32/32] Format --- rust/tvm/src/bin/tyck.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs index e9eff2c73bca..839a6bd1c17f 100644 --- a/rust/tvm/src/bin/tyck.rs +++ b/rust/tvm/src/bin/tyck.rs @@ -35,8 +35,7 @@ struct Opt { } fn main() -> Result<()> { - codespan::init() - .expect("Failed to initialize Rust based diagnostics."); + codespan::init().expect("Failed to initialize Rust based diagnostics."); let opt = Opt::from_args(); let _module = match IRModule::parse_file(opt.input) { Err(ir::module::Error::TVM(Error::DiagnosticError(_))) => return Ok(()),