Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WAVM support. #47

Merged
merged 8 commits into from
Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/wasm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,9 @@ bool WasmBase::initialize(const std::string &code, bool allow_precompiled) {

if (started_from_ != Cloneable::InstantiatedModule) {
registerCallbacks();
wasm_vm_->link(vm_id_);
if (!wasm_vm_->link(vm_id_)) {
return false;
}
}

vm_context_.reset(createVmContext());
Expand Down
167 changes: 106 additions & 61 deletions src/wavm/wavm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

#include "include/proxy-wasm/wavm.h"

#include <cstdlib>
#include <iostream>
#include <map>
#include <memory>
#include <optional>
#include <unordered_map>
Expand Down Expand Up @@ -43,8 +45,16 @@
#include "WAVM/Runtime/Runtime.h"
#include "WAVM/WASM/WASM.h"
#include "WAVM/WASTParse/WASTParse.h"
#include "absl/container/node_hash_map.h"
#include "absl/strings/match.h"

#ifdef NDEBUG
#define ASSERT(_x) _x
#else
#define ASSERT(_x) \
do { \
if (!_x) \
::exit(1); \
} while (0)
#endif

using namespace WAVM;
using namespace WAVM::IR;
Expand Down Expand Up @@ -74,15 +84,21 @@ struct Wavm;

namespace {

#define CALL_WITH_CONTEXT(_x, _context) \
#define CALL_WITH_CONTEXT(_x, _context, _wavm) \
do { \
SaveRestoreContext _saved_context(static_cast<ContextBase *>(_context)); \
WAVM::Runtime::catchRuntimeExceptions([&] { _x; }, \
[&](WAVM::Runtime::Exception *exception) { \
auto description = describeException(exception); \
destroyException(exception); \
throw WasmException(description); \
}); \
try { \
SaveRestoreContext _saved_context(static_cast<ContextBase *>(_context)); \
WAVM::Runtime::catchRuntimeExceptions( \
[&] { _x; }, \
[&](WAVM::Runtime::Exception *exception) { \
auto description = describeException(exception); \
_wavm->fail(FailState::RuntimeError, \
"Function: " + std::string(function_name) + " failed: " + description); \
destroyException(exception); \
throw std::exception(); \
}); \
} catch (...) { \
} \
} while (0)

struct WasmUntaggedValue : public WAVM::IR::UntaggedValue {
Expand All @@ -96,11 +112,9 @@ struct WasmUntaggedValue : public WAVM::IR::UntaggedValue {
WasmUntaggedValue(F64 inF64) { f64 = inF64; }
};

const Logger::Id wasmId = Logger::Id::wasm;

class RootResolver : public WAVM::Runtime::Resolver, public Logger::Loggable<wasmId> {
class RootResolver : public WAVM::Runtime::Resolver {
public:
RootResolver(WAVM::Runtime::Compartment *, WavmVm *vm) : vm_(vm) {}
RootResolver(WAVM::Runtime::Compartment *, WasmVm *vm) : vm_(vm) {}

virtual ~RootResolver() { module_name_to_instance_map_.clear(); }

Expand All @@ -113,10 +127,12 @@ class RootResolver : public WAVM::Runtime::Resolver, public Logger::Loggable<was
if (isA(out_object, type)) {
return true;
} else {
vm_->error("Failed to load WASM module due to a type mismatch in an import: " +
std::string(module_name) + "." + export_name + " " +
asString(WAVM::Runtime::getExternType(out_object)) +
" but was expecting type: " + asString(type));
vm_->fail(FailState::UnableToInitializeCode,
"Failed to load WASM module due to a type mismatch in an import: " +
std::string(module_name) + "." + export_name + " " +
asString(WAVM::Runtime::getExternType(out_object)) +
" but was expecting type: " + asString(type));
return false;
}
}
}
Expand All @@ -125,19 +141,21 @@ class RootResolver : public WAVM::Runtime::Resolver, public Logger::Loggable<was
return true;
}
}
vm_->error("Failed to load Wasm module due to a missing import: " + std::string(module_name) +
"." + std::string(export_name) + " " + asString(type));
vm_->fail(FailState::MissingFunction,
"Failed to load Wasm module due to a missing import: " + std::string(module_name) +
"." + std::string(export_name) + " " + asString(type));
return false;
}

HashMap<std::string, WAVM::Runtime::ModuleInstance *> &moduleNameToInstanceMap() {
HashMap<std::string, WAVM::Runtime::Instance *> &moduleNameToInstanceMap() {
return module_name_to_instance_map_;
}

void addResolver(WAVM::Runtime::Resolver *r) { resolvers_.push_back(r); }

private:
WavmVm *vm_;
HashMap<std::string, WAVM::Runtime::ModuleInstance *> module_name_to_instance_map_;
WasmVm *vm_;
HashMap<std::string, WAVM::Runtime::Instance *> module_name_to_instance_map_;
std::vector<WAVM::Runtime::Resolver *> resolvers_;
};

Expand Down Expand Up @@ -173,23 +191,24 @@ struct PairHash {
}
};

struct Wavm : public WasmVmBase {
Wavm(Stats::ScopeSharedPtr scope) : WasmVmBase(scope, WasmRuntimeNames::get().Wavm) {}
struct Wavm : public WasmVm {
Wavm() : WasmVm() {}
~Wavm() override;

// WasmVm
std::string_view runtime() override { return WasmRuntimeNames::get().Wavm; }
std::string_view runtime() override { return "wavm"; }
Cloneable cloneable() override { return Cloneable::InstantiatedModule; };
std::unique_ptr<WasmVm> clone() override;
bool load(const std::string &code, bool allow_precompiled) override;
void link(std::string_view debug_name) override;
bool link(std::string_view debug_name) override;
uint64_t getMemorySize() override;
std::optional<std::string_view> getMemory(uint64_t pointer, uint64_t size) override;
bool setMemory(uint64_t pointer, uint64_t size, const void *data) override;
bool getWord(uint64_t pointer, Word *data) override;
bool setWord(uint64_t pointer, Word data) override;
std::string_view getCustomSection(std::string_view name) override;
std::string_view getPrecompiledSectionName() override;
AbiVersion getAbiVersion() override;

#define _GET_FUNCTION(_T) \
void getFunction(std::string_view function_name, _T *f) override { \
Expand All @@ -209,15 +228,16 @@ struct Wavm : public WasmVmBase {
bool has_instantiated_module_ = false;
IR::Module ir_module_;
WAVM::Runtime::ModuleRef module_ = nullptr;
WAVM::Runtime::GCPointer<WAVM::Runtime::ModuleInstance> module_instance_;
WAVM::Runtime::GCPointer<WAVM::Runtime::Instance> module_instance_;
WAVM::Runtime::Memory *memory_;
WAVM::Runtime::GCPointer<WAVM::Runtime::Compartment> compartment_;
WAVM::Runtime::GCPointer<WAVM::Runtime::Context> context_;
node_hash_map<std::string, Intrinsics::Module> intrinsic_modules_;
node_hash_map<std::string, WAVM::Runtime::GCPointer<WAVM::Runtime::ModuleInstance>>
std::map<std::string, Intrinsics::Module> intrinsic_modules_;
std::map<std::string, WAVM::Runtime::GCPointer<WAVM::Runtime::Instance>>
intrinsic_module_instances_;
std::vector<std::unique_ptr<Intrinsics::Function>> envoyFunctions_;
uint8_t *memory_base_ = nullptr;
AbiVersion abi_version_ = AbiVersion::Unknown;
};

Wavm::~Wavm() {
Expand All @@ -232,11 +252,12 @@ Wavm::~Wavm() {
}

std::unique_ptr<WasmVm> Wavm::clone() {
auto wavm = std::make_unique<Wavm>(scope_);
auto wavm = std::make_unique<Wavm>();
wavm->compartment_ = WAVM::Runtime::cloneCompartment(compartment_);
wavm->memory_ = WAVM::Runtime::remapToClonedCompartment(memory_, wavm->compartment_);
wavm->memory_base_ = WAVM::Runtime::getMemoryBaseAddress(wavm->memory_);
wavm->context_ = WAVM::Runtime::createContext(wavm->compartment_);
wavm->abi_version_ = abi_version_;
for (auto &p : intrinsic_module_instances_) {
wavm->intrinsic_module_instances_.emplace(
p.first, WAVM::Runtime::remapToClonedCompartment(p.second, wavm->compartment_));
Expand All @@ -254,7 +275,7 @@ bool Wavm::load(const std::string &code, bool allow_precompiled) {
if (!loadModule(code, ir_module_)) {
return false;
}
// todo check percompiled section is permitted
getAbiVersion(); // Cache ABI version.
const CustomSection *precompiled_object_section = nullptr;
if (allow_precompiled) {
for (const CustomSection &customSection : ir_module_.customSections) {
Expand All @@ -272,21 +293,48 @@ bool Wavm::load(const std::string &code, bool allow_precompiled) {
return true;
}

AbiVersion Wavm::getAbiVersion() { return AbiVersion::Unknown; }
AbiVersion Wavm::getAbiVersion() {
if (abi_version_ != AbiVersion::Unknown) {
return abi_version_;
}
for (auto &e : ir_module_.exports) {
if (e.name == "proxy_abi_version_0_1_0") {
abi_version_ = AbiVersion::ProxyWasm_0_1_0;
return abi_version_;
}
if (e.name == "proxy_abi_version_0_2_0") {
abi_version_ = AbiVersion::ProxyWasm_0_2_0;
return abi_version_;
}
if (e.name == "proxy_abi_version_0_2_1") {
abi_version_ = AbiVersion::ProxyWasm_0_2_1;
return abi_version_;
}
}
jplevyak marked this conversation as resolved.
Show resolved Hide resolved
return AbiVersion::Unknown;
}

void Wavm::link(std::string_view debug_name) {
RootResolver rootResolver(compartment_);
bool Wavm::link(std::string_view debug_name) {
RootResolver rootResolver(compartment_, this);
for (auto &p : intrinsic_modules_) {
auto instance = Intrinsics::instantiateModule(compartment_, {&intrinsic_modules_[p.first]},
std::string(p.first));
intrinsic_module_instances_.emplace(p.first, instance);
rootResolver.moduleNameToInstanceMap().set(p.first, instance);
}
WAVM::Runtime::LinkResult link_result = linkModule(ir_module_, rootResolver);
if (!link_result.missingImports.empty()) {
for (auto &i : link_result.missingImports) {
error("Missing Wasm import " + i.moduleName + " " + i.exportName);
}
fail(FailState::MissingFunction, "Failed to load Wasm module due to a missing import(s)");
return false;
}
module_instance_ = instantiateModule(
compartment_, module_, std::move(link_result.resolvedImports), std::string(debug_name));
memory_ = getDefaultMemory(module_instance_);
memory_base_ = WAVM::Runtime::getMemoryBaseAddress(memory_);
return true;
}

uint64_t Wavm::getMemorySize() { return WAVM::Runtime::getMemoryNumPages(memory_) * WasmPageSize; }
Expand Down Expand Up @@ -326,7 +374,7 @@ bool Wavm::setWord(uint64_t pointer, Word data) {
return setMemory(pointer, sizeof(uint32_t), &data32);
}

std::string_view Wavm::getCustomSection(string_view name) {
std::string_view Wavm::getCustomSection(std::string_view name) {
for (auto &section : ir_module_.customSections) {
if (section.name == name) {
return {reinterpret_cast<char *>(section.data.data()), section.data.size()};
Expand All @@ -337,12 +385,10 @@ std::string_view Wavm::getCustomSection(string_view name) {

std::string_view Wavm::getPrecompiledSectionName() { return "wavm.precompiled_object"; }

std::unique_ptr<WasmVm> createVm(Stats::ScopeSharedPtr scope) {
return std::make_unique<Wavm>(scope);
}

} // namespace Wavm

std::unique_ptr<WasmVm> createWavmVm() { return std::make_unique<proxy_wasm::Wavm::Wavm>(); }

template <typename R, typename... Args>
IR::FunctionType inferEnvoyFunctionType(R (*)(void *, Args...)) {
return IR::FunctionType(IR::inferResultType<R>(), IR::TypeTuple({IR::inferValueType<Args>()...}),
Expand All @@ -354,10 +400,10 @@ using namespace Wavm;
template <typename R, typename... Args>
void registerCallbackWavm(WasmVm *vm, std::string_view module_name, std::string_view function_name,
R (*f)(Args...)) {
auto wavm = static_cast<Wavm *>(vm);
wavm->envoyFunctions_.emplace_back(
new Intrinsics::Function(&wavm->intrinsic_modules_[module_name], function_name.data(),
reinterpret_cast<void *>(f), inferEnvoyFunctionType(f)));
auto wavm = static_cast<proxy_wasm::Wavm::Wavm *>(vm);
wavm->envoyFunctions_.emplace_back(new Intrinsics::Function(
&wavm->intrinsic_modules_[std::string(module_name)], function_name.data(),
reinterpret_cast<void *>(f), inferEnvoyFunctionType(f)));
}

template void registerCallbackWavm<void, void *>(WasmVm *vm, std::string_view module_name,
Expand Down Expand Up @@ -452,7 +498,7 @@ static bool checkFunctionType(WAVM::Runtime::Function *f, IR::FunctionType t) {
template <typename R, typename... Args>
void getFunctionWavmReturn(WasmVm *vm, std::string_view function_name,
std::function<R(ContextBase *, Args...)> *function, uint32_t) {
auto wavm = static_cast<proxy_wasm::Wavm *>(vm);
auto wavm = static_cast<proxy_wasm::Wavm::Wavm *>(vm);
auto f =
asFunctionNullable(getInstanceExport(wavm->module_instance_, std::string(function_name)));
if (!f)
Expand All @@ -462,18 +508,19 @@ void getFunctionWavmReturn(WasmVm *vm, std::string_view function_name,
return;
}
if (!checkFunctionType(f, inferStdFunctionType(function))) {
error("Bad function signature for: " + std::string(function_name));
wavm->fail(FailState::UnableToInitializeCode,
"Bad function signature for: " + std::string(function_name));
}
*function = [wavm, f, function_name, this](ContextBase *context, Args... args) -> R {
*function = [wavm, f, function_name](ContextBase *context, Args... args) -> R {
WasmUntaggedValue values[] = {args...};
WasmUntaggedValue return_value;
try {
CALL_WITH_CONTEXT(
invokeFunction(wavm->context_, f, getFunctionType(f), &values[0], &return_value),
context);
CALL_WITH_CONTEXT(
invokeFunction(wavm->context_, f, getFunctionType(f), &values[0], &return_value), context,
wavm);
if (!wavm->isFailed()) {
return static_cast<uint32_t>(return_value.i32);
} catch (const std::exception &e) {
error("Function: " + std::string(function_name) + " failed: " + e.what());
} else {
return 0;
}
};
}
Expand All @@ -483,7 +530,7 @@ struct Void {};
template <typename R, typename... Args>
void getFunctionWavmReturn(WasmVm *vm, std::string_view function_name,
std::function<R(ContextBase *, Args...)> *function, Void) {
auto wavm = static_cast<proxy_wasm::Wavm *>(vm);
auto wavm = static_cast<proxy_wasm::Wavm::Wavm *>(vm);
auto f =
asFunctionNullable(getInstanceExport(wavm->module_instance_, std::string(function_name)));
if (!f)
Expand All @@ -493,15 +540,13 @@ void getFunctionWavmReturn(WasmVm *vm, std::string_view function_name,
return;
}
if (!checkFunctionType(f, inferStdFunctionType(function))) {
vm->error("Bad function signature for: " + std::string(function_name));
wavm->fail(FailState::UnableToInitializeCode,
"Bad function signature for: " + std::string(function_name));
}
*function = [wavm, f, function_name, this](ContextBase *context, Args... args) -> R {
*function = [wavm, f, function_name](ContextBase *context, Args... args) -> R {
WasmUntaggedValue values[] = {args...};
try {
CALL_WITH_CONTEXT(invokeFunction(wavm->context_, f, getFunctionType(f), &values[0]), context);
} catch (const std::exception &e) {
error("Function: " + std::string(function_name) + " failed: " + e.what());
}
CALL_WITH_CONTEXT(invokeFunction(wavm->context_, f, getFunctionType(f), &values[0]), context,
wavm);
};
}

Expand Down