From c8b922ff943c835086944631f812c9a33a4bc501 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 17 Mar 2020 12:45:31 -0700 Subject: [PATCH 01/31] Start on memory planning WIP Move to test_memory_passes.py Work on memory planning Post-rebase and VM changes Plumb through the offsets Basic tests all pass, fix offset to data buffer. Fix compile errors Fix ws Apply suggestions from code review Co-Authored-By: Haichen Shen Address CR Update src/runtime/vm/vm.cc Co-Authored-By: Haichen Shen Fix another comment Fix lint Fix Fix Fix Lint is done? Fix More fix Trying to debug No clue Fix lint --- include/tvm/runtime/vm.h | 15 +- python/tvm/relay/__init__.py | 6 + python/tvm/relay/op/memory/memory.py | 7 +- python/tvm/relay/transform/__init__.py | 1 - python/tvm/relay/transform/memory_alloc.py | 14 +- python/tvm/relay/transform/memory_plan.py | 231 ++++++++++++++++++ src/relay/backend/vm/compiler.cc | 77 ++++-- src/relay/op/memory/memory.cc | 40 +-- src/runtime/vm/executable.cc | 39 +-- src/runtime/vm/memory_manager.cc | 26 +- src/runtime/vm/vm.cc | 25 +- ..._memory_alloc.py => test_memory_passes.py} | 49 +++- 12 files changed, 439 insertions(+), 91 deletions(-) create mode 100644 python/tvm/relay/transform/memory_plan.py rename tests/python/relay/{test_pass_memory_alloc.py => test_memory_passes.py} (64%) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 62534d9ca6a9..5aa6ce7ddf74 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_VM_H_ #define TVM_RUNTIME_VM_H_ -#include #include +#include #include #include @@ -136,6 +136,8 @@ struct Instruction { struct /* AllocTensor Operands */ { /*! \brief The storage to allocate from. */ RegName storage; + /*! \brief The offset into the storage to allocate from. */ + Index offset; /*! \brief The number of dimensions. */ uint32_t ndim; /*! \brief The shape of tensor. */ @@ -146,6 +148,8 @@ struct Instruction { struct /* AllocTensorReg Operands */ { /*! \brief The storage to allocate from. */ RegName storage; + /*! \brief The offset into the storage to allocate from. */ + Index offset; /*! \brief The register to read the shape out of. */ RegName shape_register; /*! \brief The datatype of tensor to be allocated. */ @@ -272,18 +276,19 @@ struct Instruction { * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensor(RegName storage, const std::vector& shape, - DLDataType dtype, RegName dst); + static Instruction AllocTensor(RegName storage, Index offset, + const std::vector& shape, DLDataType dtype, RegName dst); /*! * \brief Construct an allocate tensor instruction with register. * \param storage The storage to allocate out of. + * \param offset The offset into the storage to allocate from. * \param shape_register The register containing the shape. * \param dtype The dtype of the tensor. * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, - RegName dst); + static Instruction AllocTensorReg(RegName storage, Index offset, + RegName shape_register, DLDataType dtype, RegName dst); /*! * \brief Construct an allocate datatype instruction. * \param tag The datatype tag. diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 4663866b1452..8e48e509490b 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -58,6 +58,12 @@ # Dialects from . import qnn +from .scope_builder import ScopeBuilder + +# Load Memory Passes +from .transform import memory_alloc +from .transform import memory_plan + # Required to traverse large programs setrecursionlimit(10000) diff --git a/python/tvm/relay/op/memory/memory.py b/python/tvm/relay/op/memory/memory.py index 509db354b42c..4092545d552c 100644 --- a/python/tvm/relay/op/memory/memory.py +++ b/python/tvm/relay/op/memory/memory.py @@ -40,7 +40,7 @@ def invoke_tvm_op(func, inputs, outputs): """ return _make.invoke_tvm_op(func, inputs, outputs) -def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): +def alloc_tensor(storage, offset, shape, dtype='float32', assert_shape=None): """Allocate a tensor with the provided shape, and dtype. Parameters @@ -48,6 +48,9 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): storage : tvm.relay.Expr The storage to allocate from. + offset : tvm.relay.Expr + The offset to allocate from. + shape : tvm.relay.Expr The shape of the tensor to allocate. @@ -61,7 +64,7 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): result : tvm.relay.Expr The alloc_tensor expression. """ - return _make.alloc_tensor(storage, shape, dtype, assert_shape) + return _make.alloc_tensor(storage, offset, shape, dtype, assert_shape) def alloc_storage(size, alignment, ctx, dtype_hint='float32'): """Allocate a piece of tensor storage. diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 93d4341635a0..138a36611c6f 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -18,5 +18,4 @@ """The Relay IR namespace containing transformations.""" # transformation passes from .transform import * - from . import memory_alloc diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 611fb1babf55..5fb4909b57c7 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -28,19 +28,21 @@ from ..backend import compile_engine from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type from ...import cpu +from ..op.memory import alloc_storage +def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): + offset = expr.const(0, dtype="int64") + return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape) def is_primitive(call): return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \ hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1 class ManifestAllocPass(ExprMutator): - """A pass for explictly manifesting all memory allocations in Relay.""" + """A pass for explicitly manifesting all memory allocations in Relay.""" def __init__(self, target_host): self.invoke_tvm = op.memory.invoke_tvm_op - self.alloc_storage = op.memory.alloc_storage - self.alloc_tensor = op.memory.alloc_tensor self.shape_func = op.memory.shape_func self.scopes = [ScopeBuilder()] self.target_host = target_host @@ -101,10 +103,10 @@ def make_static_allocation(self, scope, tensor_type, i): size = self.compute_storage(tensor_type) alignment = self.compute_alignment(tensor_type.dtype) dtype = tensor_type.dtype - sto = scope.let("storage_{0}".format(i), self.alloc_storage( + sto = scope.let("storage_{0}".format(i), alloc_storage( size, alignment, self.default_context, dtype)) # TODO(@jroesch): There is a bug with typing based on the constant shape. - tensor = self.alloc_tensor(sto, shape, dtype, tensor_type.shape) + tensor = alloc_tensor(sto, shape, dtype, tensor_type.shape) return scope.let("tensor_{0}".format(i), tensor) def visit_let(self, let): @@ -179,7 +181,7 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): outs = [] sh_ty_storage = zip(out_shapes, out_types, storages) for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage): - alloc = self.alloc_tensor( + alloc = alloc_tensor( storage, out_shape, out_type.dtype, diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py new file mode 100644 index 000000000000..e3625f8241df --- /dev/null +++ b/python/tvm/relay/transform/memory_plan.py @@ -0,0 +1,231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks +""" +A pass for manifesting explicit memory allocations. +""" +from typing import Optional, Dict +import attr + +from ..expr_functor import ExprMutator +from .. import op, expr +from ..function import Function +from ... import register_func, ir, cpu +from ..._ffi.runtime_ctypes import TVMContext +from . import function_pass + + +def is_primitive(call): + return ( + hasattr(call, "op") + and hasattr(call.op, "attrs") + and hasattr(call.op.attrs, "Primitive") + and int(call.op.attrs.Primitive) == 1 + ) + + +@attr.s(auto_attribs=True) +class Region: + """ + Represents a control-free allocation region. + + The below pass groups sets of allocations into regions, + then replaces the region with a single allocation. + """ + var: expr.Var + size: expr.Expr + alignment: Optional[expr.Expr] + dtype: Optional[str] + ctx: TVMContext + offsets: Dict[expr.Var, expr.Expr] = {} + + def grow( + self, old_storage: expr.Var, + size: expr.Expr, alignment: expr.Expr, + ctx: TVMContext, + dtype: str) -> None: + """Grow the region by a given allocation as well as track the old storage + for later rewriting the program to use the allocated region. + """ + if self.dtype: + assert self.dtype == dtype, "must have matching dtypes in a region" + else: + self.dtype = dtype + + if self.alignment: + assert ir.structural_equal( + self.alignment, alignment + ), "must have matching alignments in a region" + else: + self.alignment = alignment + + if self.ctx: + assert (self.ctx.device_type == ctx.device_type and + self.ctx.device_id == ctx.device_id), "must have matching context" + else: + assert ctx + self.ctx = ctx + + # Record the offset at which we allocate the storage. + self.offsets[old_storage] = self.size + + self.size = self.size + size + + def to_expr(self) -> expr.Expr: + if self.ctx is None: + self.ctx = cpu(0) + + return op.memory.alloc_storage(self.size, self.alignment, self.ctx, self.dtype) + + +def iterative_let(let, each_binding, kont): + bindings = [] + while isinstance(let, expr.Let): + lhs = let.var + rhs = let.value + bindings.append(each_binding(lhs, rhs)) + let = let.body + + return kont(bindings, let) + + +def mk_let(bindings, body): + for var, value in reversed(bindings): + assert var + assert value + assert body + body = expr.Let(var, value, body) + return body + + +class StorageCoalesce(ExprMutator): + """ + A pass for coalescing allocations into region/arena allocations. + + After this pass each allocation comes from the same backing storage, + but will never overlap even in time, i.e. the allocations are just + packed into a contiguous block of memory. + + A secondary part of memory planning will perform liveness analysis to + overlap these in time, i.e when an early tensor dies we will attempt + to reuse its slot. + """ + + def __init__(self): + super().__init__() + self.regions = [] + + def enter_scope(self): + zero = expr.const(0, dtype="int64") + region_var = expr.var(f"region{len(self.regions)}") + region = Region(region_var, zero, None, None, None) + self.regions.append(region) + + def exit_scope(self, body: expr.Expr) -> expr.Expr: + """When leaving a scope build a region allocation for the scope.""" + region = self.regions.pop() + if len(region.offsets) == 0: + return body + else: + storage_expr = region.to_expr() + assert storage_expr, "can not be None" + assert region.var + assert storage_expr + assert body + return expr.Let(region.var, storage_expr, body) + + def current_region(self) -> Region: + return self.regions[-1] + + def visit_function(self, fn): + """Transform the function body to use region allocation scheme.""" + func = fn + if func.attrs and getattr(func.attrs, "Primitive", 0) == 1: + return func + else: + self.enter_scope() + body = self.visit(func.body) + body = self.exit_scope(body) + return Function( + func.params, + body, + func.ret_type, + func.type_params, + func.attrs, + ) + + def visit_if(self, ite): + self.enter_scope() + true_branch = self.visit(ite.true_branch) + true_branch = self.exit_scope(true_branch) + + self.enter_scope() + false_branch = self.visit(ite.false_branch) + false_branch = self.exit_scope(false_branch) + + return expr.If(ite.cond, true_branch, false_branch) + + def visit_let(self, let): + def _each_binding(lhs, rhs): + if isinstance(rhs, expr.Call) and rhs.op == op.op.get( + "memory.alloc_storage" + ): + return self.process_alloc_storage(lhs, rhs) + elif isinstance(rhs, expr.Call) and rhs.op == op.op.get( + "memory.alloc_tensor" + ): + return self.process_alloc_tensor(lhs, rhs) + else: + return lhs, rhs + + result = iterative_let(let, _each_binding, mk_let) + assert result + return result + + def process_alloc_storage(self, lhs, call): + size, alignment = call.args + dtype = call.attrs.dtype + ctx = TVMContext(call.attrs.device_type, call.attrs.device_id) + region = self.current_region() + region.grow(lhs, size, alignment, ctx, dtype) + return lhs, region.var + + def process_alloc_tensor(self, lhs, call): + region = self.current_region() + storage, old_offset, shape = call.args + offset = region.offsets[storage] + assert ( + old_offset.data.asnumpy().item() == 0 + ), "no offsets should yet be allocated" + return ( + lhs, + expr.Call(call.op, [region.var, offset, shape], call.attrs), + ) + +@function_pass(opt_level=0) +class MemoryPlan: + """An explicit pass wrapper around ManifestAlloc.""" + + def transform_function(self, func, mod, _): + mod.import_from_std("core.rly") + sc = StorageCoalesce() + before = func + func = sc.visit(func) + return before + + +register_func("relay.transform.MemoryPlan", MemoryPlan) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b2a5e83ef43c..e3a47e76590c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -42,8 +42,9 @@ #include #include "../../backend/compile_engine.h" -#include "../../op/op_common.h" #include "../../transforms/pass_util.h" +#include "../../op/op_common.h" +#include "compiler.h" #include "../utils.h" namespace tvm { @@ -56,8 +57,14 @@ Pass InlinePrimitives(); Pass ManifestAlloc(Target target_host) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); - CHECK(f != nullptr) << "could not load memory allocation pass"; + CHECK(f != nullptr) << "unable to load allocation manifestation pass"; return (*f)(target_host); +Pass MemoryPlan() { +} + + auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan"); + CHECK(f != nullptr) << "unable to load the memory planning pass"; + return (*f)(); } } // namespace transform @@ -503,8 +510,8 @@ class VMFunctionCompiler : ExprFunctor { }) .Match( "memory.alloc_tensor", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 2); + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); // Get the attributes. auto alloc_attrs = attrs.as(); @@ -512,23 +519,27 @@ class VMFunctionCompiler : ExprFunctor { auto dtype = alloc_attrs->dtype; // The storage will be passed dynamically. - this->VisitExpr(args[0]); - auto storage_register = last_register_; + this->VisitExpr(args[1]); + auto offset_register = last_register_; // If the shape is constant then we will emit a static tensor allocation // instruction. - auto const_shape = args[1].as(); - - if (const_shape) { - NDArray shape = const_shape->data; - // TODO(@jroesch): we need to get an RFC done to standarize shape dtype - std::vector raw_shape = ToAllocTensorShape(shape); - // Add context field. - Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister())); + auto const_shape = args[2].as(); + + if (const_shape) { + NDArray shape = const_shape->data; + // TODO(@jroesch): we need to get an RFC done to standarize shape dtype + std::vector raw_shape = ToAllocTensorShape(shape); + // Add context field. + Emit(Instruction::AllocTensor( + storage_register, + offset_register, + raw_shape, + dtype, NewRegister())); } else { this->VisitExpr(args[1]); auto shape_register = last_register_; - Emit(Instruction::AllocTensorReg(storage_register, shape_register, dtype, + Emit(Instruction::AllocTensorReg(storage_register, shape_register, offset_register, dtype, NewRegister())); } }) @@ -830,6 +841,33 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe } } +transform::Sequential MemoryOpt(tvm::Target host_target) { + Array pass_seqs; + // Manifest the allocations. + pass_seqs.push_back(transform::ManifestAlloc(host_target)); + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); + + // Inline the functions that are lifted to the module scope. We perform this + // pass after all other optimization passes but before the memory allocation + // pass. This is because memory allocation pass will insert `invoke_tvm_op` + // and we use these ops to invoke the symbols in the module generated by + // external codegen. + pass_seqs.push_back(transform::Inline()); + + // Manifest the allocations needed for the shape functions. + pass_seqs.push_back(transform::ManifestAlloc(host_target)); + + // Perform memory planning in order to coalesce/reduce allocations. + pass_seqs.push_back(transform::MemoryPlan()); + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); + + return transform::Sequential(pass_seqs); +} + IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) { Array pass_seqs; Array entry_functions{"main"}; @@ -890,15 +928,8 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // external codegen. pass_seqs.push_back(transform::Inline()); - // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); - // Compute away possibly introduced constant computation. - pass_seqs.push_back(transform::FoldConstant()); - // Fuse the shape functions. - pass_seqs.push_back(transform::FuseOps()); - // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); + pass_seqs.push_back(MemoryOpt(this->target_host_)); transform::Sequential seq(pass_seqs); transform::PassContext pass_ctx = PassContext::Current(); diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index ec96e23a01fb..170c087cc27d 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -92,18 +92,19 @@ RELAY_REGISTER_OP("memory.alloc_storage") }); TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") - .set_body_typed([](Expr storage, tvm::relay::Expr shape, DataType dtype, - Array assert_shape) { - auto attrs = make_object(); - attrs->dtype = dtype; - if (assert_shape.defined()) { - attrs->assert_shape = assert_shape; - } else { - attrs->const_shape = Downcast(shape); - } - static const Op& op = Op::Get("memory.alloc_tensor"); - return Call(op, {storage, shape}, Attrs(attrs), {}); - }); + .set_body_typed( + [](Expr storage, Expr offset, tvm::relay::Expr shape, + DataType dtype, Array assert_shape) { + auto attrs = make_object(); + attrs->dtype = dtype; + if (assert_shape.defined()) { + attrs->assert_shape = assert_shape; + } else { + attrs->const_shape = Downcast(shape); + } + static const Op& op = Op::Get("memory.alloc_tensor"); + return Call(op, {storage, offset, shape}, Attrs(attrs), {}); + }); std::vector FromConstShape(Constant konst) { runtime::NDArray shape = konst->data; @@ -132,7 +133,7 @@ std::vector FromConstShape(Constant konst) { bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3u); + CHECK_EQ(types.size(), 4u); auto alloc_attrs = attrs.as(); CHECK(alloc_attrs != nullptr) << "must be alloc_tensor attributes"; // First argument should be storage. @@ -141,8 +142,12 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs auto storage_name = mod->GetGlobalTypeVar("Storage"); auto storage = relay::TypeCall(storage_name, {}); reporter->Assign(types[0], storage); - // Second argument should be shape tensor. - auto tt = types[1].as(); + // Second argument should be the offset. + auto offset_type = types[1].as(); + CHECK(offset_type != nullptr) << "must be a scalar type"; + + // Third argument should be shape tensor. + auto tt = types[2].as(); CHECK(tt != nullptr) << "must be tensor type"; auto rank = tt->shape[0].as(); CHECK(rank != nullptr); @@ -165,14 +170,15 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs return true; } - reporter->Assign(types[2], alloc_type); + reporter->Assign(types[3], alloc_type); return true; } RELAY_REGISTER_OP("memory.alloc_tensor") .describe(R"code(Explicitly allocate storage to be used by tensors.)code" TVM_ADD_FILELINE) - .set_num_inputs(2) + .set_num_inputs(3) .add_argument("storage", "Storage", "The storage to allocate from.") + .add_argument("offset", "Tensor", "The offset into the backing storage.") .add_argument("shape", "Tensor", "The shape of the tensor to allocate.") .add_type_rel("AllocTensor", AllocTensorRel) .set_support_level(10) diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index c72e70fd6f66..6576ef5cb81d 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -307,9 +307,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { break; } case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim + // Number of fields = 6 + instr.alloc_tensor.ndim fields.push_back(instr.alloc_tensor.storage); - + fields.push_back(instr.alloc_tensor.offset); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor.dtype; fields.push_back(dtype.code); @@ -330,8 +330,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { break; } case Opcode::AllocTensorReg: { - // Number of fields = 6 + // Number of fields = 7 fields.push_back(instr.alloc_tensor_reg.storage); + fields.push_back(instr.alloc_tensor_reg.offset); fields.push_back(instr.alloc_tensor_reg.shape_register); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor_reg.dtype; @@ -550,38 +551,40 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { } case Opcode::AllocTensor: { // Number of fields = 6 + instr.alloc_tensor.ndim - DCHECK_GE(instr.fields.size(), 6U); - DCHECK_EQ(instr.fields.size(), 6U + static_cast(instr.fields[4])); + DCHECK_GE(instr.fields.size(), 7U); + DCHECK_EQ(instr.fields.size(), 7U + static_cast(instr.fields[4])); RegName storage_reg = instr.fields[0]; + RegName offset = instr.fields[1]; DLDataType dtype; - dtype.code = instr.fields[1]; - dtype.bits = instr.fields[2]; - dtype.lanes = instr.fields[3]; + dtype.code = instr.fields[2]; + dtype.bits = instr.fields[3]; + dtype.lanes = instr.fields[4]; - Index ndim = instr.fields[4]; - RegName dst = instr.fields[5]; + Index ndim = instr.fields[5]; + RegName dst = instr.fields[6]; std::vector shape = ExtractFields(instr.fields, 6, ndim); - return Instruction::AllocTensor(storage_reg, shape, dtype, dst); + return Instruction::AllocTensor(storage_reg, offset, shape, dtype, dst); } case Opcode::AllocTensorReg: { // Number of fields = 5 - DCHECK_EQ(instr.fields.size(), 6U); + DCHECK_EQ(instr.fields.size(), 7U); RegName storage_reg = instr.fields[0]; - Index shape_register = instr.fields[1]; + RegName offset = instr.fields[1]; + Index shape_register = instr.fields[2]; DLDataType dtype; - dtype.code = instr.fields[2]; - dtype.bits = instr.fields[3]; - dtype.lanes = instr.fields[4]; + dtype.code = instr.fields[3]; + dtype.bits = instr.fields[4]; + dtype.lanes = instr.fields[5]; - RegName dst = instr.fields[5]; + RegName dst = instr.fields[6]; - return Instruction::AllocTensorReg(storage_reg, shape_register, dtype, dst); + return Instruction::AllocTensorReg(storage_reg, offset, shape_register, dtype, dst); } case Opcode::AllocADT: { // Number of fields = 3 + instr.num_fields diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index c0fd441bb0ca..1508e0311db0 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -77,8 +77,6 @@ inline size_t GetDataAlignment(const DLTensor& arr) { } NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDataType dtype) { - // TODO(@jroesch): generalize later to non-overlapping allocations. - CHECK_EQ(offset, 0u); VerifyDataType(dtype); // crtical zone: allocate header, cannot throw @@ -87,14 +85,28 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDa container->SetDeleter(StorageObj::Deleter); size_t needed_size = GetDataSize(container->dl_tensor); this->IncRef(); + // The manager context pointer must continue to point to the storage object + // which owns the backing memory, and keeps track of the reference count. + // + // When we free a container we extract the storage object, decrement its + // reference count, then destroy the container, but leave the underlying + // buffer intact. container->manager_ctx = reinterpret_cast(this); - container->dl_tensor.data = this->buffer.data; - NDArray ret(GetObjectPtr(container)); + // is this UB? + // The only change we make w.r.t offset is modifying the data pointer + // of the backing tensor to point into the buffer instead of its start. + auto offset_ptr = reinterpret_cast(this->buffer.data) + offset; + container->dl_tensor.data = reinterpret_cast(offset_ptr); + + NDArray ret(GetObjectPtr(container)); // RAII in effect, now run the check. - // TODO(@jroesch): generalize later to non-overlapping allocations. - CHECK(needed_size == this->buffer.size) - << "size mistmatch required " << needed_size << " found " << this->buffer.size; + + CHECK(offset + needed_size <= this->buffer.size) + << "storage allocation failure, attempted to allocate " + << needed_size << " at offset " + << offset << " in region that is " + << this->buffer.size << "bytes"; return ret; } diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 0714709a0718..e1326771aa30 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -85,6 +85,7 @@ Instruction::Instruction(const Instruction& instr) { return; case Opcode::AllocTensor: this->alloc_tensor.storage = instr.alloc_tensor.storage; + this->alloc_tensor.offset = instr.alloc_tensor.offset; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); @@ -92,6 +93,7 @@ Instruction::Instruction(const Instruction& instr) { return; case Opcode::AllocTensorReg: this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage; + this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return; @@ -174,7 +176,8 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->result = instr.result; return *this; case Opcode::AllocTensor: - this->alloc_tensor.storage = instr.alloc_tensor.storage; + this->alloc_tensor.storage = this->alloc_tensor.storage; + this->alloc_tensor.offset = instr.alloc_tensor.offset; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); @@ -182,6 +185,7 @@ Instruction& Instruction::operator=(const Instruction& instr) { return *this; case Opcode::AllocTensorReg: this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage; + this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return *this; @@ -308,11 +312,15 @@ Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index out } Instruction Instruction::AllocTensor(RegName storage, const std::vector& shape, + RegName storage, + RegName offset, + const std::vector& shape, DLDataType dtype, Index dst) { Instruction instr; instr.op = Opcode::AllocTensor; instr.dst = dst; instr.alloc_tensor.storage = storage; + instr.alloc_tensor.offset = offset; instr.alloc_tensor.ndim = shape.size(); instr.alloc_tensor.shape = new int64_t[shape.size()]; for (size_t i = 0; i < shape.size(); ++i) { @@ -323,11 +331,15 @@ Instruction Instruction::AllocTensor(RegName storage, const std::vector } Instruction Instruction::AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, + RegName storage, + RegName offset, + RegName shape_register, Index dst) { Instruction instr; instr.op = Opcode::AllocTensorReg; instr.dst = dst; instr.alloc_tensor_reg.storage = storage; + instr.alloc_tensor_reg.offset = offset; instr.alloc_tensor_reg.shape_register = shape_register; instr.alloc_tensor_reg.dtype = dtype; return instr; @@ -515,12 +527,16 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { } case Opcode::AllocTensor: { os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " [" + << instr.alloc_tensor.storage << " " + << instr.alloc_tensor.offset << " [" << StrJoin(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] "; DLDatatypePrint(os, instr.alloc_tensor.dtype); break; } case Opcode::AllocTensorReg: { os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $" + << instr.alloc_tensor_reg.storage << " $" + << instr.alloc_tensor_reg.offset << " " << instr.alloc_tensor_reg.shape_register << " "; DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; @@ -785,6 +801,7 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { } } + void VirtualMachine::Init(const std::vector& ctxs) { ctxs_ = ctxs; } inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { @@ -945,8 +962,9 @@ void VirtualMachine::RunLoop() { } auto storage_obj = ReadRegister(instr.alloc_tensor.storage); + auto offset = LoadScalarInt(instr.alloc_tensor.offset); auto storage = Downcast(storage_obj); - auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype); + auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor.dtype); WriteRegister(instr.dst, obj); pc_++; @@ -969,7 +987,8 @@ void VirtualMachine::RunLoop() { auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); - auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype); + auto offset = LoadScalarInt(instr.alloc_tensor.offset); + auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor_reg.dtype); WriteRegister(instr.dst, obj); pc_++; diff --git a/tests/python/relay/test_pass_memory_alloc.py b/tests/python/relay/test_memory_passes.py similarity index 64% rename from tests/python/relay/test_pass_memory_alloc.py rename to tests/python/relay/test_memory_passes.py index c3c53121d934..9de9b2e38c07 100644 --- a/tests/python/relay/test_pass_memory_alloc.py +++ b/tests/python/relay/test_memory_passes.py @@ -18,21 +18,38 @@ from tvm import te import numpy as np from tvm import relay -from tvm.relay.transform import memory_alloc +from tvm.relay import memory_alloc -def check_vm_alloc(func, check_fn): - mod = tvm.IRModule() - mod['main'] = func - ex = relay.create_executor('vm', mod) +def check_memory_plan(func, check_fn): + # Build Module + mod = tvm.IRModule().from_expr(func) + + # Convert arguments. args = [] for param in func.params: param = param.type_annotation sh = [int(sh) for sh in param.shape] data = np.random.rand(*sh).astype(param.dtype) args.append(tvm.nd.array(data)) - result = ex.evaluate(mod['main'])(*args) + + # Compute without memory planning. + ex = relay.create_executor('vm', mod) + no_plan_result = ex.evaluate(mod['main'])(*args) + + # Compute with memory planning. + with relay.build_config(opt_level=1, disabled_pass=["MemoryPlan"]): + plan_result = ex.evaluate(mod['main'])(*args) + + # Compute Python result. py_res = check_fn(*[arg.asnumpy() for arg in args]) - np.testing.assert_allclose(result.asnumpy(), py_res) + + # First check that the two VM results agree. + np.testing.assert_allclose( + no_plan_result.asnumpy(), + plan_result.asnumpy()) + + # Finally check that the results match the Python result. + np.testing.assert_allclose(plan_result.asnumpy(), py_res) def storage_type(mod): return relay.TypeCall(mod.get_global_type_var("Storage"), []) @@ -58,20 +75,34 @@ def test_add(): x = relay.var('x', shape=(2,)) z = x + x func = relay.Function([x,], z) - check_vm_alloc(func, check_add) + check_memory_plan(func, check_add) def check_add_sub(x, y): z = x + x return z - y + def test_add_sub(): x = relay.var('x', shape=(10,)) y = relay.var('y', shape=(10,)) z = x + x z = z - y func = relay.Function([x, y], z) - check_vm_alloc(func, check_add_sub) + check_memory_plan(func, check_add_sub) + +def check_no_fuse(x, y, w): + z = x + y + return np.matmul(z, np.transpose(w)) + +def test_no_fuse(): + x = relay.var('x', shape=(5, 1)) + y = relay.var('y', shape=(5, 1)) + w = relay.var('w', shape=(5, 1)) + z = x + y + out = relay.op.nn.dense(z, w) + func = relay.Function([x, y, w], out) + check_memory_plan(func, check_no_fuse) if __name__ == "__main__": test_tyck_alloc_tensor() From 1da4e9b926f5883b2ab96c3e0cb9ef28d28b7bdc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 21 Apr 2020 20:05:49 -0700 Subject: [PATCH 02/31] Fix docs --- include/tvm/runtime/vm.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 5aa6ce7ddf74..30cde8ddca49 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -271,6 +271,7 @@ struct Instruction { /*! * \brief Construct an allocate tensor instruction with constant shape. * \param storage The storage to allocate out of. + * \param offset The offset to allocate at. * \param shape The shape of the tensor. * \param dtype The dtype of the tensor. * \param dst The destination register. From 033e8c80465bad49d25df1e5e5910e6c7f190659 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 Apr 2020 13:30:29 -0700 Subject: [PATCH 03/31] Disable aggressive constant eval --- src/relay/backend/vm/compiler.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e3a47e76590c..dd078a18e7f4 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -860,10 +860,11 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { // Manifest the allocations needed for the shape functions. pass_seqs.push_back(transform::ManifestAlloc(host_target)); + // Compute away constant computation introduced by manifesting allocations. + pass_seqs.push_back(transform::FoldConstant()); + // Perform memory planning in order to coalesce/reduce allocations. pass_seqs.push_back(transform::MemoryPlan()); - // Compute away possibly introduced constant computation. - pass_seqs.push_back(transform::FoldConstant()); return transform::Sequential(pass_seqs); } From 7c8cd9b37abe2cfc0f62919fea6b538a6b841372 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Apr 2020 12:13:15 -0700 Subject: [PATCH 04/31] It works --- python/tvm/relay/transform/memory_plan.py | 64 ++++++++++++++++++----- 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index e3625f8241df..237cc557ece9 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -18,7 +18,7 @@ """ A pass for manifesting explicit memory allocations. """ -from typing import Optional, Dict +from typing import Optional, Dict, List, Tuple import attr from ..expr_functor import ExprMutator @@ -26,6 +26,8 @@ from ..function import Function from ... import register_func, ir, cpu from ..._ffi.runtime_ctypes import TVMContext +from ... import IRModule +from .. import transform from . import function_pass @@ -51,7 +53,7 @@ class Region: alignment: Optional[expr.Expr] dtype: Optional[str] ctx: TVMContext - offsets: Dict[expr.Var, expr.Expr] = {} + offsets: Dict[expr.Var, Tuple[expr.Expr, expr.Expr]] def grow( self, old_storage: expr.Var, @@ -81,15 +83,49 @@ def grow( self.ctx = ctx # Record the offset at which we allocate the storage. - self.offsets[old_storage] = self.size + offset_var: expr.RelayExpr = expr.var(f"offset{len(self.offsets)}") + self.offsets[old_storage] = (offset_var, self.size) self.size = self.size + size - def to_expr(self) -> expr.Expr: + def offset_for(self, alloc: expr.Expr) -> expr.Expr: + return self.offsets[alloc][0] + + def to_expr(self, body: expr.Expr) -> expr.Expr: + """ + Generate the prelude code for a region, wrapping the body in it. + + The prelude contains the single allocation for a region, and + all offset computations. + """ + if self.ctx is None: self.ctx = cpu(0) - return op.memory.alloc_storage(self.size, self.alignment, self.ctx, self.dtype) + # Generate bindings for each and every size computation + # we must do this to maintain ANF. + bindings: List[Tuple[expr.Expr, expr.Expr]] = [] + + # First compute the total size. + total_size = expr.var("total_size") + bindings.append((total_size, const_eval(self.size))) + + # Allocate the entire region with a single call. + alloc = op.memory.alloc_storage(total_size, self.alignment, self.ctx, self.dtype) + bindings.append((self.var, alloc)) + + # Generate variables which contain all of the offset math. + # Ensure we constant evaluate away all the math here. + # + # In theory we can support dynamic offsets but this + # requires another round of memory planning and + # potentially colaescing. + for alloc in self.offsets: + (var, offset) = self.offsets[alloc] + offset = const_eval(offset) + bindings.append((var, offset)) + + return mk_let(bindings, body) def iterative_let(let, each_binding, kont): @@ -111,6 +147,10 @@ def mk_let(bindings, body): body = expr.Let(var, value, body) return body +def const_eval(expr): + mod = IRModule.from_expr(expr) + mod = transform.FoldConstant()(mod) + return mod["main"].body class StorageCoalesce(ExprMutator): """ @@ -129,10 +169,10 @@ def __init__(self): super().__init__() self.regions = [] - def enter_scope(self): + def enter_scope(self) -> None: zero = expr.const(0, dtype="int64") region_var = expr.var(f"region{len(self.regions)}") - region = Region(region_var, zero, None, None, None) + region = Region(region_var, zero, None, None, None, {}) self.regions.append(region) def exit_scope(self, body: expr.Expr) -> expr.Expr: @@ -141,12 +181,7 @@ def exit_scope(self, body: expr.Expr) -> expr.Expr: if len(region.offsets) == 0: return body else: - storage_expr = region.to_expr() - assert storage_expr, "can not be None" - assert region.var - assert storage_expr - assert body - return expr.Let(region.var, storage_expr, body) + return region.to_expr(body) def current_region(self) -> Region: return self.regions[-1] @@ -207,7 +242,7 @@ def process_alloc_storage(self, lhs, call): def process_alloc_tensor(self, lhs, call): region = self.current_region() storage, old_offset, shape = call.args - offset = region.offsets[storage] + offset = region.offset_for(storage) assert ( old_offset.data.asnumpy().item() == 0 ), "no offsets should yet be allocated" @@ -225,6 +260,7 @@ def transform_function(self, func, mod, _): sc = StorageCoalesce() before = func func = sc.visit(func) + print(func) return before From 36d75c5760c9a05d73138b9ed2aaf6ea104cb804 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Apr 2020 12:21:33 -0700 Subject: [PATCH 05/31] Fix lint --- python/tvm/relay/transform/memory_plan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 237cc557ece9..cd0c57684dcb 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -147,8 +147,8 @@ def mk_let(bindings, body): body = expr.Let(var, value, body) return body -def const_eval(expr): - mod = IRModule.from_expr(expr) +def const_eval(exp): + mod = IRModule.from_expr(exp) mod = transform.FoldConstant()(mod) return mod["main"].body From 9b636c30725daaa11cdc07df978018328af72111 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 24 Apr 2020 16:39:11 -0700 Subject: [PATCH 06/31] Found issue with dynamic --- include/tvm/relay/transform.h | 4 +- python/tvm/relay/transform/memory_alloc.py | 12 +++- python/tvm/relay/transform/memory_plan.py | 74 ++++++++++++++++------ src/relay/backend/vm/compiler.cc | 10 +-- src/relay/transforms/fold_constant.cc | 14 ++-- 5 files changed, 78 insertions(+), 36 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 461276b79541..48948ee705eb 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -92,10 +92,10 @@ TVM_DLL Pass LazyGradientInit(); /*! * \brief Fold constant expressions. - * + * \param preserve_anf Controls the inlining of let bindings. * \return The pass. */ -TVM_DLL Pass FoldConstant(); +TVM_DLL Pass FoldConstant(bool preserve_anf=false); /*! * \brief Fuse operations into expr into seperate functions. diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 5fb4909b57c7..613963ef331f 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -174,7 +174,7 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): size = self.compute_storage_in_relay( out_shape, out_type.dtype) alignment = self.compute_alignment(out_type.dtype) - sto = scope.let("storage_{i}".format(i=i), self.alloc_storage( + sto = scope.let("storage_{i}".format(i=i), alloc_storage( size, alignment, self.default_context, out_type.dtype)) storages.append(sto) @@ -206,6 +206,13 @@ def visit_call(self, call): # Because we are in ANF we do not need to visit the arguments. scope = self.current_scope() new_args = [self.visit(arg) for arg in call.args] + new_args = [] + for i, arg in enumerate(call.args): + if isinstance(arg, expr.Constant): + new_args.append(scope.let(f"const{i}", arg)) + else: + new_args.append(arg) + ins = expr.Tuple(new_args) ret_type = call.checked_type out_types = flatten_tuple_type(ret_type) @@ -235,11 +242,12 @@ def __init__(self, target_host): self.target_host = target_host def transform_function(self, func, mod, _): - # TODO(@jroesch): Is there a way to do one shot initilization? + # TODO(@jroesch): Is there a way to do one shot initialization? # can we have def pass_init? mod.import_from_std("core.rly") ea = ManifestAllocPass(self.target_host) func = ea.visit(func) + print(func) return func diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index cd0c57684dcb..7cda8456aa8c 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -20,6 +20,7 @@ """ from typing import Optional, Dict, List, Tuple import attr +from collections import defaultdict from ..expr_functor import ExprMutator from .. import op, expr @@ -55,6 +56,12 @@ class Region: ctx: TVMContext offsets: Dict[expr.Var, Tuple[expr.Expr, expr.Expr]] + @staticmethod + def empty(region_no): + zero = expr.const(0, dtype="int64") + region_var = expr.var(f"region{region_no}") + return Region(region_var, zero, None, None, None, {}) + def grow( self, old_storage: expr.Var, size: expr.Expr, alignment: expr.Expr, @@ -108,7 +115,7 @@ def to_expr(self, body: expr.Expr) -> expr.Expr: # First compute the total size. total_size = expr.var("total_size") - bindings.append((total_size, const_eval(self.size))) + bindings.append((total_size, self.size)) # Allocate the entire region with a single call. alloc = op.memory.alloc_storage(total_size, self.alignment, self.ctx, self.dtype) @@ -122,10 +129,10 @@ def to_expr(self, body: expr.Expr) -> expr.Expr: # potentially colaescing. for alloc in self.offsets: (var, offset) = self.offsets[alloc] - offset = const_eval(offset) bindings.append((var, offset)) - return mk_let(bindings, body) + body = mk_let(bindings, body) + return body def iterative_let(let, each_binding, kont): @@ -171,20 +178,23 @@ def __init__(self): def enter_scope(self) -> None: zero = expr.const(0, dtype="int64") - region_var = expr.var(f"region{len(self.regions)}") - region = Region(region_var, zero, None, None, None, {}) - self.regions.append(region) + region_no = len(self.regions) + self.regions.append(defaultdict(lambda: Region.empty(region_no))) def exit_scope(self, body: expr.Expr) -> expr.Expr: """When leaving a scope build a region allocation for the scope.""" - region = self.regions.pop() - if len(region.offsets) == 0: - return body - else: - return region.to_expr(body) + dtype_region = self.regions.pop() + for dtype, region in reversed(list(dtype_region.items())): + if len(region.offsets) == 0: + continue + else: + body = region.to_expr(body) - def current_region(self) -> Region: - return self.regions[-1] + return body + + def current_region(self, dtype) -> Region: + current_scope = self.regions[-1] + return current_scope[dtype] def visit_function(self, fn): """Transform the function body to use region allocation scheme.""" @@ -235,12 +245,12 @@ def process_alloc_storage(self, lhs, call): size, alignment = call.args dtype = call.attrs.dtype ctx = TVMContext(call.attrs.device_type, call.attrs.device_id) - region = self.current_region() + region = self.current_region(call.attrs.dtype) region.grow(lhs, size, alignment, ctx, dtype) return lhs, region.var def process_alloc_tensor(self, lhs, call): - region = self.current_region() + region = self.current_region(call.attrs.dtype) storage, old_offset, shape = call.args offset = region.offset_for(storage) assert ( @@ -251,6 +261,32 @@ def process_alloc_tensor(self, lhs, call): expr.Call(call.op, [region.var, offset, shape], call.attrs), ) +class LiftConstants(ExprMutator): + def __init__(self): + self.i = 0 + self.constants = [] + self.top_level = True + super().__init__() + + def visit_constant(self, const): + var = expr.var(f"const{self.i}") + self.i += 1 + self.constants.append((var, const)) + return var + + def visit_function(self, function): + if self.top_level: + self.top_level = False + body = mk_let(self.constants, self.visit(function.body)) + return Function( + function.params, + body, + function.ret_type, + function.type_params, + function.attrs) + else: + return super().visit_function(function) + @function_pass(opt_level=0) class MemoryPlan: """An explicit pass wrapper around ManifestAlloc.""" @@ -258,10 +294,12 @@ class MemoryPlan: def transform_function(self, func, mod, _): mod.import_from_std("core.rly") sc = StorageCoalesce() - before = func + import pdb; pdb.set_trace() func = sc.visit(func) - print(func) - return before + import pdb; pdb.set_trace() + func = LiftConstants().visit(func) + func = const_eval(func) + return func register_func("relay.transform.MemoryPlan", MemoryPlan) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index dd078a18e7f4..ade36d6bddeb 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -847,20 +847,14 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { pass_seqs.push_back(transform::ManifestAlloc(host_target)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); + // Fuse the shape functions. pass_seqs.push_back(transform::FuseOps()); - // Inline the functions that are lifted to the module scope. We perform this - // pass after all other optimization passes but before the memory allocation - // pass. This is because memory allocation pass will insert `invoke_tvm_op` - // and we use these ops to invoke the symbols in the module generated by - // external codegen. - pass_seqs.push_back(transform::Inline()); - // Manifest the allocations needed for the shape functions. pass_seqs.push_back(transform::ManifestAlloc(host_target)); - // Compute away constant computation introduced by manifesting allocations. + // // Compute away constant computation introduced by manifesting allocations. pass_seqs.push_back(transform::FoldConstant()); // Perform memory planning in order to coalesce/reduce allocations. diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 70df0ed8c2b4..f3e15e71f6ca 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -77,7 +77,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec // or make a more powerful partial evaluator. class ConstantFolder : public ExprMutator { public: - explicit ConstantFolder(FInterpreter executor, IRModule module) + explicit ConstantFolder(FInterpreter executor, IRModule module, bool preserve_anf) : executor_(executor), module_(module), shape_of_op_(Op::Get("shape_of")), @@ -85,11 +85,12 @@ class ConstantFolder : public ExprMutator { shape_func_op_(Op::Get("memory.shape_func")), alloc_tensor_op_(Op::Get("memory.alloc_tensor")), alloc_storage_op_(Op::Get("memory.alloc_storage")), - cast_op_(Op::Get("cast")) {} + cast_op_(Op::Get("cast")), + preserve_anf(preserve_anf) {} Expr VisitExpr_(const LetNode* op) final { Expr value = this->Mutate(op->value); - if (value.as()) { + if (!preserve_anf && value.as()) { memo_[op->var] = value; return this->Mutate(op->body); } else { @@ -171,6 +172,7 @@ class ConstantFolder : public ExprMutator { const Op& alloc_tensor_op_; const Op& alloc_storage_op_; const Op& cast_op_; + bool preserve_anf; // Convert value to expression. Expr ObjectToExpr(const ObjectRef& value) { @@ -267,7 +269,7 @@ class ConstantFolder : public ExprMutator { } }; -Expr FoldConstant(const Expr& expr, const IRModule& mod) { +Expr FoldConstant(const Expr& expr, const IRModule& mod, bool preserve_anf) { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; @@ -276,12 +278,12 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod) { // in case we are already in a build context. With fresh_build_ctx(BuildConfig::Create()); - return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr); + return ConstantFolder(CreateInterpreter(mod, ctx, target), mod, preserve_anf).Mutate(expr); } namespace transform { -Pass FoldConstant() { +Pass FoldConstant(bool preserve_anf) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(FoldConstant(f, m)); From 26d5490632802e3c54822e71f0af8db60f0a8377 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 7 May 2020 01:48:36 -0700 Subject: [PATCH 07/31] Fix the pass, but runtime segfaults --- python/tvm/relay/transform/memory_alloc.py | 7 -- python/tvm/relay/transform/memory_plan.py | 79 ++++++++++++++++++---- python/tvm/relay/transform/transform.py | 2 +- src/relay/analysis/well_formed.cc | 5 ++ src/relay/backend/vm/compiler.cc | 32 ++++++++- src/runtime/vm/vm.cc | 1 + tests/python/relay/test_any.py | 59 ++++++++-------- 7 files changed, 131 insertions(+), 54 deletions(-) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 613963ef331f..8b831e47d8c0 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -206,12 +206,6 @@ def visit_call(self, call): # Because we are in ANF we do not need to visit the arguments. scope = self.current_scope() new_args = [self.visit(arg) for arg in call.args] - new_args = [] - for i, arg in enumerate(call.args): - if isinstance(arg, expr.Constant): - new_args.append(scope.let(f"const{i}", arg)) - else: - new_args.append(arg) ins = expr.Tuple(new_args) ret_type = call.checked_type @@ -247,7 +241,6 @@ def transform_function(self, func, mod, _): mod.import_from_std("core.rly") ea = ManifestAllocPass(self.target_host) func = ea.visit(func) - print(func) return func diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 7cda8456aa8c..667119bd1f33 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -114,7 +114,7 @@ def to_expr(self, body: expr.Expr) -> expr.Expr: bindings: List[Tuple[expr.Expr, expr.Expr]] = [] # First compute the total size. - total_size = expr.var("total_size") + total_size = expr.var(f"total_size{hash(body)}") bindings.append((total_size, self.size)) # Allocate the entire region with a single call. @@ -146,18 +146,20 @@ def iterative_let(let, each_binding, kont): return kont(bindings, let) + def mk_let(bindings, body): for var, value in reversed(bindings): assert var assert value assert body body = expr.Let(var, value, body) + return body -def const_eval(exp): - mod = IRModule.from_expr(exp) +def const_eval(mod, exp): + mod = IRModule.from_expr(exp, type_defs=mod.type_definitions) mod = transform.FoldConstant()(mod) - return mod["main"].body + return mod["main"] class StorageCoalesce(ExprMutator): """ @@ -199,8 +201,8 @@ def current_region(self, dtype) -> Region: def visit_function(self, fn): """Transform the function body to use region allocation scheme.""" func = fn - if func.attrs and getattr(func.attrs, "Primitive", 0) == 1: - return func + if getattr(func.attrs, "Primitive", 0) == 1: + return super().visit_function(func) else: self.enter_scope() body = self.visit(func.body) @@ -224,12 +226,28 @@ def visit_if(self, ite): return expr.If(ite.cond, true_branch, false_branch) + + def mk_let(self, dynamic_regions): + def _mk_let(bindings, body): + for var, value in reversed(bindings): + assert var + assert value + assert body + body = expr.Let(var, value, body) + if var in dynamic_regions: + body = self.exit_scope(body) + + return body + + return _mk_let + def visit_let(self, let): + dynamic_regions = [] def _each_binding(lhs, rhs): if isinstance(rhs, expr.Call) and rhs.op == op.op.get( "memory.alloc_storage" ): - return self.process_alloc_storage(lhs, rhs) + return self.process_alloc_storage(dynamic_regions, lhs, rhs) elif isinstance(rhs, expr.Call) and rhs.op == op.op.get( "memory.alloc_tensor" ): @@ -237,15 +255,20 @@ def _each_binding(lhs, rhs): else: return lhs, rhs - result = iterative_let(let, _each_binding, mk_let) + result = iterative_let(let, _each_binding, self.mk_let(dynamic_regions)) assert result return result - def process_alloc_storage(self, lhs, call): + def process_alloc_storage(self, dynamic_regions, lhs, call): size, alignment = call.args dtype = call.attrs.dtype ctx = TVMContext(call.attrs.device_type, call.attrs.device_id) - region = self.current_region(call.attrs.dtype) + + if not isinstance(size, expr.Constant): + self.enter_scope() + dynamic_regions.append(lhs) + + region = self.current_region(dtype) region.grow(lhs, size, alignment, ctx, dtype) return lhs, region.var @@ -261,7 +284,7 @@ def process_alloc_tensor(self, lhs, call): expr.Call(call.op, [region.var, offset, shape], call.attrs), ) -class LiftConstants(ExprMutator): +class LiftConst(ExprMutator): def __init__(self): self.i = 0 self.constants = [] @@ -275,6 +298,9 @@ def visit_constant(self, const): return var def visit_function(self, function): + if int(getattr(function.attrs, "Primitive", 0)) == 1: + return function + if self.top_level: self.top_level = False body = mk_let(self.constants, self.visit(function.body)) @@ -294,12 +320,35 @@ class MemoryPlan: def transform_function(self, func, mod, _): mod.import_from_std("core.rly") sc = StorageCoalesce() - import pdb; pdb.set_trace() func = sc.visit(func) - import pdb; pdb.set_trace() - func = LiftConstants().visit(func) - func = const_eval(func) + # func = Uniq().visit(func) return func +class Uniq(ExprMutator): + def __init__(self): + self.var_map = {} + self.i = 0 + super().__init__() + + def visit_var(self, var): + if var in self.var_map: + return self.var_map[var] + else: + new_var = expr.Var(f"var{self.i}", type_annotation=var.type_annotation) + self.i += 1 + self.var_map[var] = new_var + return new_var register_func("relay.transform.MemoryPlan", MemoryPlan) + +@function_pass(opt_level=0) +class LiftConstants: + """An explicit pass wrapper around LiftConstants.""" + + def transform_function(self, func, mod, _): + mod.import_from_std("core.rly") + func = LiftConst().visit(func) + return func + + +register_func("relay.transform.LiftConstants", LiftConstants) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 647e999f647a..d68750bfbb08 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -241,7 +241,7 @@ def FoldConstant(): ret : tvm.transform.Pass The registered pass for constant folding. """ - return _ffi_api.FoldConstant() + return _ffi_api.FoldConstant(False) def FuseOps(fuse_opt_level=-1): diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index 33f52c9a8397..516a10e0e7a4 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -53,7 +53,12 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { }; void Bound(const Var& v) { + // std::cout << "HERE " << v << std::endl; if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) { + // std::cout << "WELL FORMED: " << v << std::endl; + // std::cout << "current bindings :" << current_bound.count(v) << std::endl; + // std::cout << "total bindings :" << total_bound.count(v) << std::endl; + // std::cout << "free bindings :" << free.count(v) << std::endl; well_formed = false; } CHECK_GE(scope.size(), 0); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index ade36d6bddeb..216e54f7c014 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -59,14 +59,20 @@ Pass ManifestAlloc(Target target_host) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); CHECK(f != nullptr) << "unable to load allocation manifestation pass"; return (*f)(target_host); -Pass MemoryPlan() { } +Pass MemoryPlan() { auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan"); CHECK(f != nullptr) << "unable to load the memory planning pass"; return (*f)(); } +Pass LiftConstants() { + auto f = tvm::runtime::Registry::Get("relay.transform.LiftConstants"); + CHECK(f != nullptr) << "unable to load the memory planning pass"; + return (*f)(); +} + } // namespace transform namespace vm { @@ -518,6 +524,10 @@ class VMFunctionCompiler : ExprFunctor { CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; auto dtype = alloc_attrs->dtype; + // The storage will be passed dynamically. + this->VisitExpr(args[0]); + auto storage_register = last_register_; + // The storage will be passed dynamically. this->VisitExpr(args[1]); auto offset_register = last_register_; @@ -845,6 +855,7 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { Array pass_seqs; // Manifest the allocations. pass_seqs.push_back(transform::ManifestAlloc(host_target)); + // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -854,12 +865,27 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { // Manifest the allocations needed for the shape functions. pass_seqs.push_back(transform::ManifestAlloc(host_target)); - // // Compute away constant computation introduced by manifesting allocations. - pass_seqs.push_back(transform::FoldConstant()); + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); // Perform memory planning in order to coalesce/reduce allocations. pass_seqs.push_back(transform::MemoryPlan()); + // Compute away constant computation introduced by coalescing allocations. + pass_seqs.push_back(transform::FoldConstant()); + + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); + + // Create allocations for math introduced by dynamic region math. + pass_seqs.push_back(transform::ManifestAlloc(host_target)); + + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); + + // Lift constants to the top-level of the block to simplify VM code generation. + pass_seqs.push_back(transform::LiftConstants()); + return transform::Sequential(pass_seqs); } diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index e1326771aa30..f142bb9c27b4 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -628,6 +628,7 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { PackedFunc VirtualMachine::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { + std::cout << name << std::endl; if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK(exec_) << "The executable is not created yet."; diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 6ce59bbf1c36..10137d417437 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -66,6 +66,7 @@ def verify_any_elemwise(x_shape, x_np_shape, op, np_op): x = relay.var('x', shape=x_shape, dtype=dtype) mod = tvm.IRModule() mod["main"] = relay.Function([x], op(x)) + print(mod["main"]) x_np = np.random.uniform(size=x_np_shape).astype(dtype) res_np = np_op(x_np) for kind in ["debug", "vm"]: @@ -75,8 +76,8 @@ def verify_any_elemwise(x_shape, x_np_shape, op, np_op): def test_any_elemwise(): verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt) - verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative) - verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp) + # verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative) + # verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp) def test_any_broadcast_fail(): # Test broadcast with incompatible values at runtime @@ -130,6 +131,7 @@ def test_any_concat(): z = relay.op.concatenate([xx, yy], axis=0) mod = tvm.IRModule() mod["main"] = relay.Function([x, y], z) + print(mod["main"]) x_np = np.random.uniform(size=(3, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32') ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0) @@ -668,30 +670,31 @@ def _body(i, st): assert "in particular dimension 0 conflicts 2 does not match 1" in str(e) if __name__ == "__main__": - test_any_full() - test_any_broadcast() + # test_any_concat() + # test_any_full() + # test_any_broadcast() test_any_elemwise() - test_any_broadcast_fail() - test_any_concat() - test_any_reshape() - test_any_take() - test_any_tile() - test_any_split() - test_any_shape_of() - test_any_reduce() - test_any_layout_transform() - test_any_expand_dims() - test_any_transpose() - test_any_squeeze() - test_any_reshape_like() - test_any_conv2d_NCHWc() - test_any_pool2d() - test_any_global_pool2d() - test_any_batch_flatten() - test_any_dense() - test_any_pad() - test_any_softmax() - test_fused_ops() - test_arange_with_dynamic_shape() - test_recursive_concat() - test_recursive_concat_with_wrong_annotation() + # test_any_broadcast_fail() + # test_any_concat() + # test_any_reshape() + # test_any_take() + # test_any_tile() + # test_any_split() + # test_any_shape_of() + # test_any_reduce() + # test_any_layout_transform() + # test_any_expand_dims() + # test_any_transpose() + # test_any_squeeze() + # test_any_reshape_like() + # test_any_conv2d_NCHWc() + # test_any_pool2d() + # test_any_global_pool2d() + # test_any_batch_flatten() + # test_any_dense() + # test_any_pad() + # test_any_softmax() + # test_fused_ops() + # test_arange_with_dynamic_shape() + # test_recursive_concat() + # test_recursive_concat_with_wrong_annotation() From b890ec0081a4035e88da8042f26bb1462c287f0e Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 7 May 2020 22:00:09 +0000 Subject: [PATCH 08/31] fix scalar tensor, test_any_elemwise passes --- src/relay/backend/vm/compiler.cc | 8 ++++---- src/runtime/vm/vm.cc | 13 +++++++++---- tests/python/relay/test_any.py | 4 ++-- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 216e54f7c014..7cd9c54bb75d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -546,10 +546,10 @@ class VMFunctionCompiler : ExprFunctor { offset_register, raw_shape, dtype, NewRegister())); - } else { - this->VisitExpr(args[1]); - auto shape_register = last_register_; - Emit(Instruction::AllocTensorReg(storage_register, shape_register, offset_register, dtype, + } else { + this->VisitExpr(args[2]); + auto shape_register = last_register_; + Emit(Instruction::AllocTensorReg( NewRegister())); } }) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index f142bb9c27b4..d48e1379b713 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -981,10 +981,15 @@ void VirtualMachine::RunLoop() { const DLTensor* dl_tensor = shape_tensor.operator->(); CHECK_EQ(dl_tensor->dtype.code, 0u); CHECK_LE(dl_tensor->dtype.bits, 64); - int64_t* dims = reinterpret_cast(dl_tensor->data); - auto num_dims = shape_tensor->shape[0]; - auto shape = std::vector(num_dims); - shape.assign(dims, dims + num_dims); + std::vector shape; + if (dl_tensor->ndim) { + int64_t num_dims = shape_tensor->shape[0]; + int64_t* dims = reinterpret_cast(dl_tensor->data); + shape.resize(num_dims); + shape.assign(dims, dims + num_dims); + } else { + shape.push_back(1); + } auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 10137d417437..998e9cc62c88 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -673,11 +673,11 @@ def _body(i, st): # test_any_concat() # test_any_full() # test_any_broadcast() - test_any_elemwise() + # test_any_elemwise() # test_any_broadcast_fail() # test_any_concat() # test_any_reshape() - # test_any_take() + test_any_take() # test_any_tile() # test_any_split() # test_any_shape_of() From 5eb23373da22b5694e8a898bd8851656b528935e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 7 May 2020 16:36:09 -0700 Subject: [PATCH 09/31] Fix split pass --- python/tvm/relay/transform/memory_plan.py | 17 ++++++++++++++--- tests/python/relay/test_any.py | 8 ++++---- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 667119bd1f33..89f8c3178d07 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -96,7 +96,7 @@ def grow( self.size = self.size + size def offset_for(self, alloc: expr.Expr) -> expr.Expr: - return self.offsets[alloc][0] + return self.offsets.get(alloc, [None])[0] def to_expr(self, body: expr.Expr) -> expr.Expr: """ @@ -198,6 +198,16 @@ def current_region(self, dtype) -> Region: current_scope = self.regions[-1] return current_scope[dtype] + def new_region_and_offset(self, old_storage): + for dtype_region in reversed(self.regions): + for dtype in dtype_region: + region = dtype_region[dtype] + offset = region.offset_for(old_storage) + if offset: + return region, offset + + raise Exception("could not find offset in any valid region") + def visit_function(self, fn): """Transform the function body to use region allocation scheme.""" func = fn @@ -273,9 +283,9 @@ def process_alloc_storage(self, dynamic_regions, lhs, call): return lhs, region.var def process_alloc_tensor(self, lhs, call): - region = self.current_region(call.attrs.dtype) storage, old_offset, shape = call.args - offset = region.offset_for(storage) + region, offset = self.new_region_and_offset(storage) + assert ( old_offset.data.asnumpy().item() == 0 ), "no offsets should yet be allocated" @@ -320,6 +330,7 @@ class MemoryPlan: def transform_function(self, func, mod, _): mod.import_from_std("core.rly") sc = StorageCoalesce() + # func = Uniq().visit(func) func = sc.visit(func) # func = Uniq().visit(func) return func diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 998e9cc62c88..b3124a87243a 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -76,8 +76,8 @@ def verify_any_elemwise(x_shape, x_np_shape, op, np_op): def test_any_elemwise(): verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt) - # verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative) - # verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp) + verify_any_elemwise((relay.Any(), 2), (5, 2), relay.negative, np.negative) + verify_any_elemwise((relay.Any(), relay.Any()), (5, 4), relay.exp, np.exp) def test_any_broadcast_fail(): # Test broadcast with incompatible values at runtime @@ -677,9 +677,9 @@ def _body(i, st): # test_any_broadcast_fail() # test_any_concat() # test_any_reshape() - test_any_take() + # test_any_take() # test_any_tile() - # test_any_split() + test_any_split() # test_any_shape_of() # test_any_reduce() # test_any_layout_transform() From a0e4f973f0528bb4b60d24fdf91b8bf23b339251 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 8 May 2020 14:29:32 -0700 Subject: [PATCH 10/31] Fix 0-rank issues --- python/tvm/relay/expr.py | 4 ++++ python/tvm/relay/transform/memory_alloc.py | 3 +-- python/tvm/relay/transform/memory_plan.py | 3 +++ src/relay/op/memory/memory.cc | 11 ++++++++--- tests/python/relay/test_any.py | 4 ++-- 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index ff1368394917..f69414db6952 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -504,6 +504,10 @@ def const(value, dtype=None): if not isinstance(value, _nd.NDArray): raise ValueError("value has to be scalar or NDArray") + + for dim in value.shape: + assert dim != 0, "Relay constants can not contain a 0 dimension." + return Constant(value) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 8b831e47d8c0..6c081cbac0de 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -96,8 +96,7 @@ def make_static_allocation(self, scope, tensor_type, i): """Allocate a tensor with a statically known shape.""" shape = [int(sh) for sh in tensor_type.shape] if len(shape) == 0: - shape = expr.const(np.array([]).astype( - self.compute_dtype), dtype=self.compute_dtype) + shape = expr.const(np.empty((), dtype=self.compute_dtype), dtype=self.compute_dtype) else: shape = expr.const(np.array(shape), dtype=self.compute_dtype) size = self.compute_storage(tensor_type) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 89f8c3178d07..30133b19bd5c 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -59,6 +59,7 @@ class Region: @staticmethod def empty(region_no): zero = expr.const(0, dtype="int64") + assert len(zero.data.shape) == 0 region_var = expr.var(f"region{region_no}") return Region(region_var, zero, None, None, None, {}) @@ -358,7 +359,9 @@ class LiftConstants: def transform_function(self, func, mod, _): mod.import_from_std("core.rly") + print(func) func = LiftConst().visit(func) + print(func) return func diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 170c087cc27d..0f4f2360015c 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -149,9 +149,14 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs // Third argument should be shape tensor. auto tt = types[2].as(); CHECK(tt != nullptr) << "must be tensor type"; - auto rank = tt->shape[0].as(); - CHECK(rank != nullptr); - auto dims = rank->value; + + // Be careful about having to allocate scalars. + int64_t dims = 0; + if (tt->shape.size() != 0) { + auto rank = tt->shape[0].as(); + CHECK(rank != nullptr); + dims = rank->value; + } // Constant node case. Type alloc_type; diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index b3124a87243a..49af43a159cf 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -678,8 +678,8 @@ def _body(i, st): # test_any_concat() # test_any_reshape() # test_any_take() - # test_any_tile() - test_any_split() + test_any_tile() + # test_any_split() # test_any_shape_of() # test_any_reduce() # test_any_layout_transform() From e0371e37d4cf729cfb33d3ca53f0af7f037fdc5e Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 8 May 2020 14:52:04 -0700 Subject: [PATCH 11/31] Fix --- src/relay/op/memory/memory.cc | 1 + src/runtime/vm/vm.cc | 8 +++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 0f4f2360015c..823bcaffd9e5 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -163,6 +163,7 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs if (alloc_attrs->const_shape.defined()) { auto con = alloc_attrs->const_shape; auto sh = FromConstShape(con); + CHECK_EQ(sh.size(), dims); Array out_shape; for (auto i = 0u; i < dims; i++) { out_shape.push_back(tvm::Integer(sh[i])); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index d48e1379b713..48a30b2d6fb8 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -987,15 +987,17 @@ void VirtualMachine::RunLoop() { int64_t* dims = reinterpret_cast(dl_tensor->data); shape.resize(num_dims); shape.assign(dims, dims + num_dims); - } else { - shape.push_back(1); } auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); auto offset = LoadScalarInt(instr.alloc_tensor.offset); auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor_reg.dtype); - + std::cout << "shape = ("; + for (auto sh : obj.Shape()) { + std::cout << sh << ","; + } + std::cout << ")"; WriteRegister(instr.dst, obj); pc_++; goto main_loop; From 8d584700e03686193e61510c810f2e82e0e5f408 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 8 May 2020 22:09:47 +0000 Subject: [PATCH 12/31] debug --- src/runtime/vm/vm.cc | 50 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 48a30b2d6fb8..881fc4c1f519 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -773,12 +773,37 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In } } else { auto nd_array = Downcast(args[i]); + const DLTensor* tensor = nd_array.operator->(); + LOG(INFO) << "argsss:: " << tensor->ndim << " size:: " << runtime::GetDataSize(*tensor); + + int64_t* shapes = reinterpret_cast(tensor->shape); + for (auto i = 0; i < tensor->ndim; i++) { + std::cout << shapes[i] << " "; + } + + std::cout << std::endl << std::endl; + + if (tensor->dtype.bits == 32) { + float* data = reinterpret_cast(tensor->data); + for (uint64_t i = 0; i < GetDataSize(*tensor) / (tensor->dtype.bits / 8); i++) { + std::cout << data[i] << " "; + } + std::cout << std::endl; + } else { + int64_t* data = reinterpret_cast(tensor->data); + for (uint64_t i = 0; i < GetDataSize(*tensor) / (tensor->dtype.bits / 8); i++) { + std::cout << data[i] << " "; + } + std::cout << std::endl; + } setter(idx++, nd_array); } } TVMRetValue rv; + LOG(INFO) << "calling::: " << packed_index; func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); + LOG(INFO) << "calling::: " << packed_index << "donnneee"; } void VirtualMachine::LoadExecutable(const Executable* exec) { @@ -798,6 +823,7 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { } tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true); CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name; + LOG(INFO) << "----:: " << packed_index << " " << packed_name; packed_funcs_[packed_index] = pf; } } @@ -837,7 +863,7 @@ void VirtualMachine::RunLoop() { while (true) { main_loop: auto const& instr = code_[this->pc_]; - DLOG(INFO) << "Executing(" << pc_ << "): " << instr; + LOG(INFO) << "Executing(" << pc_ << "): " << instr; #if USE_RELAY_DEBUG InstructionPrint(std::cout, instr); #endif // USE_RELAY_DEBUG @@ -855,6 +881,18 @@ void VirtualMachine::RunLoop() { } case Opcode::LoadConst: { auto constant_obj = exec_->constants[instr.const_index]; + auto arr = Downcast(constant_obj); + const DLTensor* tensor = arr.operator->(); + if (tensor->ndim == 0) { + LOG(INFO) << "const:: " << reinterpret_cast(tensor->data)[0]; + } else { + LOG(INFO) << "const:: " << tensor->ndim << " " + << reinterpret_cast(tensor->shape)[0]; + int64_t* data = reinterpret_cast(tensor->data); + for (auto i = 0; i < reinterpret_cast(tensor->shape)[0]; i++) { + std::cout << data[i] << " "; + } + } // We cache the allocated object in the constant pool. To measure, the // first iteration will set the pool up. The other iterations will // directly reuse the allocated objects. @@ -989,15 +1027,25 @@ void VirtualMachine::RunLoop() { shape.assign(dims, dims + num_dims); } + LOG(INFO) << "input:: " << instr.dst << " " << shape.size() << " " << dl_tensor->ndim; + for (auto i = 0; i < shape.size(); i++) { + LOG(INFO) << shape[i]; + } + auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); auto offset = LoadScalarInt(instr.alloc_tensor.offset); auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor_reg.dtype); + + const DLTensor* tensor = obj.operator->(); + LOG(INFO) << "output:: " << tensor->ndim << " size:: " << runtime::GetDataSize(*tensor); + std::cout << "shape = ("; for (auto sh : obj.Shape()) { std::cout << sh << ","; } std::cout << ")"; + WriteRegister(instr.dst, obj); pc_++; goto main_loop; From 789b18ba5fabb4eb8da8acacdd8dd14fcf899468 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 8 May 2020 22:57:00 +0000 Subject: [PATCH 13/31] apply Haichen's patch and clean up --- include/tvm/relay/transform.h | 2 +- python/tvm/relay/transform/memory_plan.py | 12 ++--- src/relay/backend/vm/compiler.cc | 9 ++++ src/runtime/vm/vm.cc | 56 +---------------------- tests/python/relay/test_any.py | 55 +++++++++++----------- 5 files changed, 44 insertions(+), 90 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 48948ee705eb..a96fe4c79c69 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -95,7 +95,7 @@ TVM_DLL Pass LazyGradientInit(); * \param preserve_anf Controls the inlining of let bindings. * \return The pass. */ -TVM_DLL Pass FoldConstant(bool preserve_anf=false); +TVM_DLL Pass FoldConstant(bool preserve_anf = false); /*! * \brief Fuse operations into expr into seperate functions. diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 30133b19bd5c..64f7435b846a 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -19,8 +19,8 @@ A pass for manifesting explicit memory allocations. """ from typing import Optional, Dict, List, Tuple -import attr from collections import defaultdict +import attr from ..expr_functor import ExprMutator from .. import op, expr @@ -90,11 +90,14 @@ def grow( assert ctx self.ctx = ctx + new_size = (size + self.alignment - expr.const(1, "int64")) \ + / self.alignment * self.alignment + # Record the offset at which we allocate the storage. offset_var: expr.RelayExpr = expr.var(f"offset{len(self.offsets)}") self.offsets[old_storage] = (offset_var, self.size) - self.size = self.size + size + self.size = self.size + new_size def offset_for(self, alloc: expr.Expr) -> expr.Expr: return self.offsets.get(alloc, [None])[0] @@ -180,14 +183,13 @@ def __init__(self): self.regions = [] def enter_scope(self) -> None: - zero = expr.const(0, dtype="int64") region_no = len(self.regions) self.regions.append(defaultdict(lambda: Region.empty(region_no))) def exit_scope(self, body: expr.Expr) -> expr.Expr: """When leaving a scope build a region allocation for the scope.""" dtype_region = self.regions.pop() - for dtype, region in reversed(list(dtype_region.items())): + for _, region in reversed(list(dtype_region.items())): if len(region.offsets) == 0: continue else: @@ -359,9 +361,7 @@ class LiftConstants: def transform_function(self, func, mod, _): mod.import_from_std("core.rly") - print(func) func = LiftConst().visit(func) - print(func) return func diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 7cd9c54bb75d..bec6e7a9d98c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -294,6 +294,15 @@ class VMFunctionCompiler : ExprFunctor { } void VisitExpr_(const ConstantNode* const_node) { + // Check the shape is valid + NDArray data = const_node->data; + const DLTensor* tensor = data.operator->(); + if (tensor->ndim > 0) { + int64_t* shapes = reinterpret_cast(tensor->shape); + for (auto i = 0; i < tensor->ndim; i++) { + CHECK_GT(shapes[i], 0U); + } + } size_t konst_idx = context_->constants.size(); context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 881fc4c1f519..272f565fe37c 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -536,7 +536,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { case Opcode::AllocTensorReg: { os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $" << instr.alloc_tensor_reg.storage << " $" - << instr.alloc_tensor_reg.offset << " " + << instr.alloc_tensor_reg.offset << " $" << instr.alloc_tensor_reg.shape_register << " "; DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; @@ -628,7 +628,6 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { PackedFunc VirtualMachine::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { - std::cout << name << std::endl; if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { CHECK(exec_) << "The executable is not created yet."; @@ -773,37 +772,12 @@ void VirtualMachine::InvokePacked(Index packed_index, const PackedFunc& func, In } } else { auto nd_array = Downcast(args[i]); - const DLTensor* tensor = nd_array.operator->(); - LOG(INFO) << "argsss:: " << tensor->ndim << " size:: " << runtime::GetDataSize(*tensor); - - int64_t* shapes = reinterpret_cast(tensor->shape); - for (auto i = 0; i < tensor->ndim; i++) { - std::cout << shapes[i] << " "; - } - - std::cout << std::endl << std::endl; - - if (tensor->dtype.bits == 32) { - float* data = reinterpret_cast(tensor->data); - for (uint64_t i = 0; i < GetDataSize(*tensor) / (tensor->dtype.bits / 8); i++) { - std::cout << data[i] << " "; - } - std::cout << std::endl; - } else { - int64_t* data = reinterpret_cast(tensor->data); - for (uint64_t i = 0; i < GetDataSize(*tensor) / (tensor->dtype.bits / 8); i++) { - std::cout << data[i] << " "; - } - std::cout << std::endl; - } setter(idx++, nd_array); } } TVMRetValue rv; - LOG(INFO) << "calling::: " << packed_index; func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); - LOG(INFO) << "calling::: " << packed_index << "donnneee"; } void VirtualMachine::LoadExecutable(const Executable* exec) { @@ -823,7 +797,6 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { } tvm::runtime::PackedFunc pf = lib.GetFunction(packed_name, true); CHECK(pf != nullptr) << "Cannot find function in module: " << packed_name; - LOG(INFO) << "----:: " << packed_index << " " << packed_name; packed_funcs_[packed_index] = pf; } } @@ -863,7 +836,7 @@ void VirtualMachine::RunLoop() { while (true) { main_loop: auto const& instr = code_[this->pc_]; - LOG(INFO) << "Executing(" << pc_ << "): " << instr; + DLOG(INFO) << "Executing(" << pc_ << "): " << instr; #if USE_RELAY_DEBUG InstructionPrint(std::cout, instr); #endif // USE_RELAY_DEBUG @@ -882,17 +855,6 @@ void VirtualMachine::RunLoop() { case Opcode::LoadConst: { auto constant_obj = exec_->constants[instr.const_index]; auto arr = Downcast(constant_obj); - const DLTensor* tensor = arr.operator->(); - if (tensor->ndim == 0) { - LOG(INFO) << "const:: " << reinterpret_cast(tensor->data)[0]; - } else { - LOG(INFO) << "const:: " << tensor->ndim << " " - << reinterpret_cast(tensor->shape)[0]; - int64_t* data = reinterpret_cast(tensor->data); - for (auto i = 0; i < reinterpret_cast(tensor->shape)[0]; i++) { - std::cout << data[i] << " "; - } - } // We cache the allocated object in the constant pool. To measure, the // first iteration will set the pool up. The other iterations will // directly reuse the allocated objects. @@ -1027,25 +989,11 @@ void VirtualMachine::RunLoop() { shape.assign(dims, dims + num_dims); } - LOG(INFO) << "input:: " << instr.dst << " " << shape.size() << " " << dl_tensor->ndim; - for (auto i = 0; i < shape.size(); i++) { - LOG(INFO) << shape[i]; - } - auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); auto offset = LoadScalarInt(instr.alloc_tensor.offset); auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor_reg.dtype); - const DLTensor* tensor = obj.operator->(); - LOG(INFO) << "output:: " << tensor->ndim << " size:: " << runtime::GetDataSize(*tensor); - - std::cout << "shape = ("; - for (auto sh : obj.Shape()) { - std::cout << sh << ","; - } - std::cout << ")"; - WriteRegister(instr.dst, obj); pc_++; goto main_loop; diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 49af43a159cf..6ce59bbf1c36 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -66,7 +66,6 @@ def verify_any_elemwise(x_shape, x_np_shape, op, np_op): x = relay.var('x', shape=x_shape, dtype=dtype) mod = tvm.IRModule() mod["main"] = relay.Function([x], op(x)) - print(mod["main"]) x_np = np.random.uniform(size=x_np_shape).astype(dtype) res_np = np_op(x_np) for kind in ["debug", "vm"]: @@ -131,7 +130,6 @@ def test_any_concat(): z = relay.op.concatenate([xx, yy], axis=0) mod = tvm.IRModule() mod["main"] = relay.Function([x, y], z) - print(mod["main"]) x_np = np.random.uniform(size=(3, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32') ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0) @@ -670,31 +668,30 @@ def _body(i, st): assert "in particular dimension 0 conflicts 2 does not match 1" in str(e) if __name__ == "__main__": - # test_any_concat() - # test_any_full() - # test_any_broadcast() - # test_any_elemwise() - # test_any_broadcast_fail() - # test_any_concat() - # test_any_reshape() - # test_any_take() + test_any_full() + test_any_broadcast() + test_any_elemwise() + test_any_broadcast_fail() + test_any_concat() + test_any_reshape() + test_any_take() test_any_tile() - # test_any_split() - # test_any_shape_of() - # test_any_reduce() - # test_any_layout_transform() - # test_any_expand_dims() - # test_any_transpose() - # test_any_squeeze() - # test_any_reshape_like() - # test_any_conv2d_NCHWc() - # test_any_pool2d() - # test_any_global_pool2d() - # test_any_batch_flatten() - # test_any_dense() - # test_any_pad() - # test_any_softmax() - # test_fused_ops() - # test_arange_with_dynamic_shape() - # test_recursive_concat() - # test_recursive_concat_with_wrong_annotation() + test_any_split() + test_any_shape_of() + test_any_reduce() + test_any_layout_transform() + test_any_expand_dims() + test_any_transpose() + test_any_squeeze() + test_any_reshape_like() + test_any_conv2d_NCHWc() + test_any_pool2d() + test_any_global_pool2d() + test_any_batch_flatten() + test_any_dense() + test_any_pad() + test_any_softmax() + test_fused_ops() + test_arange_with_dynamic_shape() + test_recursive_concat() + test_recursive_concat_with_wrong_annotation() From 5de25c88d4c5130be7068420ab56674d1a383fc1 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Fri, 8 May 2020 23:32:27 +0000 Subject: [PATCH 14/31] lintgit add . --- python/tvm/relay/transform/memory_plan.py | 43 ++++++++--------------- src/relay/analysis/well_formed.cc | 5 --- src/runtime/vm/vm.cc | 5 ++- 3 files changed, 16 insertions(+), 37 deletions(-) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 64f7435b846a..9bd366d43c23 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -190,9 +190,7 @@ def exit_scope(self, body: expr.Expr) -> expr.Expr: """When leaving a scope build a region allocation for the scope.""" dtype_region = self.regions.pop() for _, region in reversed(list(dtype_region.items())): - if len(region.offsets) == 0: - continue - else: + if len(region.offsets) != 0: body = region.to_expr(body) return body @@ -241,6 +239,7 @@ def visit_if(self, ite): def mk_let(self, dynamic_regions): + """Let bind the dynamic regions""" def _mk_let(bindings, body): for var, value in reversed(bindings): assert var @@ -273,6 +272,7 @@ def _each_binding(lhs, rhs): return result def process_alloc_storage(self, dynamic_regions, lhs, call): + """Process alloc_storage""" size, alignment = call.args dtype = call.attrs.dtype ctx = TVMContext(call.attrs.device_type, call.attrs.device_id) @@ -286,6 +286,7 @@ def process_alloc_storage(self, dynamic_regions, lhs, call): return lhs, region.var def process_alloc_tensor(self, lhs, call): + """Process alloc tensor. Region and offset are computed""" storage, old_offset, shape = call.args region, offset = self.new_region_and_offset(storage) @@ -298,6 +299,7 @@ def process_alloc_tensor(self, lhs, call): ) class LiftConst(ExprMutator): + """A internal pass to lift constants to the top level of function.""" def __init__(self): self.i = 0 self.constants = [] @@ -310,21 +312,21 @@ def visit_constant(self, const): self.constants.append((var, const)) return var - def visit_function(self, function): - if int(getattr(function.attrs, "Primitive", 0)) == 1: - return function + def visit_function(self, fn): + if int(getattr(fn.attrs, "Primitive", 0)) == 1: + return fn if self.top_level: self.top_level = False - body = mk_let(self.constants, self.visit(function.body)) + body = mk_let(self.constants, self.visit(fn.body)) return Function( - function.params, + fn.params, body, - function.ret_type, - function.type_params, - function.attrs) + fn.ret_type, + fn.type_params, + fn.attrs) else: - return super().visit_function(function) + return super().visit_function(fn) @function_pass(opt_level=0) class MemoryPlan: @@ -333,26 +335,9 @@ class MemoryPlan: def transform_function(self, func, mod, _): mod.import_from_std("core.rly") sc = StorageCoalesce() - # func = Uniq().visit(func) func = sc.visit(func) - # func = Uniq().visit(func) return func -class Uniq(ExprMutator): - def __init__(self): - self.var_map = {} - self.i = 0 - super().__init__() - - def visit_var(self, var): - if var in self.var_map: - return self.var_map[var] - else: - new_var = expr.Var(f"var{self.i}", type_annotation=var.type_annotation) - self.i += 1 - self.var_map[var] = new_var - return new_var - register_func("relay.transform.MemoryPlan", MemoryPlan) @function_pass(opt_level=0) diff --git a/src/relay/analysis/well_formed.cc b/src/relay/analysis/well_formed.cc index 516a10e0e7a4..33f52c9a8397 100644 --- a/src/relay/analysis/well_formed.cc +++ b/src/relay/analysis/well_formed.cc @@ -53,12 +53,7 @@ class WellFormedChecker : private ExprVisitor, PatternVisitor { }; void Bound(const Var& v) { - // std::cout << "HERE " << v << std::endl; if (current_bound.count(v) != 0 || total_bound.count(v) != 0 || free.count(v) != 0) { - // std::cout << "WELL FORMED: " << v << std::endl; - // std::cout << "current bindings :" << current_bound.count(v) << std::endl; - // std::cout << "total bindings :" << total_bound.count(v) << std::endl; - // std::cout << "free bindings :" << free.count(v) << std::endl; well_formed = false; } CHECK_GE(scope.size(), 0); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 272f565fe37c..dff53c8787e8 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -526,8 +526,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::AllocTensor: { - os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " [" - << instr.alloc_tensor.storage << " " + os << "alloc_tensor $" << instr.dst << " $" + << instr.alloc_tensor.storage << " $" << instr.alloc_tensor.offset << " [" << StrJoin(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] "; DLDatatypePrint(os, instr.alloc_tensor.dtype); @@ -854,7 +854,6 @@ void VirtualMachine::RunLoop() { } case Opcode::LoadConst: { auto constant_obj = exec_->constants[instr.const_index]; - auto arr = Downcast(constant_obj); // We cache the allocated object in the constant pool. To measure, the // first iteration will set the pool up. The other iterations will // directly reuse the allocated objects. From cf4b82c3b94c9a32f874a576426ba25743c03d54 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Sat, 9 May 2020 00:17:13 +0000 Subject: [PATCH 15/31] fix serializer and test_tyck_alloc_tensor test --- src/runtime/vm/executable.cc | 6 +++--- tests/python/relay/test_memory_passes.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 6576ef5cb81d..b037d9f56134 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -307,7 +307,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { break; } case Opcode::AllocTensor: { - // Number of fields = 6 + instr.alloc_tensor.ndim + // Number of fields = 7 + instr.alloc_tensor.ndim fields.push_back(instr.alloc_tensor.storage); fields.push_back(instr.alloc_tensor.offset); // Save `DLDataType` and the dst register. @@ -550,7 +550,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { return Instruction::InvokePacked(packed_index, arity, output_size, args); } case Opcode::AllocTensor: { - // Number of fields = 6 + instr.alloc_tensor.ndim + // Number of fields = 7 + instr.alloc_tensor.ndim DCHECK_GE(instr.fields.size(), 7U); DCHECK_EQ(instr.fields.size(), 7U + static_cast(instr.fields[4])); @@ -565,7 +565,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { Index ndim = instr.fields[5]; RegName dst = instr.fields[6]; - std::vector shape = ExtractFields(instr.fields, 6, ndim); + std::vector shape = ExtractFields(instr.fields, 7, ndim); return Instruction::AllocTensor(storage_reg, offset, shape, dtype, dst); } diff --git a/tests/python/relay/test_memory_passes.py b/tests/python/relay/test_memory_passes.py index 9de9b2e38c07..70e7086cef4d 100644 --- a/tests/python/relay/test_memory_passes.py +++ b/tests/python/relay/test_memory_passes.py @@ -63,7 +63,7 @@ def test_tyck_alloc_tensor(): mod.import_from_std("core.rly") sto = relay.Var("x", storage_type(mod)) sh = relay.const(np.array([1, 2]), dtype="int64") - at = relay.op.memory.alloc_tensor(sto, sh) + at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh) mod['main'] = relay.Function([sto], at) relay.transform.InferType()(mod) From bb82251bc74943f78092bed08777dbc4b5bf678d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 11 May 2020 16:29:38 -0700 Subject: [PATCH 16/31] Fix the constant lift pass in presence of closures --- python/tvm/relay/transform/memory_plan.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 9bd366d43c23..f3d994701b92 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -316,17 +316,17 @@ def visit_function(self, fn): if int(getattr(fn.attrs, "Primitive", 0)) == 1: return fn - if self.top_level: - self.top_level = False - body = mk_let(self.constants, self.visit(fn.body)) - return Function( - fn.params, - body, - fn.ret_type, - fn.type_params, - fn.attrs) - else: - return super().visit_function(fn) + outer_constant = self.constants + self.constants = [] + body = mk_let(self.constants, self.visit(fn.body)) + self.constants = outer_constant + + return Function( + fn.params, + body, + fn.ret_type, + fn.type_params, + fn.attrs) @function_pass(opt_level=0) class MemoryPlan: From 7454618a6fe5ca1621e052aafdeb30c449c0fd90 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 11 May 2020 16:43:22 -0700 Subject: [PATCH 17/31] Restore old finder --- include/tvm/relay/transform.h | 4 ++-- src/relay/transforms/fold_constant.cc | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a96fe4c79c69..461276b79541 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -92,10 +92,10 @@ TVM_DLL Pass LazyGradientInit(); /*! * \brief Fold constant expressions. - * \param preserve_anf Controls the inlining of let bindings. + * * \return The pass. */ -TVM_DLL Pass FoldConstant(bool preserve_anf = false); +TVM_DLL Pass FoldConstant(); /*! * \brief Fuse operations into expr into seperate functions. diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index f3e15e71f6ca..70df0ed8c2b4 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -77,7 +77,7 @@ TVM_REGISTER_GLOBAL("relay.analysis.check_constant").set_body_typed(ConstantChec // or make a more powerful partial evaluator. class ConstantFolder : public ExprMutator { public: - explicit ConstantFolder(FInterpreter executor, IRModule module, bool preserve_anf) + explicit ConstantFolder(FInterpreter executor, IRModule module) : executor_(executor), module_(module), shape_of_op_(Op::Get("shape_of")), @@ -85,12 +85,11 @@ class ConstantFolder : public ExprMutator { shape_func_op_(Op::Get("memory.shape_func")), alloc_tensor_op_(Op::Get("memory.alloc_tensor")), alloc_storage_op_(Op::Get("memory.alloc_storage")), - cast_op_(Op::Get("cast")), - preserve_anf(preserve_anf) {} + cast_op_(Op::Get("cast")) {} Expr VisitExpr_(const LetNode* op) final { Expr value = this->Mutate(op->value); - if (!preserve_anf && value.as()) { + if (value.as()) { memo_[op->var] = value; return this->Mutate(op->body); } else { @@ -172,7 +171,6 @@ class ConstantFolder : public ExprMutator { const Op& alloc_tensor_op_; const Op& alloc_storage_op_; const Op& cast_op_; - bool preserve_anf; // Convert value to expression. Expr ObjectToExpr(const ObjectRef& value) { @@ -269,7 +267,7 @@ class ConstantFolder : public ExprMutator { } }; -Expr FoldConstant(const Expr& expr, const IRModule& mod, bool preserve_anf) { +Expr FoldConstant(const Expr& expr, const IRModule& mod) { DLContext ctx; ctx.device_type = kDLCPU; ctx.device_id = 0; @@ -278,12 +276,12 @@ Expr FoldConstant(const Expr& expr, const IRModule& mod, bool preserve_anf) { // in case we are already in a build context. With fresh_build_ctx(BuildConfig::Create()); - return ConstantFolder(CreateInterpreter(mod, ctx, target), mod, preserve_anf).Mutate(expr); + return ConstantFolder(CreateInterpreter(mod, ctx, target), mod).Mutate(expr); } namespace transform { -Pass FoldConstant(bool preserve_anf) { +Pass FoldConstant() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast(FoldConstant(f, m)); From add6b0ed90f0e520b5d72ca625deb1bb14205574 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 11 May 2020 17:32:31 -0700 Subject: [PATCH 18/31] Fix rebase issues --- src/relay/backend/vm/compiler.cc | 6 +++--- src/runtime/vm/vm.cc | 22 +++++++--------------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index bec6e7a9d98c..55dac20b22e2 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -558,10 +558,10 @@ class VMFunctionCompiler : ExprFunctor { } else { this->VisitExpr(args[2]); auto shape_register = last_register_; - Emit(Instruction::AllocTensorReg( + Emit(Instruction::AllocTensorReg(storage_register, offset_register, shape_register, dtype, NewRegister())); - } - }) + } + }) .Match("memory.alloc_storage", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { CHECK_EQ(args.size(), 2); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index dff53c8787e8..fd4cf5edb399 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -311,11 +311,9 @@ Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index out return instr; } -Instruction Instruction::AllocTensor(RegName storage, const std::vector& shape, - RegName storage, - RegName offset, - const std::vector& shape, - DLDataType dtype, Index dst) { +Instruction Instruction::AllocTensor(RegName storage, RegName offset, + const std::vector& shape, DLDataType dtype, + Index dst) { Instruction instr; instr.op = Opcode::AllocTensor; instr.dst = dst; @@ -330,11 +328,8 @@ Instruction Instruction::AllocTensor(RegName storage, const std::vector return instr; } -Instruction Instruction::AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, - RegName storage, - RegName offset, - RegName shape_register, - Index dst) { +Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName shape_register, + DLDataType dtype, Index dst) { Instruction instr; instr.op = Opcode::AllocTensorReg; instr.dst = dst; @@ -526,8 +521,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::AllocTensor: { - os << "alloc_tensor $" << instr.dst << " $" - << instr.alloc_tensor.storage << " $" + os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " $" << instr.alloc_tensor.offset << " [" << StrJoin(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] "; DLDatatypePrint(os, instr.alloc_tensor.dtype); @@ -535,8 +529,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { } case Opcode::AllocTensorReg: { os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $" - << instr.alloc_tensor_reg.storage << " $" - << instr.alloc_tensor_reg.offset << " $" + << instr.alloc_tensor_reg.storage << " $" << instr.alloc_tensor_reg.offset << " $" << instr.alloc_tensor_reg.shape_register << " "; DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; @@ -801,7 +794,6 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { } } - void VirtualMachine::Init(const std::vector& ctxs) { ctxs_ = ctxs; } inline void VirtualMachine::WriteRegister(Index r, const ObjectRef& val) { From 0d55d7b3fdae3e38f7771ac866e9747f22ef198d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 11 May 2020 17:35:39 -0700 Subject: [PATCH 19/31] Fix --- src/relay/backend/vm/compiler.cc | 79 +++++++++++++++----------------- 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 55dac20b22e2..fcefa4b98648 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -42,10 +42,10 @@ #include #include "../../backend/compile_engine.h" -#include "../../transforms/pass_util.h" #include "../../op/op_common.h" -#include "compiler.h" +#include "../../transforms/pass_util.h" #include "../utils.h" +#include "compiler.h" namespace tvm { namespace relay { @@ -523,45 +523,41 @@ class VMFunctionCompiler : ExprFunctor { CHECK_EQ(args.size(), 3); EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); }) - .Match( - "memory.alloc_tensor", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 3); - - // Get the attributes. - auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; - auto dtype = alloc_attrs->dtype; - - // The storage will be passed dynamically. - this->VisitExpr(args[0]); - auto storage_register = last_register_; - - // The storage will be passed dynamically. - this->VisitExpr(args[1]); - auto offset_register = last_register_; - - // If the shape is constant then we will emit a static tensor allocation - // instruction. - auto const_shape = args[2].as(); - - if (const_shape) { - NDArray shape = const_shape->data; - // TODO(@jroesch): we need to get an RFC done to standarize shape dtype - std::vector raw_shape = ToAllocTensorShape(shape); - // Add context field. - Emit(Instruction::AllocTensor( - storage_register, - offset_register, - raw_shape, - dtype, NewRegister())); - } else { - this->VisitExpr(args[2]); - auto shape_register = last_register_; - Emit(Instruction::AllocTensorReg(storage_register, offset_register, shape_register, dtype, - NewRegister())); - } - }) + .Match("memory.alloc_tensor", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); + + // Get the attributes. + auto alloc_attrs = attrs.as(); + CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + auto dtype = alloc_attrs->dtype; + + // The storage will be passed dynamically. + this->VisitExpr(args[0]); + auto storage_register = last_register_; + + // The storage will be passed dynamically. + this->VisitExpr(args[1]); + auto offset_register = last_register_; + + // If the shape is constant then we will emit a static tensor allocation + // instruction. + auto const_shape = args[2].as(); + + if (const_shape) { + NDArray shape = const_shape->data; + // TODO(@jroesch): we need to get an RFC done to standarize shape dtype + std::vector raw_shape = ToAllocTensorShape(shape); + // Add context field. + Emit(Instruction::AllocTensor(storage_register, offset_register, raw_shape, + dtype, NewRegister())); + } else { + this->VisitExpr(args[2]); + auto shape_register = last_register_; + Emit(Instruction::AllocTensorReg(storage_register, offset_register, + shape_register, dtype, NewRegister())); + } + }) .Match("memory.alloc_storage", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { CHECK_EQ(args.size(), 2); @@ -958,7 +954,6 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(this->target_host_)); transform::Sequential seq(pass_seqs); From a90faaadc8e2f3e675a986feb53168fb8597a495 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 11 May 2020 18:03:54 -0700 Subject: [PATCH 20/31] Fix --- python/tvm/relay/transform/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index d68750bfbb08..647e999f647a 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -241,7 +241,7 @@ def FoldConstant(): ret : tvm.transform.Pass The registered pass for constant folding. """ - return _ffi_api.FoldConstant(False) + return _ffi_api.FoldConstant() def FuseOps(fuse_opt_level=-1): From 4426b223bac0e77a58581878cec91f9966adbe70 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 12 May 2020 13:17:44 -0700 Subject: [PATCH 21/31] Fix issue coercing the shapes incorrectly from i64 to i32 --- include/tvm/runtime/ndarray.h | 2 ++ src/runtime/ndarray.cc | 1 + src/runtime/vm/vm.cc | 36 ++++++++++++++++++++++++----------- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 8db93b46e934..d1bc528aed5b 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -160,6 +161,7 @@ class NDArray : public ObjectRef { TVMStreamHandle stream = nullptr); TVM_DLL std::vector Shape() const; + TVM_DLL runtime::DataType DataType() const; // internal namespace struct Internal; diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index d97d01b0feab..30c2990691e6 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -236,6 +236,7 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str } std::vector NDArray::Shape() const { return get_mutable()->shape_; } +runtime::DataType NDArray::DataType() const { return runtime::DataType(get_mutable()->dl_tensor.dtype); } TVM_REGISTER_OBJECT_TYPE(NDArray::Container); diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index fd4cf5edb399..871551cc93d6 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -619,6 +619,30 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { return src; } +std::vector ToShape(NDArray shape_tensor) { + std::vector shape; + auto dtype = shape_tensor.DataType(); + CHECK(shape_tensor.Shape().size() == 1) + << "shape tensor should be a k-length vector."; + + int64_t ndim = shape_tensor.Shape().at(0); + + if (ndim == 0) { return shape; } else { shape.resize(ndim); } + + const DLTensor* dl_tensor = shape_tensor.operator->(); + if (dtype.is_int() && dtype.bits() == 32 && dtype.lanes() == 1) { + int32_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else if (dtype.is_int() && dtype.bits() == 64 && dtype.lanes() == 1) { + int64_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else { + LOG(FATAL) << "invalid shape tensor datatype: " << dtype; + } + + return shape; +} + PackedFunc VirtualMachine::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "invoke") { @@ -969,17 +993,7 @@ void VirtualMachine::RunLoop() { auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); const auto shape_arr = Downcast(shape_tensor_obj); NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx); - const DLTensor* dl_tensor = shape_tensor.operator->(); - CHECK_EQ(dl_tensor->dtype.code, 0u); - CHECK_LE(dl_tensor->dtype.bits, 64); - std::vector shape; - if (dl_tensor->ndim) { - int64_t num_dims = shape_tensor->shape[0]; - int64_t* dims = reinterpret_cast(dl_tensor->data); - shape.resize(num_dims); - shape.assign(dims, dims + num_dims); - } - + auto shape = ToShape(shape_tensor); auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); auto offset = LoadScalarInt(instr.alloc_tensor.offset); From 0924927d4724fa45113bd3e53bc3f9d61457c69b Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 12 May 2020 13:34:57 -0700 Subject: [PATCH 22/31] Fix linting --- src/runtime/ndarray.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 30c2990691e6..800a9167dadc 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -236,7 +236,9 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str } std::vector NDArray::Shape() const { return get_mutable()->shape_; } -runtime::DataType NDArray::DataType() const { return runtime::DataType(get_mutable()->dl_tensor.dtype); } +runtime::DataType NDArray::DataType() const { + return runtime::DataType(get_mutable()->dl_tensor.dtype); +} TVM_REGISTER_OBJECT_TYPE(NDArray::Container); From 98726edfbb6f62a7118bed1ddf1335acd4cb9bc0 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 12 May 2020 13:37:30 -0700 Subject: [PATCH 23/31] Fix clang format --- include/tvm/runtime/ndarray.h | 2 +- include/tvm/runtime/vm.h | 10 +++++----- src/runtime/vm/memory_manager.cc | 6 ++---- src/runtime/vm/vm.cc | 11 +++++++---- 4 files changed, 15 insertions(+), 14 deletions(-) diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index d1bc528aed5b..0171d8a999e8 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -25,9 +25,9 @@ #define TVM_RUNTIME_NDARRAY_H_ #include +#include #include #include -#include #include #include diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 30cde8ddca49..552edc5f19db 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -24,8 +24,8 @@ #ifndef TVM_RUNTIME_VM_H_ #define TVM_RUNTIME_VM_H_ -#include #include +#include #include #include @@ -277,8 +277,8 @@ struct Instruction { * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensor(RegName storage, Index offset, - const std::vector& shape, DLDataType dtype, RegName dst); + static Instruction AllocTensor(RegName storage, Index offset, const std::vector& shape, + DLDataType dtype, RegName dst); /*! * \brief Construct an allocate tensor instruction with register. * \param storage The storage to allocate out of. @@ -288,8 +288,8 @@ struct Instruction { * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensorReg(RegName storage, Index offset, - RegName shape_register, DLDataType dtype, RegName dst); + static Instruction AllocTensorReg(RegName storage, Index offset, RegName shape_register, + DLDataType dtype, RegName dst); /*! * \brief Construct an allocate datatype instruction. * \param tag The datatype tag. diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index 1508e0311db0..4c220bbe61c8 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -103,10 +103,8 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDa // RAII in effect, now run the check. CHECK(offset + needed_size <= this->buffer.size) - << "storage allocation failure, attempted to allocate " - << needed_size << " at offset " - << offset << " in region that is " - << this->buffer.size << "bytes"; + << "storage allocation failure, attempted to allocate " << needed_size << " at offset " + << offset << " in region that is " << this->buffer.size << "bytes"; return ret; } diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 871551cc93d6..b3f968ea9443 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -622,16 +622,19 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { std::vector ToShape(NDArray shape_tensor) { std::vector shape; auto dtype = shape_tensor.DataType(); - CHECK(shape_tensor.Shape().size() == 1) - << "shape tensor should be a k-length vector."; + CHECK(shape_tensor.Shape().size() == 1) << "shape tensor should be a k-length vector."; int64_t ndim = shape_tensor.Shape().at(0); - if (ndim == 0) { return shape; } else { shape.resize(ndim); } + if (ndim == 0) { + return shape; + } else { + shape.resize(ndim); + } const DLTensor* dl_tensor = shape_tensor.operator->(); if (dtype.is_int() && dtype.bits() == 32 && dtype.lanes() == 1) { - int32_t* dims = reinterpret_cast(dl_tensor->data); + int32_t* dims = reinterpret_cast(dl_tensor->data); shape.assign(dims, dims + ndim); } else if (dtype.is_int() && dtype.bits() == 64 && dtype.lanes() == 1) { int64_t* dims = reinterpret_cast(dl_tensor->data); From 443c6af4f5dcb805a1937cac16aab5dd0584a795 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 12 May 2020 13:39:02 -0700 Subject: [PATCH 24/31] Format memory.cc --- src/relay/op/memory/memory.cc | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 823bcaffd9e5..76a3315dbb03 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -92,19 +92,18 @@ RELAY_REGISTER_OP("memory.alloc_storage") }); TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") - .set_body_typed( - [](Expr storage, Expr offset, tvm::relay::Expr shape, - DataType dtype, Array assert_shape) { - auto attrs = make_object(); - attrs->dtype = dtype; - if (assert_shape.defined()) { - attrs->assert_shape = assert_shape; - } else { - attrs->const_shape = Downcast(shape); - } - static const Op& op = Op::Get("memory.alloc_tensor"); - return Call(op, {storage, offset, shape}, Attrs(attrs), {}); - }); + .set_body_typed([](Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, + Array assert_shape) { + auto attrs = make_object(); + attrs->dtype = dtype; + if (assert_shape.defined()) { + attrs->assert_shape = assert_shape; + } else { + attrs->const_shape = Downcast(shape); + } + static const Op& op = Op::Get("memory.alloc_tensor"); + return Call(op, {storage, offset, shape}, Attrs(attrs), {}); + }); std::vector FromConstShape(Constant konst) { runtime::NDArray shape = konst->data; From 483f846fb7510ab5e7ec60d3d1292351f9a9a1b9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 12 May 2020 14:22:43 -0700 Subject: [PATCH 25/31] Fix 0-rank case --- src/runtime/vm/vm.cc | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index b3f968ea9443..8440e60824b6 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -621,17 +621,21 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { std::vector ToShape(NDArray shape_tensor) { std::vector shape; + auto rank = shape_tensor.Shape().size(); auto dtype = shape_tensor.DataType(); - CHECK(shape_tensor.Shape().size() == 1) << "shape tensor should be a k-length vector."; - int64_t ndim = shape_tensor.Shape().at(0); - - if (ndim == 0) { + // For 0-rank shapes we need to allocate a single scalar. + if (rank == 0) { return shape; - } else { - shape.resize(ndim); } + // Otherwise we should be rank-1, and we will extract the number of dimensions + // for the output vector. + CHECK(shape_tensor.Shape().size() == 1) + << "shape tensor should be a k-length vector, found " << shape_tensor.Shape().size(); + int64_t ndim = shape_tensor.Shape().at(0); + shape.resize(ndim); + const DLTensor* dl_tensor = shape_tensor.operator->(); if (dtype.is_int() && dtype.bits() == 32 && dtype.lanes() == 1) { int32_t* dims = reinterpret_cast(dl_tensor->data); From f5cc10b16a56837a809576c74da362465b6b70f4 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 12 May 2020 20:37:14 -0700 Subject: [PATCH 26/31] Add fix for (0,) shape --- tests/python/frontend/onnx/test_forward.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 614041401026..22b975da8149 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -64,6 +64,15 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) + + # Normalize any parameters with (0,) shape. + # + # Currently Relay does not support (0,) shaped tensors. + for param_key in params: + param = params[param_key] + if len(param.shape) == 1 and param.shape[0] == 0: + params[param_key] = tvm.nd.array(np.empty(())) + with relay.build_config(opt_level=1): graph, lib, params = relay.build(mod, target, @@ -1667,7 +1676,7 @@ def verify_prelu(x_shape, a_shape): onnx_out = get_onnxruntime_output(model, [indata, slopedata]) for target, ctx in [('llvm', tvm.cpu())]: - tvm_out = get_tvm_output(model, [indata, slopedata], target, ctx, list(x_shape), + tvm_out = get_tvm_output(model, [indata, slopedata], target, ctx, list(x_shape), output_dtype='float32') tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) @@ -2572,7 +2581,7 @@ def verify_topk(input_dims, K, axis=-1): inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), helper.make_tensor_value_info("K", TensorProto.INT64, [1,])], initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])], - outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), + outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)]) model = helper.make_model(graph, producer_name='topk_test') @@ -2581,10 +2590,10 @@ def verify_topk(input_dims, K, axis=-1): onnx_out = get_onnxruntime_output(model, [indata, k]) for target, ctx in [('llvm', tvm.cpu())]: - tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims], + tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims], output_dtype=['float32', 'int64']) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) - + for n in [12, 32]: for shape in [[n], [n, n], [n, n, n]]: for k in [1, 5, 10]: @@ -2593,7 +2602,7 @@ def verify_topk(input_dims, K, axis=-1): verify_topk([n, n, n], 5, 0) verify_topk([n, n, n], 5, 1) verify_topk([n, n, n], 5, 2) - + def test_roi_align(): def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0): From 72d2a622f7e629be2306988db3776402e5d2e5e6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 13 May 2020 09:15:05 -0700 Subject: [PATCH 27/31] Ignore shapes for now --- python/tvm/relay/expr.py | 3 --- tests/python/frontend/onnx/test_forward.py | 8 -------- 2 files changed, 11 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index f69414db6952..3e98e52af714 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -505,9 +505,6 @@ def const(value, dtype=None): if not isinstance(value, _nd.NDArray): raise ValueError("value has to be scalar or NDArray") - for dim in value.shape: - assert dim != 0, "Relay constants can not contain a 0 dimension." - return Constant(value) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 22b975da8149..2e61b4c62c73 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -65,14 +65,6 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) - # Normalize any parameters with (0,) shape. - # - # Currently Relay does not support (0,) shaped tensors. - for param_key in params: - param = params[param_key] - if len(param.shape) == 1 and param.shape[0] == 0: - params[param_key] = tvm.nd.array(np.empty(())) - with relay.build_config(opt_level=1): graph, lib, params = relay.build(mod, target, From 3d951b56ba7cd92b612f481342db961f617d2a11 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 13 May 2020 18:30:11 -0700 Subject: [PATCH 28/31] Apply suggestions from code review Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com> --- python/tvm/relay/transform/memory_plan.py | 6 +++--- src/relay/backend/vm/compiler.cc | 2 +- src/runtime/vm/vm.cc | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index f3d994701b92..5bf2223d39d1 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -299,7 +299,7 @@ def process_alloc_tensor(self, lhs, call): ) class LiftConst(ExprMutator): - """A internal pass to lift constants to the top level of function.""" + """An internal pass to lift constants to the top level of function.""" def __init__(self): self.i = 0 self.constants = [] @@ -330,7 +330,7 @@ def visit_function(self, fn): @function_pass(opt_level=0) class MemoryPlan: - """An explicit pass wrapper around ManifestAlloc.""" + """An explicit pass wrapper around StorageCoalesce.""" def transform_function(self, func, mod, _): mod.import_from_std("core.rly") @@ -342,7 +342,7 @@ def transform_function(self, func, mod, _): @function_pass(opt_level=0) class LiftConstants: - """An explicit pass wrapper around LiftConstants.""" + """An explicit pass wrapper around LiftConst.""" def transform_function(self, func, mod, _): mod.import_from_std("core.rly") diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index fcefa4b98648..810664e58c93 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -69,7 +69,7 @@ Pass MemoryPlan() { Pass LiftConstants() { auto f = tvm::runtime::Registry::Get("relay.transform.LiftConstants"); - CHECK(f != nullptr) << "unable to load the memory planning pass"; + CHECK(f != nullptr) << "unable to load the constant lifting pass"; return (*f)(); } diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 8440e60824b6..e775064df64f 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -632,7 +632,7 @@ std::vector ToShape(NDArray shape_tensor) { // Otherwise we should be rank-1, and we will extract the number of dimensions // for the output vector. CHECK(shape_tensor.Shape().size() == 1) - << "shape tensor should be a k-length vector, found " << shape_tensor.Shape().size(); + << "shape tensor should be a k-length vector, found " << rank; int64_t ndim = shape_tensor.Shape().at(0); shape.resize(ndim); From 28150aaa2ac512a79789757610b537544ef2259c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 13 May 2020 18:40:09 -0700 Subject: [PATCH 29/31] Update src/runtime/vm/executable.cc Co-authored-by: Zhi <5145158+zhiics@users.noreply.github.com> --- src/runtime/vm/executable.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index b037d9f56134..47bdd1c705de 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -570,7 +570,7 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { return Instruction::AllocTensor(storage_reg, offset, shape, dtype, dst); } case Opcode::AllocTensorReg: { - // Number of fields = 5 + // Number of fields = 7 DCHECK_EQ(instr.fields.size(), 7U); RegName storage_reg = instr.fields[0]; From 91c22ee6a9d3f26caadce6ae9c0bf09ce1bfff87 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 13 May 2020 18:42:23 -0700 Subject: [PATCH 30/31] Fix --- python/tvm/relay/transform/memory_plan.py | 4 +++- src/runtime/vm/vm.cc | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 5bf2223d39d1..a6c2c11745ed 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -318,7 +318,9 @@ def visit_function(self, fn): outer_constant = self.constants self.constants = [] - body = mk_let(self.constants, self.visit(fn.body)) + # Populates self.constants. + body = self.visit(fn.body) + body = mk_let(self.constants, body) self.constants = outer_constant return Function( diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index e775064df64f..4a3eed4a2889 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -631,7 +631,7 @@ std::vector ToShape(NDArray shape_tensor) { // Otherwise we should be rank-1, and we will extract the number of dimensions // for the output vector. - CHECK(shape_tensor.Shape().size() == 1) + CHECK_EQ(rank, 1U) << "shape tensor should be a k-length vector, found " << rank; int64_t ndim = shape_tensor.Shape().at(0); shape.resize(ndim); From b778d4882f2ede2e93767619c80e94f8173f9c38 Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Thu, 14 May 2020 03:42:27 +0000 Subject: [PATCH 31/31] lint --- src/runtime/vm/vm.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 4a3eed4a2889..22102c93083b 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -631,8 +631,7 @@ std::vector ToShape(NDArray shape_tensor) { // Otherwise we should be rank-1, and we will extract the number of dimensions // for the output vector. - CHECK_EQ(rank, 1U) - << "shape tensor should be a k-length vector, found " << rank; + CHECK_EQ(rank, 1U) << "shape tensor should be a k-length vector, found " << rank; int64_t ndim = shape_tensor.Shape().at(0); shape.resize(ndim);