Skip to content

Commit

Permalink
[PASS] Schedule Ops init working version (#6)
Browse files Browse the repository at this point in the history
* [PASS] Schedule Ops init working version

* bugfix in PassUp
  • Loading branch information
tqchen authored and icemelon committed Jan 10, 2017
1 parent 302c2e6 commit 78ea652
Show file tree
Hide file tree
Showing 33 changed files with 499 additions and 222 deletions.
2 changes: 1 addition & 1 deletion HalideIR
Submodule HalideIR updated from 5d1bd1 to 1ec478
3 changes: 3 additions & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace tvm {

using Halide::Type;
using Halide::Float;
using Halide::Bool;
using Halide::Int;
using Halide::UInt;
using Halide::Handle;
Expand All @@ -29,6 +30,8 @@ using Halide::Internal::Stmt;
using Halide::Internal::IRPrinter;
using Halide::Internal::Variable;

using Halide::Internal::make_const;

/*! \brief a named variable in TVM */
class Var : public Halide::VarExpr {
public:
Expand Down
18 changes: 10 additions & 8 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@
namespace tvm {
namespace ir {


/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \param dom_map The domain of each iter vars.
* \return the result Stmt
*/
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);

/*!
* \brief verifies whether the IR stmt or Expr is in SSA form.
* That is: each VarExpr is defined and assigned once(in Let/For)
Expand Down Expand Up @@ -51,14 +61,6 @@ Stmt Inline(FunctionRef f,
Expr body,
Stmt stmt);

/*!
* \brief Schedule s' dependent operations.
*
* \param s The schedule to be realized
* \return the result Stmt
*/
Stmt ScheduelOps(Schedule s);

} // namespace ir
} // namespace tvm

Expand Down
43 changes: 41 additions & 2 deletions include/tvm/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@

namespace tvm {

/*!
* \brief A placeholder op represents an input placeholder.
*/
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
Array<Expr> shape;
/*! \brief The data type of the input. */
Type dtype;

int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;

void VisitAttrs(AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name,
Array<Expr> shape,
Type dtype);

static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_NODE_TYPE_INFO(PlaceholderOpNode);
};

/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
Expand All @@ -24,11 +54,10 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */
ComputeOpNode() {}

size_t num_outputs() const final {
int num_outputs() const final {
return 1;
}
Array<IterVar> root_iter_vars() const final;
std::string output_name(size_t i) const final;
Type output_dtype(size_t i) const final;
Array<Expr> output_shape(size_t i) const final;

Expand All @@ -49,6 +78,16 @@ class ComputeOpNode : public OperationNode {
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<Expr (const Array<Var>& i)>;

/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
Tensor Placeholder(Array<Expr> shape,
Type dtype = Float(32),
std::string name = "placeholder");

/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
Expand Down
20 changes: 11 additions & 9 deletions src/schedule/bound.h → include/tvm/schedule_pass.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
/*!
* Copyright (c) 2016 by Contributors
* \file bound.h
* \brief The bound inference logics on the schedule.
* \file schedule_pass.h
* \brief Collection of Schedule pass functions.
*
* These passes works on the schedule hyper-graph
* and infers information such as bounds, check conditions
* read/write dependencies between the IterVar
*/
#ifndef TVM_SCHEDULE_BOUND_H_
#define TVM_SCHEDULE_BOUND_H_
#ifndef TVM_SCHEDULE_PASS_H_
#define TVM_SCHEDULE_PASS_H_

#include <tvm/expr.h>
#include <tvm/schedule.h>
#include <unordered_map>
#include "./base.h"
#include "./schedule.h"

namespace tvm {
namespace schedule {
Expand All @@ -23,5 +26,4 @@ Map<IterVar, Range> InferBound(Schedule sch);

} // namespace schedule
} // namespace tvm

#endif // TVM_SCHEDULE_BOUND_H_
#endif // TVM_SCHEDULE_PASS_H_
41 changes: 12 additions & 29 deletions include/tvm/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,11 @@ using Halide::IR::FunctionRef;
* \brief Tensor structure representing a possible input,
* or intermediate computation result.
*/
class Tensor : public FunctionRef {
class Tensor : public NodeRef {
public:
/*! \brief default constructor, used internally */
Tensor() {}
explicit Tensor(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief constructor of input tensor
* \param shape Shape of the tensor.
* \param name optional name of the Tensor.
* \param dtype The data type of the input tensor.
*/
explicit Tensor(Array<Expr> shape,
std::string name = "tensor",
Type dtype = Float(32));
explicit Tensor(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down Expand Up @@ -116,11 +107,11 @@ class Tensor : public FunctionRef {
};

/*! \brief Operation that produces tensors */
class Operation : public NodeRef {
class Operation : public FunctionRef {
public:
/*! \brief default constructor */
Operation() {}
explicit Operation(std::shared_ptr<Node> n) : NodeRef(n) {}
explicit Operation(std::shared_ptr<Node> n) : FunctionRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand All @@ -137,12 +128,10 @@ class Operation : public NodeRef {
};

/*! \brief Node to represent a tensor */
class TensorNode : public FunctionBaseNode {
class TensorNode : public Node {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief optional name of the tensor */
std::string name;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief the source operation, can be None */
Expand All @@ -154,19 +143,11 @@ class TensorNode : public FunctionBaseNode {

void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
const std::string& func_name() const final {
return name;
}
int outputs() const final {
return 1;
}
static Tensor make(Array<Expr> shape,
std::string name,
Type dtype,
Operation op,
int value_index);
Expand All @@ -178,16 +159,18 @@ class TensorNode : public FunctionBaseNode {
/*!
* \brief base class of operation node.
*/
class OperationNode : public Node {
class OperationNode : public FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \return name of the operation */
const std::string& func_name() const final {
return name;
}
/*! \return number of outputs of this op */
virtual int num_outputs() const = 0;
/*! \return the list of iteration variable at root */
virtual Array<IterVar> root_iter_vars() const = 0;
/*! \return number of outputs of this op */
virtual size_t num_outputs() const = 0;
/*! \return name of i-th output */
virtual std::string output_name(size_t i) const = 0;
/*! \return type of i-th output */
virtual Type output_dtype(size_t i) const = 0;
/*! \return shape of i-th output */
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def Var(name="tindex", dtype=int32):
return _function_internal._Var(name, dtype)


def placeholder(shape, dtype = None, name="TensorObj"):
def placeholder(shape, dtype = None, name="placeholder"):
"""Construct an empty tensor object.
Parameters
Expand All @@ -53,11 +53,11 @@ def placeholder(shape, dtype = None, name="TensorObj"):
The created tensor
"""
dtype = float32 if dtype is None else dtype
return _function_internal._Tensor(
shape, name, dtype, None, 0)
return _function_internal._Placeholder(
shape, dtype, name)


def compute(shape, fcompute, name="TensorCompute"):
def compute(shape, fcompute, name="compute"):
"""Construct a new tensor by computing over the shape domain.
The compute rule is result[axis] = fcompute(axis)
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def __call__(self, *indices):
else:
raise ValueError("The indices must be expression")

return _make.Call(self.dtype, self.name, args, _expr.Call.Halide, self, 0)
return _make.Call(self.dtype, self.op.name,
args, _expr.Call.Halide,
self.op, self.value_index)

def __getitem__(self, indices):
return TensorSlice(self, indices)
Expand Down Expand Up @@ -71,3 +73,7 @@ def output(self, index):
@register_node
class ComputeOp(Operation):
pass

@register_node
class PlaceholderOp(Operation):
pass
14 changes: 12 additions & 2 deletions src/c_api/c_api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ TVM_REGISTER_API(_make_For)
args.at(5));
});

TVM_REGISTER_API(_make_Realize)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Realize::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4),
args.at(5));
});


TVM_REGISTER_API(_make_Call)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Call::make(args.at(0),
Expand Down Expand Up @@ -113,9 +124,8 @@ REGISTER_MAKE3(LetStmt);
REGISTER_MAKE2(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE3(Store);
REGISTER_MAKE3(Provide);
REGISTER_MAKE4(Provide);
REGISTER_MAKE1(Free);
// TODO(tqchen) Realize;
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
Expand Down
8 changes: 7 additions & 1 deletion src/c_api/c_api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ TVM_REGISTER_API(Range)
TVM_REGISTER_API(_Tensor)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = TensorNode::make(args.at(0),
args.at(1),
args.at(2),
args.at(3),
args.at(4));
Expand All @@ -160,6 +159,13 @@ TVM_REGISTER_API(_TensorHash)
std::hash<Tensor>()(args.at(0).operator Tensor()));
});

TVM_REGISTER_API(_Placeholder)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = Placeholder(args.at(0),
args.at(1),
args.at(2));
});

TVM_REGISTER_API(_ComputeOp)
.set_body([](const ArgStack& args, RetValue *ret) {
*ret = ComputeOpNode::make(args.at(0),
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include "./c_api_registry.h"
#include "../schedule/bound.h"

namespace tvm {
namespace ir {
Expand Down Expand Up @@ -36,6 +35,7 @@ using RetValue = APIVariantValue;
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS4(Inline);
REGISTER_PASS2(ScheduleOps);

} // namespace ir
} // namespace tvm
2 changes: 1 addition & 1 deletion src/c_api/c_api_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/schedule.h>
#include <tvm/schedule_pass.h>
#include "./c_api_registry.h"
#include "../schedule/bound.h"
#include "../schedule/graph.h"

namespace tvm {
Expand Down
1 change: 0 additions & 1 deletion src/lang/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,4 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)

TVM_REGISTER_NODE_TYPE(IterVarNode);


} // namespace tvm
3 changes: 2 additions & 1 deletion src/lang/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) {
p->stream << "attr " << op->type_key << " = ";
p->do_indent();
p->stream << "// attr " << op->type_key << " = ";
p->print(op->value);
p->stream << '\n';
p->print(op->body);
Expand Down
Loading

0 comments on commit 78ea652

Please sign in to comment.