diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 8db93b46e934..0171d8a999e8 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_NDARRAY_H_ #include +#include #include #include @@ -160,6 +161,7 @@ class NDArray : public ObjectRef { TVMStreamHandle stream = nullptr); TVM_DLL std::vector Shape() const; + TVM_DLL runtime::DataType DataType() const; // internal namespace struct Internal; diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index 62534d9ca6a9..552edc5f19db 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -136,6 +136,8 @@ struct Instruction { struct /* AllocTensor Operands */ { /*! \brief The storage to allocate from. */ RegName storage; + /*! \brief The offset into the storage to allocate from. */ + Index offset; /*! \brief The number of dimensions. */ uint32_t ndim; /*! \brief The shape of tensor. */ @@ -146,6 +148,8 @@ struct Instruction { struct /* AllocTensorReg Operands */ { /*! \brief The storage to allocate from. */ RegName storage; + /*! \brief The offset into the storage to allocate from. */ + Index offset; /*! \brief The register to read the shape out of. */ RegName shape_register; /*! \brief The datatype of tensor to be allocated. */ @@ -267,23 +271,25 @@ struct Instruction { /*! * \brief Construct an allocate tensor instruction with constant shape. * \param storage The storage to allocate out of. + * \param offset The offset to allocate at. * \param shape The shape of the tensor. * \param dtype The dtype of the tensor. * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensor(RegName storage, const std::vector& shape, + static Instruction AllocTensor(RegName storage, Index offset, const std::vector& shape, DLDataType dtype, RegName dst); /*! * \brief Construct an allocate tensor instruction with register. * \param storage The storage to allocate out of. + * \param offset The offset into the storage to allocate from. * \param shape_register The register containing the shape. * \param dtype The dtype of the tensor. * \param dst The destination register. * \return The allocate tensor instruction. */ - static Instruction AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, - RegName dst); + static Instruction AllocTensorReg(RegName storage, Index offset, RegName shape_register, + DLDataType dtype, RegName dst); /*! * \brief Construct an allocate datatype instruction. * \param tag The datatype tag. diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 4663866b1452..8e48e509490b 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -58,6 +58,12 @@ # Dialects from . import qnn +from .scope_builder import ScopeBuilder + +# Load Memory Passes +from .transform import memory_alloc +from .transform import memory_plan + # Required to traverse large programs setrecursionlimit(10000) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index ff1368394917..3e98e52af714 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -504,6 +504,7 @@ def const(value, dtype=None): if not isinstance(value, _nd.NDArray): raise ValueError("value has to be scalar or NDArray") + return Constant(value) diff --git a/python/tvm/relay/op/memory/memory.py b/python/tvm/relay/op/memory/memory.py index 509db354b42c..4092545d552c 100644 --- a/python/tvm/relay/op/memory/memory.py +++ b/python/tvm/relay/op/memory/memory.py @@ -40,7 +40,7 @@ def invoke_tvm_op(func, inputs, outputs): """ return _make.invoke_tvm_op(func, inputs, outputs) -def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): +def alloc_tensor(storage, offset, shape, dtype='float32', assert_shape=None): """Allocate a tensor with the provided shape, and dtype. Parameters @@ -48,6 +48,9 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): storage : tvm.relay.Expr The storage to allocate from. + offset : tvm.relay.Expr + The offset to allocate from. + shape : tvm.relay.Expr The shape of the tensor to allocate. @@ -61,7 +64,7 @@ def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): result : tvm.relay.Expr The alloc_tensor expression. """ - return _make.alloc_tensor(storage, shape, dtype, assert_shape) + return _make.alloc_tensor(storage, offset, shape, dtype, assert_shape) def alloc_storage(size, alignment, ctx, dtype_hint='float32'): """Allocate a piece of tensor storage. diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 93d4341635a0..138a36611c6f 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -18,5 +18,4 @@ """The Relay IR namespace containing transformations.""" # transformation passes from .transform import * - from . import memory_alloc diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index 611fb1babf55..6c081cbac0de 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -28,19 +28,21 @@ from ..backend import compile_engine from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type from ...import cpu +from ..op.memory import alloc_storage +def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): + offset = expr.const(0, dtype="int64") + return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape) def is_primitive(call): return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \ hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1 class ManifestAllocPass(ExprMutator): - """A pass for explictly manifesting all memory allocations in Relay.""" + """A pass for explicitly manifesting all memory allocations in Relay.""" def __init__(self, target_host): self.invoke_tvm = op.memory.invoke_tvm_op - self.alloc_storage = op.memory.alloc_storage - self.alloc_tensor = op.memory.alloc_tensor self.shape_func = op.memory.shape_func self.scopes = [ScopeBuilder()] self.target_host = target_host @@ -94,17 +96,16 @@ def make_static_allocation(self, scope, tensor_type, i): """Allocate a tensor with a statically known shape.""" shape = [int(sh) for sh in tensor_type.shape] if len(shape) == 0: - shape = expr.const(np.array([]).astype( - self.compute_dtype), dtype=self.compute_dtype) + shape = expr.const(np.empty((), dtype=self.compute_dtype), dtype=self.compute_dtype) else: shape = expr.const(np.array(shape), dtype=self.compute_dtype) size = self.compute_storage(tensor_type) alignment = self.compute_alignment(tensor_type.dtype) dtype = tensor_type.dtype - sto = scope.let("storage_{0}".format(i), self.alloc_storage( + sto = scope.let("storage_{0}".format(i), alloc_storage( size, alignment, self.default_context, dtype)) # TODO(@jroesch): There is a bug with typing based on the constant shape. - tensor = self.alloc_tensor(sto, shape, dtype, tensor_type.shape) + tensor = alloc_tensor(sto, shape, dtype, tensor_type.shape) return scope.let("tensor_{0}".format(i), tensor) def visit_let(self, let): @@ -172,14 +173,14 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): size = self.compute_storage_in_relay( out_shape, out_type.dtype) alignment = self.compute_alignment(out_type.dtype) - sto = scope.let("storage_{i}".format(i=i), self.alloc_storage( + sto = scope.let("storage_{i}".format(i=i), alloc_storage( size, alignment, self.default_context, out_type.dtype)) storages.append(sto) outs = [] sh_ty_storage = zip(out_shapes, out_types, storages) for i, (out_shape, out_type, storage) in enumerate(sh_ty_storage): - alloc = self.alloc_tensor( + alloc = alloc_tensor( storage, out_shape, out_type.dtype, @@ -204,6 +205,7 @@ def visit_call(self, call): # Because we are in ANF we do not need to visit the arguments. scope = self.current_scope() new_args = [self.visit(arg) for arg in call.args] + ins = expr.Tuple(new_args) ret_type = call.checked_type out_types = flatten_tuple_type(ret_type) @@ -233,7 +235,7 @@ def __init__(self, target_host): self.target_host = target_host def transform_function(self, func, mod, _): - # TODO(@jroesch): Is there a way to do one shot initilization? + # TODO(@jroesch): Is there a way to do one shot initialization? # can we have def pass_init? mod.import_from_std("core.rly") ea = ManifestAllocPass(self.target_host) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py new file mode 100644 index 000000000000..a6c2c11745ed --- /dev/null +++ b/python/tvm/relay/transform/memory_plan.py @@ -0,0 +1,355 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks +""" +A pass for manifesting explicit memory allocations. +""" +from typing import Optional, Dict, List, Tuple +from collections import defaultdict +import attr + +from ..expr_functor import ExprMutator +from .. import op, expr +from ..function import Function +from ... import register_func, ir, cpu +from ..._ffi.runtime_ctypes import TVMContext +from ... import IRModule +from .. import transform +from . import function_pass + + +def is_primitive(call): + return ( + hasattr(call, "op") + and hasattr(call.op, "attrs") + and hasattr(call.op.attrs, "Primitive") + and int(call.op.attrs.Primitive) == 1 + ) + + +@attr.s(auto_attribs=True) +class Region: + """ + Represents a control-free allocation region. + + The below pass groups sets of allocations into regions, + then replaces the region with a single allocation. + """ + var: expr.Var + size: expr.Expr + alignment: Optional[expr.Expr] + dtype: Optional[str] + ctx: TVMContext + offsets: Dict[expr.Var, Tuple[expr.Expr, expr.Expr]] + + @staticmethod + def empty(region_no): + zero = expr.const(0, dtype="int64") + assert len(zero.data.shape) == 0 + region_var = expr.var(f"region{region_no}") + return Region(region_var, zero, None, None, None, {}) + + def grow( + self, old_storage: expr.Var, + size: expr.Expr, alignment: expr.Expr, + ctx: TVMContext, + dtype: str) -> None: + """Grow the region by a given allocation as well as track the old storage + for later rewriting the program to use the allocated region. + """ + if self.dtype: + assert self.dtype == dtype, "must have matching dtypes in a region" + else: + self.dtype = dtype + + if self.alignment: + assert ir.structural_equal( + self.alignment, alignment + ), "must have matching alignments in a region" + else: + self.alignment = alignment + + if self.ctx: + assert (self.ctx.device_type == ctx.device_type and + self.ctx.device_id == ctx.device_id), "must have matching context" + else: + assert ctx + self.ctx = ctx + + new_size = (size + self.alignment - expr.const(1, "int64")) \ + / self.alignment * self.alignment + + # Record the offset at which we allocate the storage. + offset_var: expr.RelayExpr = expr.var(f"offset{len(self.offsets)}") + self.offsets[old_storage] = (offset_var, self.size) + + self.size = self.size + new_size + + def offset_for(self, alloc: expr.Expr) -> expr.Expr: + return self.offsets.get(alloc, [None])[0] + + def to_expr(self, body: expr.Expr) -> expr.Expr: + """ + Generate the prelude code for a region, wrapping the body in it. + + The prelude contains the single allocation for a region, and + all offset computations. + """ + + if self.ctx is None: + self.ctx = cpu(0) + + # Generate bindings for each and every size computation + # we must do this to maintain ANF. + bindings: List[Tuple[expr.Expr, expr.Expr]] = [] + + # First compute the total size. + total_size = expr.var(f"total_size{hash(body)}") + bindings.append((total_size, self.size)) + + # Allocate the entire region with a single call. + alloc = op.memory.alloc_storage(total_size, self.alignment, self.ctx, self.dtype) + bindings.append((self.var, alloc)) + + # Generate variables which contain all of the offset math. + # Ensure we constant evaluate away all the math here. + # + # In theory we can support dynamic offsets but this + # requires another round of memory planning and + # potentially colaescing. + for alloc in self.offsets: + (var, offset) = self.offsets[alloc] + bindings.append((var, offset)) + + body = mk_let(bindings, body) + return body + + +def iterative_let(let, each_binding, kont): + bindings = [] + while isinstance(let, expr.Let): + lhs = let.var + rhs = let.value + bindings.append(each_binding(lhs, rhs)) + let = let.body + + return kont(bindings, let) + + + +def mk_let(bindings, body): + for var, value in reversed(bindings): + assert var + assert value + assert body + body = expr.Let(var, value, body) + + return body + +def const_eval(mod, exp): + mod = IRModule.from_expr(exp, type_defs=mod.type_definitions) + mod = transform.FoldConstant()(mod) + return mod["main"] + +class StorageCoalesce(ExprMutator): + """ + A pass for coalescing allocations into region/arena allocations. + + After this pass each allocation comes from the same backing storage, + but will never overlap even in time, i.e. the allocations are just + packed into a contiguous block of memory. + + A secondary part of memory planning will perform liveness analysis to + overlap these in time, i.e when an early tensor dies we will attempt + to reuse its slot. + """ + + def __init__(self): + super().__init__() + self.regions = [] + + def enter_scope(self) -> None: + region_no = len(self.regions) + self.regions.append(defaultdict(lambda: Region.empty(region_no))) + + def exit_scope(self, body: expr.Expr) -> expr.Expr: + """When leaving a scope build a region allocation for the scope.""" + dtype_region = self.regions.pop() + for _, region in reversed(list(dtype_region.items())): + if len(region.offsets) != 0: + body = region.to_expr(body) + + return body + + def current_region(self, dtype) -> Region: + current_scope = self.regions[-1] + return current_scope[dtype] + + def new_region_and_offset(self, old_storage): + for dtype_region in reversed(self.regions): + for dtype in dtype_region: + region = dtype_region[dtype] + offset = region.offset_for(old_storage) + if offset: + return region, offset + + raise Exception("could not find offset in any valid region") + + def visit_function(self, fn): + """Transform the function body to use region allocation scheme.""" + func = fn + if getattr(func.attrs, "Primitive", 0) == 1: + return super().visit_function(func) + else: + self.enter_scope() + body = self.visit(func.body) + body = self.exit_scope(body) + return Function( + func.params, + body, + func.ret_type, + func.type_params, + func.attrs, + ) + + def visit_if(self, ite): + self.enter_scope() + true_branch = self.visit(ite.true_branch) + true_branch = self.exit_scope(true_branch) + + self.enter_scope() + false_branch = self.visit(ite.false_branch) + false_branch = self.exit_scope(false_branch) + + return expr.If(ite.cond, true_branch, false_branch) + + + def mk_let(self, dynamic_regions): + """Let bind the dynamic regions""" + def _mk_let(bindings, body): + for var, value in reversed(bindings): + assert var + assert value + assert body + body = expr.Let(var, value, body) + if var in dynamic_regions: + body = self.exit_scope(body) + + return body + + return _mk_let + + def visit_let(self, let): + dynamic_regions = [] + def _each_binding(lhs, rhs): + if isinstance(rhs, expr.Call) and rhs.op == op.op.get( + "memory.alloc_storage" + ): + return self.process_alloc_storage(dynamic_regions, lhs, rhs) + elif isinstance(rhs, expr.Call) and rhs.op == op.op.get( + "memory.alloc_tensor" + ): + return self.process_alloc_tensor(lhs, rhs) + else: + return lhs, rhs + + result = iterative_let(let, _each_binding, self.mk_let(dynamic_regions)) + assert result + return result + + def process_alloc_storage(self, dynamic_regions, lhs, call): + """Process alloc_storage""" + size, alignment = call.args + dtype = call.attrs.dtype + ctx = TVMContext(call.attrs.device_type, call.attrs.device_id) + + if not isinstance(size, expr.Constant): + self.enter_scope() + dynamic_regions.append(lhs) + + region = self.current_region(dtype) + region.grow(lhs, size, alignment, ctx, dtype) + return lhs, region.var + + def process_alloc_tensor(self, lhs, call): + """Process alloc tensor. Region and offset are computed""" + storage, old_offset, shape = call.args + region, offset = self.new_region_and_offset(storage) + + assert ( + old_offset.data.asnumpy().item() == 0 + ), "no offsets should yet be allocated" + return ( + lhs, + expr.Call(call.op, [region.var, offset, shape], call.attrs), + ) + +class LiftConst(ExprMutator): + """An internal pass to lift constants to the top level of function.""" + def __init__(self): + self.i = 0 + self.constants = [] + self.top_level = True + super().__init__() + + def visit_constant(self, const): + var = expr.var(f"const{self.i}") + self.i += 1 + self.constants.append((var, const)) + return var + + def visit_function(self, fn): + if int(getattr(fn.attrs, "Primitive", 0)) == 1: + return fn + + outer_constant = self.constants + self.constants = [] + # Populates self.constants. + body = self.visit(fn.body) + body = mk_let(self.constants, body) + self.constants = outer_constant + + return Function( + fn.params, + body, + fn.ret_type, + fn.type_params, + fn.attrs) + +@function_pass(opt_level=0) +class MemoryPlan: + """An explicit pass wrapper around StorageCoalesce.""" + + def transform_function(self, func, mod, _): + mod.import_from_std("core.rly") + sc = StorageCoalesce() + func = sc.visit(func) + return func + +register_func("relay.transform.MemoryPlan", MemoryPlan) + +@function_pass(opt_level=0) +class LiftConstants: + """An explicit pass wrapper around LiftConst.""" + + def transform_function(self, func, mod, _): + mod.import_from_std("core.rly") + func = LiftConst().visit(func) + return func + + +register_func("relay.transform.LiftConstants", LiftConstants) diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index b2a5e83ef43c..810664e58c93 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -45,6 +45,7 @@ #include "../../op/op_common.h" #include "../../transforms/pass_util.h" #include "../utils.h" +#include "compiler.h" namespace tvm { namespace relay { @@ -56,10 +57,22 @@ Pass InlinePrimitives(); Pass ManifestAlloc(Target target_host) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); - CHECK(f != nullptr) << "could not load memory allocation pass"; + CHECK(f != nullptr) << "unable to load allocation manifestation pass"; return (*f)(target_host); } +Pass MemoryPlan() { + auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan"); + CHECK(f != nullptr) << "unable to load the memory planning pass"; + return (*f)(); +} + +Pass LiftConstants() { + auto f = tvm::runtime::Registry::Get("relay.transform.LiftConstants"); + CHECK(f != nullptr) << "unable to load the constant lifting pass"; + return (*f)(); +} + } // namespace transform namespace vm { @@ -281,6 +294,15 @@ class VMFunctionCompiler : ExprFunctor { } void VisitExpr_(const ConstantNode* const_node) { + // Check the shape is valid + NDArray data = const_node->data; + const DLTensor* tensor = data.operator->(); + if (tensor->ndim > 0) { + int64_t* shapes = reinterpret_cast(tensor->shape); + for (auto i = 0; i < tensor->ndim; i++) { + CHECK_GT(shapes[i], 0U); + } + } size_t konst_idx = context_->constants.size(); context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); @@ -501,37 +523,41 @@ class VMFunctionCompiler : ExprFunctor { CHECK_EQ(args.size(), 3); EmitInvokeTVMOp(Downcast(args[0]), args[1], args[2]); }) - .Match( - "memory.alloc_tensor", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { - CHECK_EQ(args.size(), 2); - - // Get the attributes. - auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; - auto dtype = alloc_attrs->dtype; - - // The storage will be passed dynamically. - this->VisitExpr(args[0]); - auto storage_register = last_register_; - - // If the shape is constant then we will emit a static tensor allocation - // instruction. - auto const_shape = args[1].as(); - - if (const_shape) { - NDArray shape = const_shape->data; - // TODO(@jroesch): we need to get an RFC done to standarize shape dtype - std::vector raw_shape = ToAllocTensorShape(shape); - // Add context field. - Emit(Instruction::AllocTensor(storage_register, raw_shape, dtype, NewRegister())); - } else { - this->VisitExpr(args[1]); - auto shape_register = last_register_; - Emit(Instruction::AllocTensorReg(storage_register, shape_register, dtype, - NewRegister())); - } - }) + .Match("memory.alloc_tensor", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 3); + + // Get the attributes. + auto alloc_attrs = attrs.as(); + CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + auto dtype = alloc_attrs->dtype; + + // The storage will be passed dynamically. + this->VisitExpr(args[0]); + auto storage_register = last_register_; + + // The storage will be passed dynamically. + this->VisitExpr(args[1]); + auto offset_register = last_register_; + + // If the shape is constant then we will emit a static tensor allocation + // instruction. + auto const_shape = args[2].as(); + + if (const_shape) { + NDArray shape = const_shape->data; + // TODO(@jroesch): we need to get an RFC done to standarize shape dtype + std::vector raw_shape = ToAllocTensorShape(shape); + // Add context field. + Emit(Instruction::AllocTensor(storage_register, offset_register, raw_shape, + dtype, NewRegister())); + } else { + this->VisitExpr(args[2]); + auto shape_register = last_register_; + Emit(Instruction::AllocTensorReg(storage_register, offset_register, + shape_register, dtype, NewRegister())); + } + }) .Match("memory.alloc_storage", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { CHECK_EQ(args.size(), 2); @@ -830,6 +856,44 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe } } +transform::Sequential MemoryOpt(tvm::Target host_target) { + Array pass_seqs; + // Manifest the allocations. + pass_seqs.push_back(transform::ManifestAlloc(host_target)); + + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); + + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); + + // Manifest the allocations needed for the shape functions. + pass_seqs.push_back(transform::ManifestAlloc(host_target)); + + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); + + // Perform memory planning in order to coalesce/reduce allocations. + pass_seqs.push_back(transform::MemoryPlan()); + + // Compute away constant computation introduced by coalescing allocations. + pass_seqs.push_back(transform::FoldConstant()); + + // Fuse the shape functions. + pass_seqs.push_back(transform::FuseOps()); + + // Create allocations for math introduced by dynamic region math. + pass_seqs.push_back(transform::ManifestAlloc(host_target)); + + // Compute away possibly introduced constant computation. + pass_seqs.push_back(transform::FoldConstant()); + + // Lift constants to the top-level of the block to simplify VM code generation. + pass_seqs.push_back(transform::LiftConstants()); + + return transform::Sequential(pass_seqs); +} + IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targets) { Array pass_seqs; Array entry_functions{"main"}; @@ -890,15 +954,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // external codegen. pass_seqs.push_back(transform::Inline()); - // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); - // Compute away possibly introduced constant computation. - pass_seqs.push_back(transform::FoldConstant()); - // Fuse the shape functions. - pass_seqs.push_back(transform::FuseOps()); - - // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); + pass_seqs.push_back(MemoryOpt(this->target_host_)); transform::Sequential seq(pass_seqs); transform::PassContext pass_ctx = PassContext::Current(); diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index ec96e23a01fb..76a3315dbb03 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -92,7 +92,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") }); TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") - .set_body_typed([](Expr storage, tvm::relay::Expr shape, DataType dtype, + .set_body_typed([](Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { auto attrs = make_object(); attrs->dtype = dtype; @@ -102,7 +102,7 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") attrs->const_shape = Downcast(shape); } static const Op& op = Op::Get("memory.alloc_tensor"); - return Call(op, {storage, shape}, Attrs(attrs), {}); + return Call(op, {storage, offset, shape}, Attrs(attrs), {}); }); std::vector FromConstShape(Constant konst) { @@ -132,7 +132,7 @@ std::vector FromConstShape(Constant konst) { bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 3u); + CHECK_EQ(types.size(), 4u); auto alloc_attrs = attrs.as(); CHECK(alloc_attrs != nullptr) << "must be alloc_tensor attributes"; // First argument should be storage. @@ -141,18 +141,28 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs auto storage_name = mod->GetGlobalTypeVar("Storage"); auto storage = relay::TypeCall(storage_name, {}); reporter->Assign(types[0], storage); - // Second argument should be shape tensor. - auto tt = types[1].as(); + // Second argument should be the offset. + auto offset_type = types[1].as(); + CHECK(offset_type != nullptr) << "must be a scalar type"; + + // Third argument should be shape tensor. + auto tt = types[2].as(); CHECK(tt != nullptr) << "must be tensor type"; - auto rank = tt->shape[0].as(); - CHECK(rank != nullptr); - auto dims = rank->value; + + // Be careful about having to allocate scalars. + int64_t dims = 0; + if (tt->shape.size() != 0) { + auto rank = tt->shape[0].as(); + CHECK(rank != nullptr); + dims = rank->value; + } // Constant node case. Type alloc_type; if (alloc_attrs->const_shape.defined()) { auto con = alloc_attrs->const_shape; auto sh = FromConstShape(con); + CHECK_EQ(sh.size(), dims); Array out_shape; for (auto i = 0u; i < dims; i++) { out_shape.push_back(tvm::Integer(sh[i])); @@ -165,14 +175,15 @@ bool AllocTensorRel(const Array& types, int num_inputs, const Attrs& attrs return true; } - reporter->Assign(types[2], alloc_type); + reporter->Assign(types[3], alloc_type); return true; } RELAY_REGISTER_OP("memory.alloc_tensor") .describe(R"code(Explicitly allocate storage to be used by tensors.)code" TVM_ADD_FILELINE) - .set_num_inputs(2) + .set_num_inputs(3) .add_argument("storage", "Storage", "The storage to allocate from.") + .add_argument("offset", "Tensor", "The offset into the backing storage.") .add_argument("shape", "Tensor", "The shape of the tensor to allocate.") .add_type_rel("AllocTensor", AllocTensorRel) .set_support_level(10) diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index d97d01b0feab..800a9167dadc 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -236,6 +236,9 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str } std::vector NDArray::Shape() const { return get_mutable()->shape_; } +runtime::DataType NDArray::DataType() const { + return runtime::DataType(get_mutable()->dl_tensor.dtype); +} TVM_REGISTER_OBJECT_TYPE(NDArray::Container); diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index c72e70fd6f66..47bdd1c705de 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -307,9 +307,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { break; } case Opcode::AllocTensor: { - // Number of fields = 5 + instr.alloc_tensor.ndim + // Number of fields = 7 + instr.alloc_tensor.ndim fields.push_back(instr.alloc_tensor.storage); - + fields.push_back(instr.alloc_tensor.offset); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor.dtype; fields.push_back(dtype.code); @@ -330,8 +330,9 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { break; } case Opcode::AllocTensorReg: { - // Number of fields = 6 + // Number of fields = 7 fields.push_back(instr.alloc_tensor_reg.storage); + fields.push_back(instr.alloc_tensor_reg.offset); fields.push_back(instr.alloc_tensor_reg.shape_register); // Save `DLDataType` and the dst register. const auto& dtype = instr.alloc_tensor_reg.dtype; @@ -549,39 +550,41 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { return Instruction::InvokePacked(packed_index, arity, output_size, args); } case Opcode::AllocTensor: { - // Number of fields = 6 + instr.alloc_tensor.ndim - DCHECK_GE(instr.fields.size(), 6U); - DCHECK_EQ(instr.fields.size(), 6U + static_cast(instr.fields[4])); + // Number of fields = 7 + instr.alloc_tensor.ndim + DCHECK_GE(instr.fields.size(), 7U); + DCHECK_EQ(instr.fields.size(), 7U + static_cast(instr.fields[4])); RegName storage_reg = instr.fields[0]; + RegName offset = instr.fields[1]; DLDataType dtype; - dtype.code = instr.fields[1]; - dtype.bits = instr.fields[2]; - dtype.lanes = instr.fields[3]; + dtype.code = instr.fields[2]; + dtype.bits = instr.fields[3]; + dtype.lanes = instr.fields[4]; - Index ndim = instr.fields[4]; - RegName dst = instr.fields[5]; + Index ndim = instr.fields[5]; + RegName dst = instr.fields[6]; - std::vector shape = ExtractFields(instr.fields, 6, ndim); + std::vector shape = ExtractFields(instr.fields, 7, ndim); - return Instruction::AllocTensor(storage_reg, shape, dtype, dst); + return Instruction::AllocTensor(storage_reg, offset, shape, dtype, dst); } case Opcode::AllocTensorReg: { - // Number of fields = 5 - DCHECK_EQ(instr.fields.size(), 6U); + // Number of fields = 7 + DCHECK_EQ(instr.fields.size(), 7U); RegName storage_reg = instr.fields[0]; - Index shape_register = instr.fields[1]; + RegName offset = instr.fields[1]; + Index shape_register = instr.fields[2]; DLDataType dtype; - dtype.code = instr.fields[2]; - dtype.bits = instr.fields[3]; - dtype.lanes = instr.fields[4]; + dtype.code = instr.fields[3]; + dtype.bits = instr.fields[4]; + dtype.lanes = instr.fields[5]; - RegName dst = instr.fields[5]; + RegName dst = instr.fields[6]; - return Instruction::AllocTensorReg(storage_reg, shape_register, dtype, dst); + return Instruction::AllocTensorReg(storage_reg, offset, shape_register, dtype, dst); } case Opcode::AllocADT: { // Number of fields = 3 + instr.num_fields diff --git a/src/runtime/vm/memory_manager.cc b/src/runtime/vm/memory_manager.cc index c0fd441bb0ca..4c220bbe61c8 100644 --- a/src/runtime/vm/memory_manager.cc +++ b/src/runtime/vm/memory_manager.cc @@ -77,8 +77,6 @@ inline size_t GetDataAlignment(const DLTensor& arr) { } NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDataType dtype) { - // TODO(@jroesch): generalize later to non-overlapping allocations. - CHECK_EQ(offset, 0u); VerifyDataType(dtype); // crtical zone: allocate header, cannot throw @@ -87,14 +85,26 @@ NDArray StorageObj::AllocNDArray(size_t offset, std::vector shape, DLDa container->SetDeleter(StorageObj::Deleter); size_t needed_size = GetDataSize(container->dl_tensor); this->IncRef(); + // The manager context pointer must continue to point to the storage object + // which owns the backing memory, and keeps track of the reference count. + // + // When we free a container we extract the storage object, decrement its + // reference count, then destroy the container, but leave the underlying + // buffer intact. container->manager_ctx = reinterpret_cast(this); - container->dl_tensor.data = this->buffer.data; - NDArray ret(GetObjectPtr(container)); + // is this UB? + // The only change we make w.r.t offset is modifying the data pointer + // of the backing tensor to point into the buffer instead of its start. + auto offset_ptr = reinterpret_cast(this->buffer.data) + offset; + container->dl_tensor.data = reinterpret_cast(offset_ptr); + + NDArray ret(GetObjectPtr(container)); // RAII in effect, now run the check. - // TODO(@jroesch): generalize later to non-overlapping allocations. - CHECK(needed_size == this->buffer.size) - << "size mistmatch required " << needed_size << " found " << this->buffer.size; + + CHECK(offset + needed_size <= this->buffer.size) + << "storage allocation failure, attempted to allocate " << needed_size << " at offset " + << offset << " in region that is " << this->buffer.size << "bytes"; return ret; } diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 0714709a0718..22102c93083b 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -85,6 +85,7 @@ Instruction::Instruction(const Instruction& instr) { return; case Opcode::AllocTensor: this->alloc_tensor.storage = instr.alloc_tensor.storage; + this->alloc_tensor.offset = instr.alloc_tensor.offset; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); @@ -92,6 +93,7 @@ Instruction::Instruction(const Instruction& instr) { return; case Opcode::AllocTensorReg: this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage; + this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return; @@ -174,7 +176,8 @@ Instruction& Instruction::operator=(const Instruction& instr) { this->result = instr.result; return *this; case Opcode::AllocTensor: - this->alloc_tensor.storage = instr.alloc_tensor.storage; + this->alloc_tensor.storage = this->alloc_tensor.storage; + this->alloc_tensor.offset = instr.alloc_tensor.offset; this->alloc_tensor.ndim = instr.alloc_tensor.ndim; this->alloc_tensor.shape = Duplicate(instr.alloc_tensor.shape, instr.alloc_tensor.ndim); @@ -182,6 +185,7 @@ Instruction& Instruction::operator=(const Instruction& instr) { return *this; case Opcode::AllocTensorReg: this->alloc_tensor_reg.storage = instr.alloc_tensor_reg.storage; + this->alloc_tensor_reg.offset = instr.alloc_tensor_reg.offset; this->alloc_tensor_reg.shape_register = instr.alloc_tensor_reg.shape_register; this->alloc_tensor_reg.dtype = instr.alloc_tensor_reg.dtype; return *this; @@ -307,12 +311,14 @@ Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index out return instr; } -Instruction Instruction::AllocTensor(RegName storage, const std::vector& shape, - DLDataType dtype, Index dst) { +Instruction Instruction::AllocTensor(RegName storage, RegName offset, + const std::vector& shape, DLDataType dtype, + Index dst) { Instruction instr; instr.op = Opcode::AllocTensor; instr.dst = dst; instr.alloc_tensor.storage = storage; + instr.alloc_tensor.offset = offset; instr.alloc_tensor.ndim = shape.size(); instr.alloc_tensor.shape = new int64_t[shape.size()]; for (size_t i = 0; i < shape.size(); ++i) { @@ -322,12 +328,13 @@ Instruction Instruction::AllocTensor(RegName storage, const std::vector return instr; } -Instruction Instruction::AllocTensorReg(RegName storage, RegName shape_register, DLDataType dtype, - Index dst) { +Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName shape_register, + DLDataType dtype, Index dst) { Instruction instr; instr.op = Opcode::AllocTensorReg; instr.dst = dst; instr.alloc_tensor_reg.storage = storage; + instr.alloc_tensor_reg.offset = offset; instr.alloc_tensor_reg.shape_register = shape_register; instr.alloc_tensor_reg.dtype = dtype; return instr; @@ -514,13 +521,15 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::AllocTensor: { - os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " [" + os << "alloc_tensor $" << instr.dst << " $" << instr.alloc_tensor.storage << " $" + << instr.alloc_tensor.offset << " [" << StrJoin(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) << "] "; DLDatatypePrint(os, instr.alloc_tensor.dtype); break; } case Opcode::AllocTensorReg: { os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.storage << " $" + << instr.alloc_tensor_reg.storage << " $" << instr.alloc_tensor_reg.offset << " $" << instr.alloc_tensor_reg.shape_register << " "; DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); break; @@ -610,6 +619,36 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { return src; } +std::vector ToShape(NDArray shape_tensor) { + std::vector shape; + auto rank = shape_tensor.Shape().size(); + auto dtype = shape_tensor.DataType(); + + // For 0-rank shapes we need to allocate a single scalar. + if (rank == 0) { + return shape; + } + + // Otherwise we should be rank-1, and we will extract the number of dimensions + // for the output vector. + CHECK_EQ(rank, 1U) << "shape tensor should be a k-length vector, found " << rank; + int64_t ndim = shape_tensor.Shape().at(0); + shape.resize(ndim); + + const DLTensor* dl_tensor = shape_tensor.operator->(); + if (dtype.is_int() && dtype.bits() == 32 && dtype.lanes() == 1) { + int32_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else if (dtype.is_int() && dtype.bits() == 64 && dtype.lanes() == 1) { + int64_t* dims = reinterpret_cast(dl_tensor->data); + shape.assign(dims, dims + ndim); + } else { + LOG(FATAL) << "invalid shape tensor datatype: " << dtype; + } + + return shape; +} + PackedFunc VirtualMachine::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { if (name == "invoke") { @@ -945,8 +984,9 @@ void VirtualMachine::RunLoop() { } auto storage_obj = ReadRegister(instr.alloc_tensor.storage); + auto offset = LoadScalarInt(instr.alloc_tensor.offset); auto storage = Downcast(storage_obj); - auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor.dtype); + auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor.dtype); WriteRegister(instr.dst, obj); pc_++; @@ -959,17 +999,11 @@ void VirtualMachine::RunLoop() { auto shape_tensor_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); const auto shape_arr = Downcast(shape_tensor_obj); NDArray shape_tensor = shape_arr.CopyTo(cpu_ctx); - const DLTensor* dl_tensor = shape_tensor.operator->(); - CHECK_EQ(dl_tensor->dtype.code, 0u); - CHECK_LE(dl_tensor->dtype.bits, 64); - int64_t* dims = reinterpret_cast(dl_tensor->data); - auto num_dims = shape_tensor->shape[0]; - auto shape = std::vector(num_dims); - shape.assign(dims, dims + num_dims); - + auto shape = ToShape(shape_tensor); auto storage_obj = ReadRegister(instr.alloc_tensor_reg.storage); auto storage = Downcast(storage_obj); - auto obj = storage->AllocNDArray(0, shape, instr.alloc_tensor_reg.dtype); + auto offset = LoadScalarInt(instr.alloc_tensor.offset); + auto obj = storage->AllocNDArray(offset, shape, instr.alloc_tensor_reg.dtype); WriteRegister(instr.dst, obj); pc_++; diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 614041401026..2e61b4c62c73 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -64,6 +64,7 @@ def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data) mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) + with relay.build_config(opt_level=1): graph, lib, params = relay.build(mod, target, @@ -1667,7 +1668,7 @@ def verify_prelu(x_shape, a_shape): onnx_out = get_onnxruntime_output(model, [indata, slopedata]) for target, ctx in [('llvm', tvm.cpu())]: - tvm_out = get_tvm_output(model, [indata, slopedata], target, ctx, list(x_shape), + tvm_out = get_tvm_output(model, [indata, slopedata], target, ctx, list(x_shape), output_dtype='float32') tvm.testing.assert_allclose(onnx_out[0], tvm_out, rtol=1e-05, atol=1e-05) @@ -2572,7 +2573,7 @@ def verify_topk(input_dims, K, axis=-1): inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list(input_dims)), helper.make_tensor_value_info("K", TensorProto.INT64, [1,])], initializer=[helper.make_tensor("K", TensorProto.INT64, [1], [K])], - outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), + outputs=[helper.make_tensor_value_info("Values", TensorProto.FLOAT, output_dims), helper.make_tensor_value_info("Indicies", TensorProto.INT64, output_dims)]) model = helper.make_model(graph, producer_name='topk_test') @@ -2581,10 +2582,10 @@ def verify_topk(input_dims, K, axis=-1): onnx_out = get_onnxruntime_output(model, [indata, k]) for target, ctx in [('llvm', tvm.cpu())]: - tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims], + tvm_out = get_tvm_output(model, indata, target, ctx, [output_dims, output_dims], output_dtype=['float32', 'int64']) tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05) - + for n in [12, 32]: for shape in [[n], [n, n], [n, n, n]]: for k in [1, 5, 10]: @@ -2593,7 +2594,7 @@ def verify_topk(input_dims, K, axis=-1): verify_topk([n, n, n], 5, 0) verify_topk([n, n, n], 5, 1) verify_topk([n, n, n], 5, 2) - + def test_roi_align(): def verify_roi_align(input_dims, num_roi, output_height, output_width, sampling_ratio=0, spatial_scale=1.0): diff --git a/tests/python/relay/test_pass_memory_alloc.py b/tests/python/relay/test_memory_passes.py similarity index 62% rename from tests/python/relay/test_pass_memory_alloc.py rename to tests/python/relay/test_memory_passes.py index c3c53121d934..70e7086cef4d 100644 --- a/tests/python/relay/test_pass_memory_alloc.py +++ b/tests/python/relay/test_memory_passes.py @@ -18,21 +18,38 @@ from tvm import te import numpy as np from tvm import relay -from tvm.relay.transform import memory_alloc +from tvm.relay import memory_alloc -def check_vm_alloc(func, check_fn): - mod = tvm.IRModule() - mod['main'] = func - ex = relay.create_executor('vm', mod) +def check_memory_plan(func, check_fn): + # Build Module + mod = tvm.IRModule().from_expr(func) + + # Convert arguments. args = [] for param in func.params: param = param.type_annotation sh = [int(sh) for sh in param.shape] data = np.random.rand(*sh).astype(param.dtype) args.append(tvm.nd.array(data)) - result = ex.evaluate(mod['main'])(*args) + + # Compute without memory planning. + ex = relay.create_executor('vm', mod) + no_plan_result = ex.evaluate(mod['main'])(*args) + + # Compute with memory planning. + with relay.build_config(opt_level=1, disabled_pass=["MemoryPlan"]): + plan_result = ex.evaluate(mod['main'])(*args) + + # Compute Python result. py_res = check_fn(*[arg.asnumpy() for arg in args]) - np.testing.assert_allclose(result.asnumpy(), py_res) + + # First check that the two VM results agree. + np.testing.assert_allclose( + no_plan_result.asnumpy(), + plan_result.asnumpy()) + + # Finally check that the results match the Python result. + np.testing.assert_allclose(plan_result.asnumpy(), py_res) def storage_type(mod): return relay.TypeCall(mod.get_global_type_var("Storage"), []) @@ -46,7 +63,7 @@ def test_tyck_alloc_tensor(): mod.import_from_std("core.rly") sto = relay.Var("x", storage_type(mod)) sh = relay.const(np.array([1, 2]), dtype="int64") - at = relay.op.memory.alloc_tensor(sto, sh) + at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh) mod['main'] = relay.Function([sto], at) relay.transform.InferType()(mod) @@ -58,20 +75,34 @@ def test_add(): x = relay.var('x', shape=(2,)) z = x + x func = relay.Function([x,], z) - check_vm_alloc(func, check_add) + check_memory_plan(func, check_add) def check_add_sub(x, y): z = x + x return z - y + def test_add_sub(): x = relay.var('x', shape=(10,)) y = relay.var('y', shape=(10,)) z = x + x z = z - y func = relay.Function([x, y], z) - check_vm_alloc(func, check_add_sub) + check_memory_plan(func, check_add_sub) + +def check_no_fuse(x, y, w): + z = x + y + return np.matmul(z, np.transpose(w)) + +def test_no_fuse(): + x = relay.var('x', shape=(5, 1)) + y = relay.var('y', shape=(5, 1)) + w = relay.var('w', shape=(5, 1)) + z = x + y + out = relay.op.nn.dense(z, w) + func = relay.Function([x, y, w], out) + check_memory_plan(func, check_no_fuse) if __name__ == "__main__": test_tyck_alloc_tensor()