From 3c1020dffb94e8ee4e076fa69374ad9df1d339ae Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 20 Jan 2017 12:41:52 -0800 Subject: [PATCH] [CODEGEN] Add CodeGenC (#22) --- HalideIR | 2 +- python/tvm/__init__.py | 1 + python/tvm/_ctypes/_api.py | 1 + python/tvm/codegen.py | 1 + src/base/common.h | 2 +- src/c_api/c_api_codegen.cc | 25 ++ src/codegen/codegen_c.cc | 483 ++++++++++++++++++++++++++++++ src/codegen/codegen_c.h | 140 +++++++++ tests/python/test_codegen_cuda.py | 8 + 9 files changed, 661 insertions(+), 2 deletions(-) create mode 100644 python/tvm/codegen.py create mode 100644 src/c_api/c_api_codegen.cc create mode 100644 src/codegen/codegen_c.cc create mode 100644 src/codegen/codegen_c.h diff --git a/HalideIR b/HalideIR index b6637f611f91..adfa66240265 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit b6637f611f91dd075dc251438f72ad38901d17fb +Subproject commit adfa662402650e2f9b02ea600ffb70d6e7bb5adf diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index b3a376de3d9f..91b5abb6cc4e 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -8,6 +8,7 @@ from . import stmt from . import make from . import ir_pass +from . import codegen from . import collections from . import schedule diff --git a/python/tvm/_ctypes/_api.py b/python/tvm/_ctypes/_api.py index 3ad5aa330468..4b9d9d4932db 100644 --- a/python/tvm/_ctypes/_api.py +++ b/python/tvm/_ctypes/_api.py @@ -281,6 +281,7 @@ def _init_function_module(root_namespace): namespace_match = { "_make_": sys.modules["%s.make" % root_namespace], "_pass_": sys.modules["%s.ir_pass" % root_namespace], + "_codegen_": sys.modules["%s.codegen" % root_namespace], "_schedule_": sys.modules["%s.schedule" % root_namespace] } diff --git a/python/tvm/codegen.py b/python/tvm/codegen.py new file mode 100644 index 000000000000..02dda155c19a --- /dev/null +++ b/python/tvm/codegen.py @@ -0,0 +1 @@ +"""Code generation related functions""" diff --git a/src/base/common.h b/src/base/common.h index 0485bdfc4af0..ea2f4bdad9e5 100644 --- a/src/base/common.h +++ b/src/base/common.h @@ -30,7 +30,7 @@ inline Type String2Type(std::string s) { } else if (s.substr(0, 5) == "float") { code = Type::Float; s = s.substr(5); } else if (s == "handle") { - return Type(Type::Handle, 0, 0); + return Type(Type::Handle, 32, 1); } else { LOG(FATAL) << "unknown type " << s; } diff --git a/src/c_api/c_api_codegen.cc b/src/c_api/c_api_codegen.cc new file mode 100644 index 000000000000..365033ea445f --- /dev/null +++ b/src/c_api/c_api_codegen.cc @@ -0,0 +1,25 @@ +/*! + * Copyright (c) 2016 by Contributors + * Implementation of API functions related to IR build + * \file c_api_ir.cc + */ +#include +#include + +#include "./c_api_registry.h" +#include "../codegen/codegen_c.h" + +namespace tvm { +namespace codegen { + +using ArgStack = const std::vector; +using RetValue = APIVariantValue; + +TVM_REGISTER_API(_codegen_CompileToC) +.set_body([](const ArgStack& args, RetValue *ret) { + *ret = CodeGenC().Compile( + args.at(0), args.at(1), args.at(2), args.at(3)); + }); + +} // namespace codegen +} // namespace tvm diff --git a/src/codegen/codegen_c.cc b/src/codegen/codegen_c.cc new file mode 100644 index 000000000000..a42569e9ad32 --- /dev/null +++ b/src/codegen/codegen_c.cc @@ -0,0 +1,483 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file codegen_c.cc + */ +#include "./codegen_c.h" + +namespace tvm { +namespace codegen { + +using namespace ir; + +std::string CodeGenC::Compile( + Stmt stmt, std::string fun_name, + Array args, bool output_ssa) { + print_ssa_form_ = output_ssa; + // skip the first underscore, so SSA variable starts from _1 + if (print_ssa_form_) GetUniqueName("_"); + + this->indent += 2; + this->stream << "void " << fun_name << "("; + for (size_t i = 0; i < args.size(); ++i) { + Var v = args[i]; + std::string vid = AllocVarID(v.get()); + if (i != 0) stream << ", "; + PrintType(v.type(), stream); + stream << ' ' << vid; + } + stream << ") {\n"; + this->PrintStmt(stmt); + this->indent -= 2; + this->PrintIndent(); + this->stream << "}\n"; + return stream.str(); +} + +void CodeGenC::PrintStmt(const Stmt& n) { + static const FPrintStmt& f = vtable_print_stmt(); + f(n, this); +} + +std::string CodeGenC::SSAGetID(std::string src, Type t) { + if (name_alloc_map_.count(src)) return src; + auto it = ssa_assign_map_.find(src); + if (it != ssa_assign_map_.end()) { + return it->second; + } else { + this->PrintIndent(); + std::string id = GetUniqueName("_"); + ssa_assign_map_[src] = id; + if (src.length() > 3 && + src[0] == '(' && src[src.length() - 1] == ')') { + src = src.substr(1, src.length() - 2); + } + PrintType(t, stream); + stream << ' ' << id << " = " << src << ";\n"; + return id; + } +} + +void CodeGenC::PrintExpr(const Expr& n, std::ostream& os) { // NOLINT(*) + static const FPrintExpr& f = vtable_print_expr(); + if (print_ssa_form_) { + std::ostringstream temp; + f(n, temp, this); + os << SSAGetID(temp.str(), n.type()); + } else { + f(n, os, this); + } +} + +std::string CodeGenC::GetUniqueName(std::string prefix) { + auto it = name_alloc_map_.find(prefix); + if (it != name_alloc_map_.end()) { + while (true) { + std::ostringstream os; + os << prefix << (++it->second); + std::string name = os.str(); + if (name_alloc_map_.count(name) == 0) { + prefix = name; + break; + } + } + } + name_alloc_map_[prefix] = 0; + return prefix; +} + +std::string CodeGenC::AllocVarID(const Variable* v) { + CHECK(!var_idmap_.count(v)) + << "Need input to be in SSA form dup " << v->name_hint; + std::string key = v->name_hint; + for (size_t i = 0; i < key.size(); ++i) { + if (key[i] == '.') key[i] = '_'; + } + std::string vid = GetUniqueName(key); + var_idmap_[v] = vid; + return vid; +} + +std::string CodeGenC::GetVarID(const Variable* v) const { + auto it = var_idmap_.find(v); + CHECK(it != var_idmap_.end()) + << "Find undefined Variable " << v->name_hint; + return it->second; +} + +bool CodeGenC::BufferTypeMatch(const Variable* buf_var, Type t) const { + auto it = alloc_buf_type_.find(buf_var); + if (it == alloc_buf_type_.end()) return false; + return it->second == t; +} + +void CodeGenC::PrintIndent() { + for (int i = 0; i < this->indent; ++i) { + this->stream << ' '; + } +} + +void CodeGenC::MarkConst(std::string vid) { + if (print_ssa_form_) { + auto it = ssa_assign_map_.find(vid); + if (it == ssa_assign_map_.end()) { + ssa_assign_map_[vid] = vid; + } else { + CHECK_EQ(it->second, vid); + } + } +} + +void CodeGenC::PrintType(Type t, std::ostream& os) const { // NOLINT(*) + CHECK_EQ(t.lanes(), 1) + << "do not yet support vector types"; + if (t.is_handle()) { + os << "void*"; return; + } + if (t.is_float()) { + if (t.bits() == 32) { + os << "float"; return; + } + if (t.bits() == 64) { + os << "double"; return; + } + } else if (t.is_uint()) { + switch (t.bits()) { + case 8: case 16: case 32: case 64: { + os << "uint" << t.bits() << "_t"; return; + } + case 1: os << "int"; return; + } + } else if (t.is_int()) { + switch (t.bits()) { + case 8: case 16: case 32: case 64: { + os << "int" << t.bits() << "_t"; return; + } + } + } + LOG(FATAL) << "Cannot convert type " << t << " to C type"; +} + +CodeGenC::FPrintStmt& CodeGenC::vtable_print_stmt() { // NOLINT(*) + static FPrintStmt inst; return inst; +} + +CodeGenC::FPrintExpr& CodeGenC::vtable_print_expr() { // NOLINT(*) + static FPrintExpr inst; return inst; +} + +inline void PrintConst(const IntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + if (op->type == Int(32)) { + std::ostringstream temp; + temp << op->value; + p->MarkConst(temp.str()); + os << temp.str(); + } else { + os << "("; + p->PrintType(op->type, os); + os << ")" << op->value; + } +} + +inline void PrintConst(const UIntImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + if (op->type == UInt(32)) { + std::ostringstream temp; + temp << op->value << "U"; + p->MarkConst(temp.str()); + os << temp.str(); + } else { + os << "("; + p->PrintType(op->type, os); + os << ")" << op->value; + } +} + +inline void PrintConst(const FloatImm* op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + switch (op->type.bits()) { + case 64: case 32: { + std::ostringstream temp; + temp << op->value; + if (op->type.bits() == 32) temp << 'f'; + p->MarkConst(temp.str()); + os << temp.str(); + break; + } + case 16: { + os << '('; + p->PrintType(op->type, os); + os << ')' << op->value << 'f'; + break; + } + default: LOG(FATAL) << "Bad bit-width for float: " << op->type << "\n"; + } +} + +TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) +.set_dispatch([](const IntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) + PrintConst(op, os, p); + }) +.set_dispatch([](const UIntImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) + PrintConst(op, os, p); + }) +.set_dispatch([](const FloatImm *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) + PrintConst(op, os, p); + }); + +template +inline void PrintBinaryExpr(const T* op, + const char *opstr, + std::ostream& os, // NOLINT(*) + CodeGenC* p) { + os << '('; + p->PrintExpr(op->a, os); + os << opstr; + p->PrintExpr(op->b, os); + os << ')'; +} + +TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) +.set_dispatch([](const Cast *op, std::ostream& os, CodeGenC *p) { // NOLINT(*) + p->PrintType(op->type, os); + os << '('; + p->PrintExpr(op->value, os); + os << ')'; + }) +.set_dispatch([](const Variable *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + os << p->GetVarID(op); + }) +.set_dispatch([](const Add *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " + ", os, p); + }) +.set_dispatch([](const Sub *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " - ", os, p); + }) +.set_dispatch([](const Mul *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " * ", os, p); + }) +.set_dispatch
([](const Div *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " / ", os, p); + }) +.set_dispatch([](const Mod *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " % ", os, p); +}) +.set_dispatch([](const Min *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + os << "min("; + p->PrintExpr(op->a, os); + os << ", "; + p->PrintExpr(op->b, os); + os << ")"; +}) +.set_dispatch([](const Max *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + os << "max("; + p->PrintExpr(op->a, os); + os << ", "; + p->PrintExpr(op->b, os); + os << ")"; +}) +.set_dispatch([](const EQ *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " == ", os, p); +}) +.set_dispatch([](const NE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " != ", os, p); +}) +.set_dispatch([](const LT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " < ", os, p); +}) +.set_dispatch([](const LE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " <= ", os, p); +}) +.set_dispatch([](const GT *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " > ", os, p); +}) +.set_dispatch([](const GE *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " >= ", os, p); +}) +.set_dispatch([](const And *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " && ", os, p); +}) +.set_dispatch([](const Or *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + PrintBinaryExpr(op, " || ", os, p); +}) +.set_dispatch([](const Not *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + os << '!'; + p->PrintExpr(op->a, os); + }) +.set_dispatch([](const Call *op, std::ostream& os, CodeGenC* p) { // NOLINT(*) + os << op->name << "("; + for (size_t i = 0; i < op->args.size(); i++) { + p->PrintExpr(op->args[i], os); + if (i < op->args.size() - 1) { + os << ", "; + } + } + os << ")"; + }); + +TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) +.set_dispatch([](const AssertStmt *op, CodeGenC* p) { + std::string cond = p->PrintExpr(op->condition); + p->PrintIndent(); + p->stream << "assert(" << cond << ");\n"; + }) +.set_dispatch([](const ProducerConsumer *op, CodeGenC* p) { + p->PrintStmt(op->body); + }) +.set_dispatch([](const For *op, CodeGenC* p) { + std::string extent = p->PrintExpr(op->extent); + p->PrintIndent(); + std::string vid = p->AllocVarID(op->loop_var.get()); + CHECK(is_zero(op->min)); + p->stream << "for ("; + p->PrintType(op->loop_var.type(), p->stream); + p->stream << ' ' << vid << " = 0; " + << vid << " < " << extent + << "; ++" << vid << ") {\n"; + p->indent += 2; + p->PrintStmt(op->body); + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + }) +.set_dispatch([](const Block *op, CodeGenC* p) { + p->PrintStmt(op->first); + if (op->rest.defined()) p->PrintStmt(op->rest); + }) +.set_dispatch([](const Evaluate *op, CodeGenC* p) { + if (is_const(op->value)) return; + std::string vid = p->PrintExpr(op->value); + p->PrintIndent(); + p->stream << "(void)" << vid << ";\n"; + }) +.set_dispatch([](const IfThenElse *op, CodeGenC* p) { + std::string cond = p->PrintExpr(op->condition); + p->PrintIndent(); + p->stream << "if (" << cond << ") {\n"; + p->indent += 2; + p->PrintStmt(op->then_case); + p->indent -= 2; + if (op->else_case.defined()) { + p->PrintIndent(); + p->stream << "} else {\n"; + p->indent += 2; + p->PrintStmt(op->else_case); + p->indent -= 2; + } + p->PrintIndent(); + p->stream << "}\n"; +}); + + +#define DISPATCH_EXPR(OP) \ + set_dispatch([](const OP *op, std::ostream&os, CodeGenC* p) { \ + p->PrintExpr(op, os); }) + +TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_expr) +.DISPATCH_EXPR(Load) +.DISPATCH_EXPR(Let) +.DISPATCH_EXPR(Ramp) +.DISPATCH_EXPR(Broadcast) +.DISPATCH_EXPR(Select); + +void CodeGenC::PrintExpr(const Load* op, std::ostream& os) { // NOLINT(*) + std::string vid = GetVarID(op->buffer_var.get()); + if (!BufferTypeMatch(op->buffer_var.get(), op->type)) { + os << "((const "; + PrintType(op->type, os); + os << "*)" << vid << ')'; + } else { + os << vid; + } + os << '['; + PrintExpr(op->index, os); + os << ']'; +} + +void CodeGenC::PrintExpr(const Let* op, std::ostream& os) { // NOLINT(*) + CHECK(print_ssa_form_) + << "LetExpr is only supported by print SSA form"; + std::string value = PrintExpr(op->value); + CHECK(!var_idmap_.count(op->var.get())); + var_idmap_[op->var.get()] = value; +} + +void CodeGenC::PrintExpr(const Ramp* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "not supported "; +} + +void CodeGenC::PrintExpr(const Broadcast* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "not supported "; +} + +void CodeGenC::PrintExpr(const Select* op, std::ostream& os) { // NOLINT(*) + LOG(FATAL) << "not supported "; +} + +// Disoatch back to member functions +TVM_STATIC_IR_FUNCTOR(CodeGenC, vtable_print_stmt) +.set_dispatch([](const LetStmt *op, CodeGenC* p) { p->PrintStmt(op); }) +.set_dispatch([](const Store *op, CodeGenC* p) { p->PrintStmt(op); }) +.set_dispatch([](const Allocate *op, CodeGenC* p) { p->PrintStmt(op); }) +.set_dispatch([](const AttrStmt *op, CodeGenC* p) { p->PrintStmt(op); }); + + +void CodeGenC::PrintStmt(const LetStmt* op) { + std::string value = PrintExpr(op->value); + if (print_ssa_form_) { + CHECK(!var_idmap_.count(op->var.get())); + var_idmap_[op->var.get()] = value; + } else { + PrintIndent(); + PrintType(op->var.type(), this->stream); + this->stream << ' ' + << AllocVarID(op->var.get()) + << " = " << value << ";\n"; + } + PrintStmt(op->body); +} + +void CodeGenC::PrintStmt(const Store* op) { + std::string index = this->PrintExpr(op->index); + std::string value = this->PrintExpr(op->value); + this->PrintIndent(); + std::string vid = GetVarID(op->buffer_var.get()); + if (!BufferTypeMatch(op->buffer_var.get(), op->value.type())) { + this->stream << "(("; + PrintType(op->value.type(), this->stream); + this->stream << "*)" << vid << ')'; + } else { + this->stream << vid; + } + this->stream << '[' << index + << "] = " << value + << ";\n"; +} + +void CodeGenC::PrintStmt(const Allocate* op) { + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + std::string vid = AllocVarID(op->buffer_var.get()); + CHECK(!op->new_expr.defined()); + CHECK(!is_zero(op->condition)); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + PrintType(op->type, stream); + stream << ' '<< vid << '[' + << constant_size << "]\n;"; + this->PrintStmt(op->body); +} + +void CodeGenC::PrintStmt(const AttrStmt* op) { + if (op->type_key == "scope") { + IterVar iv(op->node.node_); + if (iv->thread_tag.length() != 0) { + this->PrintIndent(); + PrintType(iv->var.type(), stream); + stream << ' ' + << AllocVarID(iv->var.get()) + << " = " << iv->thread_tag << ";\n"; + } + } + this->PrintStmt(op->body); +} + +} // namespace codegen +} // namespace tvm diff --git a/src/codegen/codegen_c.h b/src/codegen/codegen_c.h new file mode 100644 index 000000000000..a8ce1828e4b4 --- /dev/null +++ b/src/codegen/codegen_c.h @@ -0,0 +1,140 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file codegen_c.h + * \brief Common utilities to generated C style code. + */ +#ifndef TVM_CODEGEN_CODEGEN_C_H_ +#define TVM_CODEGEN_CODEGEN_C_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace codegen { + +/*! + * \brief A base class to generate C code. + * + * CodeGenC have two modes: generate SSA formed C code or normal form. + */ +class CodeGenC { + public: + /*! + * \brief Generate the C code of statement + * \param body The body of the function. + * \param fun_name The name of the function. + * \param args The arguments to the function. + * \param output_ssa Whether output ssa form. + * \note Only call compile once, + * create a new codegen object each time. + */ + std::string Compile(Stmt body, + std::string fun_name, + Array args, + bool output_ssa); + /*! + * \brief Print the Stmt n to CodeGenC->stream + * \param n The statement to be printed. + */ + void PrintStmt(const Stmt& n); + /*! + * \brief Print the expression n(or its ssa id if in ssa mode) into os + * \param n The expression to be printed. + * \param os The output stream + */ + void PrintExpr(const Expr& n, std::ostream& os); // NOLINT(*) + /*! + * \brief Same as PrintExpr, but simply returns result string + * \param n The expression to be printed. + */ + inline std::string PrintExpr(const Expr& n) { + std::ostringstream os; + PrintExpr(n, os); + return os.str(); + } + /*! \brief print the current indented value */ + void PrintIndent(); + /*! + * \brief Register constant value appeared in expresion tree + * This avoid generated a ssa id for each appearance of the value + * \param value The constant value. + */ + void MarkConst(std::string value); + /*! + * \brief Allocate a variable name for a newly defined var. + * \param v The variable. + * \return the variable name. + */ + std::string AllocVarID(const Variable* v); + /*! + * \brief Get a variable name. + * \param v The variable. + * \return the variable name. + */ + std::string GetVarID(const Variable* v) const; + /*! + * Print Type represetnation of type t. + * \param t The type representation. + * \return os The stream to print the ctype into + */ + virtual void PrintType(Type t, std::ostream& os) const; // NOLINT(*) + // The following parts are overloadable print operations. + virtual void PrintStmt(const ir::LetStmt* op); + virtual void PrintStmt(const ir::Store* op); + virtual void PrintStmt(const ir::Allocate* op); + virtual void PrintStmt(const ir::AttrStmt* op); + virtual void PrintExpr(const ir::Load* op, std::ostream& os); // NOLINT(*) + virtual void PrintExpr(const ir::Let* op, std::ostream& os); // NOLINT(*) + virtual void PrintExpr(const ir::Ramp* op, std::ostream& os); // NOLINT(*) + virtual void PrintExpr(const ir::Broadcast* op, std::ostream& os); // NOLINT(*) + virtual void PrintExpr(const ir::Select* op, std::ostream& os); // NOLINT(*) + /*! \brief function print into the ostream */ + using FPrintExpr = IRFunctor; // NOLINT(*) + /*! \brief function to to print normal code */ + using FPrintStmt = IRFunctor; + // vtable to print code + static FPrintStmt& vtable_print_stmt(); + // vtable to print code + static FPrintExpr& vtable_print_expr(); + /*! \brief The current indentation value */ + int indent{0}; + /*! \brief the stream to be printed */ + std::ostringstream stream; + + private: + /*! + * \brief Get the SSA ID corresponds to src + * If necessary, generate new assignment + * \param src The source expression + * \param t The type of the expression. + */ + std::string SSAGetID(std::string src, Type t); + /*! + * \brief If buffer is allocated as type t. + * \param buf_var The buffer variable. + * \param t The type to be checked. + */ + bool BufferTypeMatch(const Variable* buf_var, Type t) const; + /*! + * \brief get a unique name with the corresponding prefix + * \param prefix The prefix of the name + * \return The returned name. + */ + std::string GetUniqueName(std::string prefix); + /*! \brief whether to print in SSA form */ + bool print_ssa_form_{true}; + /*! \brief name of each variable */ + std::unordered_map var_idmap_; + /*! \brief the data type of allocated buffers */ + std::unordered_map alloc_buf_type_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + /*! \brief assignment map of ssa */ + std::unordered_map ssa_assign_map_; +}; + +} // namespace codegen +} // namespace tvm +#endif // TVM_CODEGEN_CODEGEN_C_H_ diff --git a/tests/python/test_codegen_cuda.py b/tests/python/test_codegen_cuda.py index b93e80e52059..0f0a8df30506 100644 --- a/tests/python/test_codegen_cuda.py +++ b/tests/python/test_codegen_cuda.py @@ -19,10 +19,18 @@ def mock_test_add(): bounds = tvm.schedule.InferBound(s) stmt = tvm.ir_pass.ScheduleOps(s, bounds) + Ab = tvm.Buffer(A.shape, A.dtype, name='A') Bb = tvm.Buffer(B.shape, B.dtype, name='B') Cb = tvm.Buffer(C.shape, C.dtype, name='C') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb}) + print(stmt) + output_ssa = False + code = tvm.codegen.CompileToC(stmt, "myadd", + [Ab.ptr, Bb.ptr, Cb.ptr, n], + output_ssa) + print(code) def codegen(): # generate host/device code host_code, device_code = tvm.codegen.GenCUDA(