diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 1d2fa5472993f..18ec944f54f52 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -184,6 +184,26 @@ class VarNode : public ExprNode { RELAY_DEFINE_NODE_REF(Var, VarNode, Expr); +/*! \brief Hash Var by it's id. + * Different VarNode might has same vid, and they are considered to be the same var in such case. + * Use VarHash to hash Var by id. + */ +struct VarHash { + size_t operator()(const Var& v) const { + return v->vid.hash(); + } +}; + +/*! \brief Compare Var by it's id. + * Different VarNode might has same vid, and they are considered to be the same var in such case. + * Use VarEqual to compare Var by id. + */ +struct VarEqual { + bool operator()(const Var& l, const Var& r) const { + return l->vid.get() == r->vid.get(); + } +}; + /*! * \brief Global variable that leaves in the top-level module. * This is used to enable recursive calls between function. @@ -521,7 +541,7 @@ RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr); * rewriting pass such as layout or type transformation. * * Subclass TempExprNode allows us to pattern match on - * specific kind TempExpr and use them for expression rewriting. + * specific kind of TempExpr and use them for expression rewriting. * * TempExpr should only be used within a pass, */ @@ -539,6 +559,25 @@ class TempExprNode : public ExprNode { RELAY_DEFINE_NODE_REF(TempExpr, TempExprNode, Expr); +class Annotate; +class AnnotateNode : public ExprNode { + public: + Expr expr; + NodeRef annotation; + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("expr", &expr); + v->Visit("annotation", &annotation); + v->Visit("_checked_type_", &checked_type_); + } + + TVM_DLL static Annotate make(Expr expr, NodeRef annotation); + + static constexpr const char* _type_key = "relay.AnnotateNode"; + TVM_DECLARE_NODE_TYPE_INFO(AnnotateNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(Annotate, AnnotateNode, Expr); + // implementataions inline const Type& ExprNode::checked_type() const { CHECK(checked_type_.defined()) << "internal error: the type checker has " diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 3b179f8e53300..d3154d28bb272 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -116,6 +116,7 @@ class ExprFunctor { virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const ConstructorNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const MatchNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const AnnotateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -140,6 +141,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); RELAY_EXPR_FUNCTOR_DISPATCH(ConstructorNode); RELAY_EXPR_FUNCTOR_DISPATCH(MatchNode); + RELAY_EXPR_FUNCTOR_DISPATCH(AnnotateNode); return vtable; } }; @@ -170,6 +172,7 @@ class ExprVisitor void VisitExpr_(const RefWriteNode* op) override; void VisitExpr_(const ConstructorNode* op) override; void VisitExpr_(const MatchNode* op) override; + void VisitExpr_(const AnnotateNode* op) override; virtual void VisitType(const Type& t); virtual void VisitClause(const Clause& c); virtual void VisitPattern(const Pattern& c); @@ -212,6 +215,7 @@ class ExprMutator Expr VisitExpr_(const RefWriteNode* op) override; Expr VisitExpr_(const ConstructorNode* op) override; Expr VisitExpr_(const MatchNode* op) override; + Expr VisitExpr_(const AnnotateNode* op) override; /*! * \brief Used to visit the types inside of expressions. diff --git a/python/tvm/relay/network.py b/python/tvm/relay/network.py new file mode 100644 index 0000000000000..cc3ca35b9b934 --- /dev/null +++ b/python/tvm/relay/network.py @@ -0,0 +1,159 @@ +import numpy as np +import tvm +from tvm import relay +from tvm.relay import op +from tvm.relay import create_executor, Module +from tvm.relay.backend.interpreter import TensorValue +from tvm.relay.prelude import Prelude +import aot +import collections + +class OrderedSet(collections.MutableSet): + + def __init__(self, iterable=None): + self.end = end = [] + end += [None, end, end] # sentinel node for doubly linked list + self.map = {} # key --> [key, prev, next] + if iterable is not None: + self |= iterable + + def __len__(self): + return len(self.map) + + def __contains__(self, key): + return key in self.map + + def add(self, key): + if key not in self.map: + end = self.end + curr = end[1] + curr[2] = end[1] = self.map[key] = [key, curr, end] + + def discard(self, key): + if key in self.map: + key, prev, next = self.map.pop(key) + prev[2] = next + next[1] = prev + + def __iter__(self): + end = self.end + curr = end[2] + while curr is not end: + yield curr[0] + curr = curr[2] + + def __reversed__(self): + end = self.end + curr = end[1] + while curr is not end: + yield curr[0] + curr = curr[1] + + def pop(self): + key = self.last() + self.discard(key) + return key + + def last(self): + return self.end[1][0] + + def __repr__(self): + if not self: + return '%s()' % (self.__class__.__name__,) + return '%s(%r)' % (self.__class__.__name__, list(self)) + + def __eq__(self, other): + if isinstance(other, OrderedSet): + return len(self) == len(other) and list(self) == list(other) + return set(self) == set(other) + +def initialize(param): + ty = param.type_annotation + shape = [int(i) for i in ty.shape] + return np.random.normal(0, 1, shape).astype('float32') + +def copy_var(v): + return relay.Var(v.name_hint, v.type_annotation) + +class Network: + stack = [] + cnt = 0 + + def __init__(self, *, name="f", **kwargs): + name = f"{name}_{Network.cnt}" + Network.cnt += 1 + if len(Network.stack) is not 0: + mod = Network.stack[-1].mod + p = Network.stack[-1].p + else: + mod = Module() + p = Prelude(mod) + + self.mod = mod + self.p = p + self.inputs = [] + self.weights = OrderedSet() + self.sub_network = OrderedSet() + self.f = relay.GlobalVar(name) + self.recurse = relay.Var("recurse") + self.use_recurse = False + self.ret_type = None + body = self.build(**kwargs) + assert isinstance(body, relay.Expr) + if self.use_recurse: + inputs = [copy_var(v) for v in self.inputs] + body = relay.Let(self.recurse, relay.Function(inputs, self.call_from_outside(*inputs)), body) + self.mod[self.f] = relay.Function(self.inputs + self.all_weights(), body, self.ret_type) + + def build(self, **kwargs): + Network.stack.append(self) + try: + return self.build_impl(**kwargs) + finally: + Network.stack.pop() + + def build_impl(self, *args): + raise NotImplementedError + + def weight(self, w): + assert isinstance(w, relay.Var) + self.weights.add(w) + return w + + def input(self, i): + assert isinstance(i, relay.Var) + self.inputs.append(i) + return i + + def all_weights(self): + return list(set(list(self.weights) + [w for n in self.sub_network for w in n.all_weights()])) + + def call_from_outside(self, *inputs): + return self.f(*(list(inputs) + self.all_weights())) + + def __call__(self, *inputs): + if self in Network.stack: + self.use_recurse = True + return self.recurse(*inputs) + else: + assert len(Network.stack) > 0 + assert Network.stack[-1].mod == self.mod + assert Network.stack[-1].p == self.p + Network.stack[-1].sub_network.add(self) + return self.call_from_outside(*inputs) + + def interface_type(self): + t = relay.ir_pass.infer_type(self.mod[self.f], mod=self.mod).checked_type + return relay.FuncType(t.arg_types[:len(self.inputs)], t.ret_type, t.type_params, t.type_constraints) + + def get(self): + weights = [] + for x in self.all_weights(): + ty = x.type_annotation + assert isinstance(ty, relay.TensorType) + assert ty.dtype == 'float32' + shape = [int(i) for i in ty.shape] + weight = relay.const(np.random.normal(0, 1, shape).astype('float32')) + weights.append(weight) + inputs = [copy_var(v) for v in self.inputs] + return relay.Function(inputs, self.f(*inputs, *weights)) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 5a47b1d42ed31..f1e3942743dbb 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -74,7 +74,7 @@ def schedule_batch_matmul(attrs, outputs, target): with target: return topi.generic.schedule_batch_matmul(outputs) -reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) +reg.register_pattern("nn.batch_matmul", reg.OpPattern.OPAQUE) # conv2d diff --git a/python/tvm/relay/test_network.py b/python/tvm/relay/test_network.py new file mode 100644 index 0000000000000..b84895503b784 --- /dev/null +++ b/python/tvm/relay/test_network.py @@ -0,0 +1,93 @@ +from .network import Network +from tvm import relay +from tvm.relay import op, var, Var, Function, Clause, PatternConstructor, PatternVar, Match +from tvm.relay import TupleGetItem, Tuple, TensorType, TupleType + +class Linear(Network): + def build_impl(self, input_size, output_size, dtype="float32"): + x = self.input(var("linear_input", shape=(1, input_size), dtype=dtype)) + w = self.weight(var("linear_weight", shape=(output_size, input_size), dtype=dtype)) + b = self.weight(var("linear_bias", shape=(output_size,), dtype=dtype)) + return op.add(op.nn.dense(x, w), b) + +def lam(names, func): + args = [Var(name) for name in names] + return Function(args, func(*args)) + +class LSTMCell(Network): + def build_impl(self, input_size, memory_size, dtype="float32"): + t = TensorType(shape=(1, memory_size), dtype=dtype) + i = self.input(var("lstmcell_input", shape=(1, input_size), dtype=dtype)) + c = self.input(Var("lstmcell_children", self.p.l(TupleType([t, t])))) + sum = lam(["x", "y"], lambda x, y: x + y) + child_h_sum = self.p.foldl(sum, + op.zeros(shape=(1, memory_size), dtype=dtype), + self.p.map(lam(["z"], lambda z: TupleGetItem(z, 1)), c)) + ioux = Linear(input_size=input_size, output_size=memory_size * 3)(i) + iouh = Linear(input_size=memory_size, output_size=memory_size * 3)(child_h_sum) + iou = ioux + iouh + fx = Linear(input_size=input_size, output_size=memory_size)(i) + fh = Linear(input_size=memory_size, output_size=memory_size) + i, o, u = op.split(iou, 3, axis=1) + i, o, u = op.sigmoid(i), op.sigmoid(o), op.tanh(u) + def foreach_children(children): + f = op.sigmoid(fh(TupleGetItem(children, 1)) + fx) + return f * TupleGetItem(children, 0) + c = self.p.foldl(sum, i * u, self.p.map(lam(["z"], foreach_children), c)) + return Tuple([c, o * op.tanh(c)]) + +class LSTMEncoder(Network): + def build_impl(self, input_size, memory_size, dtype="float32"): + l = self.input(Var("l", self.p.l(TensorType(shape=(1, input_size), dtype=dtype)))) + cell = LSTMCell(input_size=input_size, memory_size=memory_size, dtype=dtype) + return self.p.foldl(lam(["c", "x"], lambda c, x: cell(x, self.p.cons(c, self.p.nil()))), + Tuple([op.zeros(shape=(1, memory_size), dtype=dtype), + op.zeros(shape=(1, memory_size), dtype=dtype)]), l) + +class LSTMTransformer(Network): + def build_impl(self, input_size, memory_size, dtype="float32"): + l = self.input(Var("l", self.p.l(TensorType(shape=(1, input_size), dtype=dtype)))) + def f(c, x): + cell = LSTMCell(input_size=input_size, memory_size=memory_size, dtype=dtype) + o = cell(x, self.p.cons(c, self.p.nil())) + return Tuple([o, TupleGetItem(o, 1)]) + res = self.p.map_accuml(lam(["c", "x"], f), + Tuple([op.zeros(shape=(1, memory_size), dtype=dtype), + op.zeros(shape=(1, memory_size), dtype=dtype)]), + l) + return Tuple([TupleGetItem(TupleGetItem(res, 0), 1), TupleGetItem(res, 1)]) + +class TreeLSTM(Network): + def build_impl(self, input_size, memory_size, dtype="float32"): + t = TensorType(shape=(1, memory_size), dtype=dtype) + self.ret_type = TupleType([t, t]) + tree_type = self.p.tree(TensorType(shape=(1, input_size), dtype=dtype)) + t = self.input(Var("tlstm_input", tree_type)) + i = Var("i", TensorType(shape=(1, input_size), dtype=dtype)) + c = Var("c", self.p.l(tree_type)) + cell = LSTMCell(input_size=input_size, memory_size=memory_size, dtype=dtype) + rose_case = Clause(PatternConstructor(self.p.rose, [PatternVar(i), PatternVar(c)]), + cell(i, self.p.map(lam(["x"], self), c))) + return Match(t, [rose_case]) + +class BiLSTM(Network): + def build_impl(self, input_size, memory_size, dtype="float32"): + l = self.input(Var("l", self.p.l(TensorType(shape=(1, input_size), dtype=dtype)))) + def LSTM(l): + return LSTMTransformer(input_size=input_size, + memory_size=memory_size, + dtype=dtype)(l) + fwd = LSTM(l) + rev = LSTM(self.p.rev(l)) + lhs = op.concatenate([TupleGetItem(fwd, 0), TupleGetItem(rev, 0)], axis=1) + t = TensorType(shape=(1, memory_size), dtype=dtype) + x = Var("x", TupleType([t, t])) # cannot infer here + rhs = self.p.map(Function([x], op.concatenate([TupleGetItem(x, 0), + TupleGetItem(x, 1)], + axis=1)), + self.p.zip(TupleGetItem(fwd, 1), TupleGetItem(rev, 1))) + return Tuple([lhs, rhs]) + +# t = BiLSTM(input_size=128, memory_size=256) +# print("type of BidirectionalLSTM, with input_size=128, memory_size=256, is:") +# print(t.interface_type()) diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3108bc2501fed..422163758a2f2 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -232,8 +232,7 @@ TVM_REGISTER_API("relay._make.Call") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const CallNode* node, tvm::IRPrinter* p) { - p->stream << "CallNode(" << node->op << ", " << node->args << ", " - << node->attrs << ", " << node->type_args << ")"; + p->stream << "CallNode(" << node->op << ")"; }); Let LetNode::make(Var var, Expr value, Expr body) { @@ -349,5 +348,17 @@ TVM_REGISTER_API("relay._expr.TempExprRealize") *ret = temp->Realize(); }); +Annotate AnnotateNode::make(Expr expr, NodeRef annotation) { + NodePtr n = make_node(); + n->expr = std::move(expr); + n->annotation = std::move(annotation); + return Annotate(n); +} + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const AnnotateNode* node, tvm::IRPrinter* p) { + p->stream << "AnnotateNode(" << node->expr << ")"; + }); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index d0cd30adda29f..aaaf34d261a17 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -221,6 +221,10 @@ Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; } Type ExprMutator::VisitType(const Type& t) { return t; } +Expr ExprMutator::VisitExpr_(const AnnotateNode* op) { + return AnnotateNode::make(VisitExpr(op->expr), op->annotation); +} + void ExprVisitor::VisitExpr(const Expr& expr) { auto it = visit_counter_.find(expr.get()); if (it != visit_counter_.end()) { @@ -315,6 +319,10 @@ void ExprVisitor::VisitExpr_(const MatchNode* op) { } } +void ExprVisitor::VisitExpr_(const AnnotateNode* op) { + this->VisitExpr(op->expr); +} + void ExprVisitor::VisitClause(const Clause& op) { this->VisitPattern(op->lhs); this->VisitExpr(op->rhs); diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 969f08b32e830..2b043fc7f4189 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -585,6 +585,24 @@ class PrettyPrinter : return doc << "), " << PrintDType(node->dtype) << "]"; } + Doc VisitType_(const GlobalTypeVarNode* node) final { + Doc doc; + doc << node->var->name_hint; + return doc; + } + + Doc VisitType_(const TypeCallNode* node) final { + Doc doc = PrintType(node->func, false); + std::vector args; + for (const Type& t : node->args) { + args.push_back(PrintType(t, false)); + } + doc << "("; + doc << PrintVec(args); + doc << ")"; + return doc; + } + Doc VisitType_(const TupleTypeNode* node) final { std::vector fields; for (Type field : node->fields) { diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index fb0d919b46c38..718bad63693b1 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -113,6 +113,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) }); TypeCall TypeCallNode::make(Type func, tvm::Array args) { + CHECK(func.as()); NodePtr n = make_node(); n->func = std::move(func); n->args = std::move(args); diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index d24431347f808..c2f24e2179bef 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -683,7 +683,7 @@ bool BatchMatmulRel(const Array& types, const auto* x = types[0].as(); const auto* y = types[1].as(); if (x == nullptr || y == nullptr) return false; - if (x->shape.size() != 3 || y->shape.size() != 3) return false; + CHECK (x->shape.size() == 3 && y->shape.size() == 3); CHECK(reporter->AssertEQ(x->shape[0], y->shape[0])) << "BatchDot: batch dimension doesn't match, " << " x shape=" << x->shape diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f86156bdbddcd..0f83f2cf194f2 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -729,9 +729,19 @@ bool TakeRel(const Array& types, // `types` contains: [data, indices, result] CHECK_EQ(types.size(), 3); const auto* data = types[0].as(); - CHECK(data != nullptr); + if (data == nullptr) { + CHECK(types[0].as()) + << "must be tensor type or incomplete type"; + return false; + } + const auto* indices = types[1].as(); - CHECK(indices != nullptr); + if (indices == nullptr) { + CHECK(types[1].as()) + << "must be tensor type or incomplete type"; + return true; + } + const auto param = attrs.as(); CHECK(param != nullptr); diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc index 06cd9091749bc..091e7cf836033 100644 --- a/src/relay/pass/dead_code.cc +++ b/src/relay/pass/dead_code.cc @@ -1,22 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - /*! * Copyright (c) 2018 by Contributors * @@ -27,28 +8,278 @@ * The algorithm is implemented by two visitor: * CalcDep turn an expr into a dependency graph of expr, * GenLet turn the dependency graph into a let list, taking only the used value. + * + * Also, Dead Code Eliminator has to take into account of effect - + * Call to foreign function should not be eliminated. + * Write to reference should not be eliminated if that reference is used. + * + * To do this we implement a simple escape analysis. + * We abstract Reference Value point to StoreId. + * Each RefCreate get a unique StoreId, + * And also assign each parameter a unique StoreId (as they might has/contain Ref). + * We then create a map of Expr -> Set StoreId, which record what StoreId Expr might depend on. + * The map is ran until a Fixpoint (it will terminate as there are finite StoreId.). + * The StoreId inside the inputs and the body are all the StoreId that is alive, + * and effect to other StoreId can be removed. + * + * We choose to implement StoreId as Expr for simplicity. + * + * Whenever a function is called, or a reference is written into, + * We make the set of reference inside depend on that call/write. */ #include #include +#include #include "let_list.h" namespace tvm { namespace relay { +using Sid = Expr; +using SidSet = std::unordered_set; +using ExprSet = std::unordered_set; +template +using ExprMap = std::unordered_map; +template +using VarMap = std::unordered_map; +using VarSet = std::unordered_set; + +struct EscapeAnalysis : ExprFunctor, + PatternFunctor { + ExprMap map_; + SidSet live_sid_; + ExprSet root_expr_; + bool converge = false; + bool HasEffect(const Expr& e) { + struct EffectVisitor : ExprVisitor { + EscapeAnalysis* ea_; + explicit EffectVisitor(EscapeAnalysis* ea) : ea_(ea) { } + bool has_effect = false; + void Touch(const Expr& e) { + for (const Sid& s: ea_->Get(e)) { + has_effect |= (ea_->live_sid_.count(s) > 0); + if (ea_->live_sid_.count(s) > 0) { + std::cout << "HAS EFFECT" << std::endl; + return; + } + } + } + void VisitExpr_(const RefReadNode* op) final { + Touch(op->ref); + VisitExpr(op->ref); + } + void VisitExpr_(const RefWriteNode* op) final { + Touch(op->ref); + VisitExpr(op->ref); + VisitExpr(op->value); + } + void VisitExpr_(const CallNode* op) final { + // The args contain same sid as op, so no need to touch them. + Touch(op->op); + VisitExpr(op->op); + for (const Expr& arg: op->args) { + VisitExpr(arg); + } + } + void VisitExpr_(const FunctionNode* op) final { } + }; + std::cout << "CHECK EFFECT:" << e << std::endl; + EffectVisitor ev(this); + ev(e); + return ev.has_effect; + } + explicit EscapeAnalysis(const Expr& e) { + for (const Var& v: FreeVars(e)) { + AllocRoot(v); + } + while (!converge) { + converge = true; + Analysis(e); + } + Alive(e); + for (const Expr& r: root_expr_) { + Alive(r); + } + while (!converge) { + converge = true; + std::vector live_sid_old; + for (const Sid& s: live_sid_) { + live_sid_old.push_back(s); + } + for (const Sid& s: live_sid_old) { + Alive(s); + } + } + } + void Alive(const Expr& e) { + for (const Sid& s: Get(e)) { + if (live_sid_.count(s) == 0) { + converge = false; + live_sid_.insert(s); + } + } + } + void Analysis(const Expr& e) { + VisitExpr(e, e); + } + ExprSet& Get(const Expr& e) { + if (map_.count(e) == 0) { + map_.insert({e, ExprSet()}); + } + return map_.at(e); + } + std::vector Range(const Expr& e) { + std::vector ret; + for (const auto& x: Get(e)) { + ret.push_back(x); + } + return ret; + } + void Insert(const Expr& from, const Expr& to) { + ExprSet& x = Get(from); + if (x.count(to) == 0) { + converge = false; + x.insert(to); + } + } + void Join(const Expr& from, const Expr& to) { + for (const Expr& e: Range(to)) { + Insert(from, e); + } + } + void Write(const Expr& from, const Expr& to) { + for (const Expr& e: Range(from)) { + Join(e, to); + } + } + void Alloc(const Expr& e) { + Insert(e, e); + } + void Root(const Expr& e) { + root_expr_.insert(e); + } + void AllocRoot(const Expr& e) { + Alloc(e); + Root(e); + } + void Depend(const Expr& val, const Expr& on) { + Analysis(on); + Join(val, on); + } + void VisitExpr_(const RefCreateNode* op, const Expr& e) final { + AllocRoot(e); + Depend(e, op->value); + } + void VisitExpr_(const RefWriteNode* op, const Expr& e) final { + Write(e, op->ref); + Analysis(op->ref); + Analysis(op->value); + } + void VisitExpr_(const FunctionNode* op, const Expr& e) final { + for (const Var& v: op->params) { + AllocRoot(v); + } + Root(op->body); + Depend(e, op->body); + } + void VisitExpr_(const CallNode* op, const Expr& e) final { + std::vector exprs; + Depend(e, op->op); + exprs.push_back(op->op); + for (const Expr& arg: op->args) { + Depend(e, arg); + exprs.push_back(arg); + } + for (size_t i = 0; i < exprs.size(); ++i) { + for (size_t j = i + 1; j < exprs.size(); ++j) { + Write(exprs[i], exprs[j]); + Write(exprs[j], exprs[i]); + } + } + } + void RecordVar(const Var& v) { + Get(v); + } + void VisitExpr_(const LetNode* op, const Expr& e) final { + RecordVar(op->var); + Depend(op->var, op->value); + Depend(e, op->body); + } + // From here on the uninteresting case: just declare Depend on children + void VisitExpr_(const VarNode* op, const Expr& e) final { + CHECK_GT(map_.count(GetRef(op)), 0); + } + void VisitExpr_(const ConstructorNode* op, const Expr& e) final { } + void VisitExpr_(const OpNode* op, const Expr& e) final { + // TODO(@M.K.): handle stateful op + } + void VisitExpr_(const ConstantNode* op, const Expr& e) final { } + void VisitExpr_(const GlobalVarNode* op, const Expr& e) final { } + void VisitExpr_(const MatchNode* op, const Expr& e) final { + Depend(e, op->data); + for (const Clause& c: op->clauses) { + VisitPattern(c->lhs, op->data); + Depend(e, c->rhs); + } + } + void VisitPattern_(const PatternWildcardNode* op, const Expr& e) final { } + void VisitPattern_(const PatternVarNode* op, const Expr& e) final { + Depend(op->var, e); + } + void VisitPattern_(const PatternConstructorNode* op, const Expr& e) final { + for (const Pattern& pat: op->patterns) { + VisitPattern(pat, e); + } + } + void VisitExpr_(const RefReadNode* op, const Expr& e) final { + Depend(e, op->ref); + } + void VisitExpr_(const TupleNode* op, const Expr& e) final { + for (const Expr& c: op->fields) { + Depend(e, c); + } + } + void VisitExpr_(const TupleGetItemNode* op, const Expr& e) final { + Depend(e, op->tuple); + } + void VisitExpr_(const IfNode* op, const Expr& e) final { + Depend(e, op->cond); + Depend(e, op->true_branch); + Depend(e, op->false_branch); + } +}; + + // calculate the dependency graph from expression class CalcDep : private ExprVisitor { public: - static Expr Eliminate(const Expr& e) { - CalcDep cd; - cd.Calculate(e); - Eliminator el(cd.expr_map_, cd.use_map_, cd.letrec_set_); - return el(e); + explicit CalcDep(const Expr& v) { + VisitExpr(v); + return; + count_ = false; + while (!dead_worklist_.empty()) { + Var dead = *(dead_worklist_.begin()); + dead_worklist_.erase(dead); + CHECK_EQ(use_map_[dead], 0); + if (expr_map_.count(dead) > 0) { + LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead); + } + } + } + + bool Used(const Var& v) { + return use_map_[v] > 0; + } + + bool HasLet(const Var& v) { + return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); + } + + Expr Map(const Var& v) { + return expr_map_.count(v) == 0 ? Expr(v) : expr_map_[v]; } private: - template - using VarMap = std::unordered_map; - using VarSet = std::unordered_set; VarMap expr_map_; VarMap use_map_; VarSet letrec_set_; @@ -98,53 +329,43 @@ class CalcDep : private ExprVisitor { letrec_set_.insert(var); } } +}; - void Calculate(const Expr& v) { - VisitExpr(v); - count_ = false; - while (!dead_worklist_.empty()) { - Var dead = *(dead_worklist_.begin()); - dead_worklist_.erase(dead); - CHECK_EQ(use_map_[dead], 0); - if (expr_map_.count(dead) > 0) { - LetRec([&]() { VisitExpr(expr_map_[dead]); }, dead); - } - } +class Eliminator : private ExprMutator { + public: + static Expr Eliminate(const Expr& e) { + Eliminator elm(e); + return elm(e); } + private: + EscapeAnalysis ea_; + CalcDep cd_; + explicit Eliminator(const Expr& e) : ea_(e), cd_(e) { } + friend CalcDep; - class Eliminator : private ExprMutator { - private: - VarMap expr_map_; - VarMap use_map_; - VarSet letrec_set_; - explicit Eliminator(const VarMap& expr_map, - const VarMap& use_map, - const VarSet& letrec_set) : - expr_map_(expr_map), use_map_(use_map), letrec_set_(letrec_set) { } - friend CalcDep; - - bool HasLet(const Var& v) { - return (use_map_[v] > 1 || (use_map_[v] != 0 && letrec_set_.count(v) != 0)); - } - - Expr VisitExpr_(const VarNode* op) final { - Var v = GetRef(op); - return (expr_map_.count(v) == 0 || HasLet(v)) ? v : VisitExpr(expr_map_[v]); - } + Expr VisitExpr_(const VarNode* op) final { + Var v = GetRef(op); + return v; + std::cout << v << " map to " << cd_.Map(v) << std::endl; + return (cd_.Used(v) || ea_.HasEffect(cd_.Map(v))) ? v : cd_.Map(v); + } - Expr VisitExpr_(const LetNode* op) final { - Var v = op->var; - if (HasLet(v)) { - return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); - } else { - return VisitExpr(op->body); - } + Expr VisitExpr_(const LetNode* op) final { + Var v = op->var; + CHECK_EQ(cd_.Map(v), op->value); + if (cd_.Used(v) || ea_.HasEffect(op->value)) { + return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body)); + } else { + return VisitExpr(op->body); } - }; + } + //Expr VisitExpr_(const IfNode* op) final { + // return IfNode::make(op->cond, Descend(op->true_branch), Descend(op->false_branch)); + //} }; Expr DeadCodeElimination(const Expr& e) { - return CalcDep::Eliminate(e); + return Eliminator::Eliminate(e); } TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") diff --git a/src/relay/pass/partial_eval.cc b/src/relay/pass/partial_eval.cc index f6283d380176a..63ad11d2e1182 100644 --- a/src/relay/pass/partial_eval.cc +++ b/src/relay/pass/partial_eval.cc @@ -1,22 +1,3 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - /*! * Copyright (c) 2018 by Contributors * @@ -104,6 +85,7 @@ #include #include #include +#include "../ir/type_functor.h" #include "pass_util.h" #include "let_list.h" @@ -112,26 +94,7 @@ namespace relay { using namespace runtime; -/*! \brief Hash Var by it's id. - * Different VarNode might has same vid, and they are considered to be the same var in such case. - * Use VarHash to hash Var by id. - */ -struct VarHash { - size_t operator()(const Var& v) const { - return v->vid.hash(); - } -}; - -/*! \brief Compare Var by it's id. - * Different VarNode might has same vid, and they are considered to be the same var in such case. - * Use VarEqual to compare Var by id. - */ -struct VarEqual { - bool operator()(const Var& l, const Var& r) const { - return l->vid.get() == r->vid.get(); - } -}; - +Expr PostProcess(const Expr&); /*! \brief The base container type of Relay values. */ class StaticNode : public RelayNode { public: @@ -150,10 +113,20 @@ class Static : public NodeRef { using ContainerType = StaticNode; }; +using Time = size_t; + struct PStaticNode : Node { + static Time time() { + static Time time_ = 0; + Time ret = time_; + time_++; + return ret; + } Static pstatic; // may be null Expr dynamic; - PStaticNode(const Static& pstatic, const Expr& dynamic) : pstatic(pstatic), dynamic(dynamic) { } + Time created_time; + PStaticNode(const Static& pstatic, const Expr& dynamic) : + pstatic(pstatic), dynamic(dynamic), created_time(time()) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); }; @@ -261,7 +234,7 @@ class Environment { } ++rit; } - LOG(FATAL) << "Unknown Variable: " << v; + LOG(FATAL) << "Unknown Variable: " << v << v.as(); throw; } @@ -341,6 +314,7 @@ class Store { }; PStatic HasStatic(const Static& stat, const Expr& dynamic) { + CHECK(stat.defined()); return PStatic(make_node(stat, dynamic)); } @@ -383,15 +357,61 @@ FInterpreter CPUInterpreter() { return CreateInterpreter(Module(nullptr), CPUContext(), target); } -class PartialEvaluator : public ExprFunctor, +bool IsAtomic(const Expr& e) { + return e.as() || e.as() || e.as() || e.as(); +} + +using FuncId = size_t; + +struct WithFuncId; + +struct WithFuncIdNode : Node { + FuncId fid; + WithFuncIdNode(FuncId fid) : fid(fid) { } + static constexpr const char* _type_key = "relay.WithFuncId"; + TVM_DECLARE_NODE_TYPE_INFO(WithFuncIdNode, Node); +}; + +RELAY_DEFINE_NODE_REF(WithFuncId, WithFuncIdNode, NodeRef); + +Annotate MkWithFuncId(const Expr& expr, FuncId fid) { + return AnnotateNode::make(expr, WithFuncId(make_node(fid))); +} + +Expr StripWithFuncId(const Expr& e); + +Expr DeDup(const Expr& e); + +Function AsFunc(const Expr& e) { + if (e.as()) { + return Downcast(e); + } else if (const AnnotateNode* a = e.as()) { + CHECK(a->annotation.as()); + return AsFunc(a->expr); + } else { + LOG(FATAL) << "Unknown case"; + throw; + } +} + +class PartialEvaluator : public ExprFunctor, public PatternFunctor { public: - PartialEvaluator(const tvm::Array& free_vars) { + PartialEvaluator(const tvm::Array& free_vars, + const Module& mod) : + mod_(mod) { for (const Var& v : free_vars) { env_.Insert(v, NoStatic(v)); } } + size_t depth = 0; + PStatic VisitExpr(const Expr& e, LetList* ll) final { + PStatic ret = ExprFunctor::VisitExpr(e, ll); + CHECK(IsAtomic(ret->dynamic)) << ret->dynamic; + return ret; + } + PStatic VisitExpr_(const ConstantNode* op, LetList* ll) final { return HasStatic(MkSTensor(op->data.CopyTo(context_)), ll->Push(GetRef(op))); } @@ -421,7 +441,20 @@ class PartialEvaluator : public ExprFunctor } PStatic VisitExpr_(const GlobalVarNode* op, LetList* ll) final { - return NoStatic(GetRef(op)); + GlobalVar gv = GetRef(op); + if (gv_map_.count(gv) == 0) { + if (mod_.defined()) { + Function func = mod_->Lookup(gv); + InitializeFuncId(func); + Func f = VisitFuncStatic(func, gv); + gv_map_.insert({gv, HasStatic(MkSFunc(f), gv)}); + func = AsFunc(PostProcess(VisitFuncDynamic(func, f))); + mod_->Update(gv, func); + } else { + gv_map_.insert({gv, NoStatic(gv)}); + } + } + return gv_map_.at(gv); } PStatic VisitExpr_(const LetNode* op, LetList* ll) final { @@ -501,19 +534,45 @@ class PartialEvaluator : public ExprFunctor } } - PStatic VisitExpr_(const FunctionNode* op, LetList* ll) final { - Function func = GetRef(op); + PStatic VisitExpr_(const AnnotateNode* op, LetList* ll) final { + CHECK(op->annotation.as()); + return VisitExpr(op->expr, ll); + } + + struct TimeFrame { + PartialEvaluator* pe_; + FuncId fid_; + std::vector