Skip to content

Commit

Permalink
Get of special register number for return value
Browse files Browse the repository at this point in the history
  • Loading branch information
wweic committed Apr 10, 2019
1 parent dfcd401 commit 5c56f6b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
4 changes: 3 additions & 1 deletion include/tvm/relay/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@ struct VMFrame {

std::vector<Object> register_file;

VirtualRegisterNum caller_return_register;

VMFrame(size_t pc, size_t func_index, size_t args, const Instruction* code, size_t register_file_size)
: pc(pc), func_index(func_index), args(args), code(code), register_file(register_file_size)
: pc(pc), func_index(func_index), args(args), code(code), register_file(register_file_size), caller_return_register(0)
{}
};

Expand Down
8 changes: 3 additions & 5 deletions src/relay/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
VMCompilerContext* context;

VMCompiler(VMCompilerContext* context) :
instructions(), last_register(0), registers_num(1),
instructions(), last_register(0), registers_num(0),
engine(CompileEngine::Global()), context(context) {}

size_t NewRegister() {
Expand Down Expand Up @@ -345,8 +345,6 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
Emit(AllocClosure(it->second, arity, free_var_registers, NewRegister()));
} else {
Emit(Invoke(it->second, args_registers, NewRegister()));
// 0 is return value register
Emit(Move(0, NewRegister()));
}
} else if (auto constructor_node = op.as<ConstructorNode>()) {
auto constructor = GetRef<Constructor>(constructor_node);
Expand Down Expand Up @@ -384,7 +382,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// We first layout the function arguments.
auto inner_func = Downcast<Function>(func->body);

size_t i = 1;
size_t i = 0;
for (auto param : inner_func->params) {
auto arg_register = NewRegister();
CHECK_EQ(i, arg_register);
Expand Down Expand Up @@ -415,7 +413,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {

for (auto i = 0; i < func->params.size(); ++i) {
auto arg_register = NewRegister();
CHECK_EQ(arg_register, i+1);
CHECK_EQ(arg_register, i);
var_register_map.insert({ func->params[i], arg_register });
}

Expand Down
6 changes: 5 additions & 1 deletion src/relay/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector<Obje

PushFrame(func.params, this->pc + 1, func);
for (size_t i = 0; i < args.size(); ++i) {
WriteRegister(i+1, args[i]);
WriteRegister(i, args[i]);
}
RELAY_LOG(INFO) << "func.params= " << func.params << std::endl;

Expand Down Expand Up @@ -482,6 +482,7 @@ void VirtualMachine::Run() {
args.push_back(ReadRegister(instr.invoke_args_registers[i]));
}
InvokeGlobal(this->functions[instr.func_index], args);
frames.back().caller_return_register = instr.dst;
goto main_loop;
}
case Opcode::InvokePacked: {
Expand Down Expand Up @@ -510,6 +511,7 @@ void VirtualMachine::Run() {
args.push_back(free_var);
}
InvokeGlobal(this->functions[closure->func_index], args);
frames.back().caller_return_register = instr.dst;
goto main_loop;
}
case Opcode::GetField: {
Expand Down Expand Up @@ -599,6 +601,7 @@ void VirtualMachine::Run() {
// running, we should return to the caller breaking
// the dispatch loop.
return_register = ReadRegister(instr.result);
auto caller_return_register = frames.back().caller_return_register;

if (PopFrame() == frame_start) {
return;
Expand All @@ -607,6 +610,7 @@ void VirtualMachine::Run() {
// Since we have already popped the stack we will just
// resume at the top of the dispatch loop.
} else {
WriteRegister(caller_return_register, return_register);
goto main_loop;
}
}
Expand Down

0 comments on commit 5c56f6b

Please sign in to comment.