diff --git a/include/proxy-wasm/wasm.h b/include/proxy-wasm/wasm.h index 7860039e..b8199ace 100644 --- a/include/proxy-wasm/wasm.h +++ b/include/proxy-wasm/wasm.h @@ -59,6 +59,7 @@ class WasmBase : public std::enable_shared_from_this<WasmBase> { std::string_view vm_key() const { return vm_key_; } WasmVm *wasm_vm() const { return wasm_vm_.get(); } ContextBase *vm_context() const { return vm_context_.get(); } + ContextBase *createAndSaveRootContext(const std::shared_ptr<PluginBase> &plugin); ContextBase *getRootContext(std::string_view plugin_key) { return root_contexts_[std::string(plugin_key)].get(); } diff --git a/src/wasm.cc b/src/wasm.cc index 496b09a0..dd99b21c 100644 --- a/src/wasm.cc +++ b/src/wasm.cc @@ -314,12 +314,17 @@ bool WasmBase::initialize(const std::string &code, bool allow_precompiled) { return !isFailed(); } +ContextBase *WasmBase::createAndSaveRootContext(const std::shared_ptr<PluginBase> &plugin) { + auto context = std::unique_ptr<ContextBase>(createRootContext(plugin)); + auto root_context = context.get(); + root_contexts_[plugin->key()] = std::move(context); + return root_context; +} + ContextBase *WasmBase::getOrCreateRootContext(const std::shared_ptr<PluginBase> &plugin) { auto root_context = getRootContext(plugin->key()); if (!root_context) { - auto context = std::unique_ptr<ContextBase>(createRootContext(plugin)); - root_context = context.get(); - root_contexts_[plugin->key()] = std::move(context); + root_context = createAndSaveRootContext(plugin); } return root_context; } @@ -344,13 +349,11 @@ ContextBase *WasmBase::start(std::shared_ptr<PluginBase> plugin) { it->second->onStart(plugin); return it->second.get(); } - auto context = std::unique_ptr<ContextBase>(createRootContext(plugin)); - auto context_ptr = context.get(); - root_contexts_[plugin->key()] = std::move(context); - if (!context_ptr->onStart(plugin)) { + auto root_context = createAndSaveRootContext(plugin); + if (!root_context->onStart(plugin)) { return nullptr; } - return context_ptr; + return root_context; }; uint32_t WasmBase::allocContextId() {