diff --git a/bazel/repository_locations.bzl b/bazel/repository_locations.bzl index ecb74393ffea..7607f13a46e3 100644 --- a/bazel/repository_locations.bzl +++ b/bazel/repository_locations.bzl @@ -870,8 +870,8 @@ REPOSITORY_LOCATIONS_SPEC = dict( project_name = "WebAssembly for Proxies (C++ host implementation)", project_desc = "WebAssembly for Proxies (C++ host implementation)", project_url = "https://github.com/proxy-wasm/proxy-wasm-cpp-host", - version = "4741d2f1cd5eb250f66d0518238c333353259d56", - sha256 = "30fc4becfcc5a95ac875fc5a0658a91aa7ddedd763b52d7810c13ed35d9d81aa", + version = "eceb02d5b7772ec1cd78a4d35356e57d2e6d59bb", + sha256 = "ae9d9b87d21d95647ebda197d130b37bddc5c6ee3e6630909a231fd55fcc9069", strip_prefix = "proxy-wasm-cpp-host-{version}", urls = ["https://github.com/proxy-wasm/proxy-wasm-cpp-host/archive/{version}.tar.gz"], use_category = ["dataplane_ext"], diff --git a/source/extensions/access_loggers/wasm/config.cc b/source/extensions/access_loggers/wasm/config.cc index 718adb0fad93..8ca765442e9c 100644 --- a/source/extensions/access_loggers/wasm/config.cc +++ b/source/extensions/access_loggers/wasm/config.cc @@ -23,8 +23,6 @@ WasmAccessLogFactory::createAccessLogInstance(const Protobuf::Message& proto_con const auto& config = MessageUtil::downcastAndValidate< const envoy::extensions::access_loggers::wasm::v3::WasmAccessLog&>( proto_config, context.messageValidationVisitor()); - auto access_log = - std::make_shared(config.config().root_id(), nullptr, std::move(filter)); // Create a base WASM to verify that the code loads before setting/cloning the for the // individual threads. @@ -35,25 +33,15 @@ WasmAccessLogFactory::createAccessLogInstance(const Protobuf::Message& proto_con envoy::config::core::v3::TrafficDirection::UNSPECIFIED, context.localInfo(), nullptr /* listener_metadata */); - auto callback = [access_log, &context, plugin](Common::Wasm::WasmHandleSharedPtr base_wasm) { - auto tls_slot = context.threadLocal().allocateSlot(); + auto access_log = std::make_shared(plugin, nullptr, std::move(filter)); + auto callback = [access_log, &context, plugin](Common::Wasm::WasmHandleSharedPtr base_wasm) { // NB: the Slot set() call doesn't complete inline, so all arguments must outlive this call. - tls_slot->set( - [base_wasm, - plugin](Event::Dispatcher& dispatcher) -> std::shared_ptr { - if (!base_wasm) { - // There is no way to prevent the connection at this point. The user could choose to use - // an HTTP Wasm plugin and only handle onLog() which would correctly close the - // connection in onRequestHeaders(). - if (!plugin->fail_open_) { - ENVOY_LOG(critical, "Plugin configured to fail closed failed to load"); - } - return nullptr; - } - return std::static_pointer_cast( - Common::Wasm::getOrCreateThreadLocalWasm(base_wasm, plugin, dispatcher)); - }); + auto tls_slot = + ThreadLocal::TypedSlot::makeUnique(context.threadLocal()); + tls_slot->set([base_wasm, plugin](Event::Dispatcher& dispatcher) { + return Common::Wasm::getOrCreateThreadLocalPlugin(base_wasm, plugin, dispatcher); + }); access_log->setTlsSlot(std::move(tls_slot)); }; diff --git a/source/extensions/access_loggers/wasm/wasm_access_log_impl.h b/source/extensions/access_loggers/wasm/wasm_access_log_impl.h index 5a7654b97bee..94910ef56712 100644 --- a/source/extensions/access_loggers/wasm/wasm_access_log_impl.h +++ b/source/extensions/access_loggers/wasm/wasm_access_log_impl.h @@ -12,13 +12,15 @@ namespace Extensions { namespace AccessLoggers { namespace Wasm { -using Envoy::Extensions::Common::Wasm::WasmHandle; +using Envoy::Extensions::Common::Wasm::PluginHandle; +using Envoy::Extensions::Common::Wasm::PluginSharedPtr; class WasmAccessLog : public AccessLog::Instance { public: - WasmAccessLog(absl::string_view root_id, ThreadLocal::SlotPtr tls_slot, + WasmAccessLog(const PluginSharedPtr& plugin, ThreadLocal::TypedSlotPtr&& tls_slot, AccessLog::FilterPtr filter) - : root_id_(root_id), tls_slot_(std::move(tls_slot)), filter_(std::move(filter)) {} + : plugin_(plugin), tls_slot_(std::move(tls_slot)), filter_(std::move(filter)) {} + void log(const Http::RequestHeaderMap* request_headers, const Http::ResponseHeaderMap* response_headers, const Http::ResponseTrailerMap* response_trailers, @@ -30,20 +32,21 @@ class WasmAccessLog : public AccessLog::Instance { } } - if (tls_slot_->get()) { - tls_slot_->getTyped().wasm()->log(root_id_, request_headers, response_headers, - response_trailers, stream_info); + auto handle = tls_slot_->get(); + if (handle.has_value()) { + handle->wasm()->log(plugin_, request_headers, response_headers, response_trailers, + stream_info); } } - void setTlsSlot(ThreadLocal::SlotPtr tls_slot) { + void setTlsSlot(ThreadLocal::TypedSlotPtr&& tls_slot) { ASSERT(tls_slot_ == nullptr); tls_slot_ = std::move(tls_slot); } private: - std::string root_id_; - ThreadLocal::SlotPtr tls_slot_; + PluginSharedPtr plugin_; + ThreadLocal::TypedSlotPtr tls_slot_; AccessLog::FilterPtr filter_; }; diff --git a/source/extensions/bootstrap/wasm/config.cc b/source/extensions/bootstrap/wasm/config.cc index 3cc0068b9a16..0e8f4caa99ac 100644 --- a/source/extensions/bootstrap/wasm/config.cc +++ b/source/extensions/bootstrap/wasm/config.cc @@ -37,18 +37,18 @@ void WasmFactory::createWasm(const envoy::extensions::wasm::v3::WasmService& con } if (singleton) { // Return a Wasm VM which will be stored as a singleton by the Server. - cb(std::make_unique( - Common::Wasm::getOrCreateThreadLocalWasm(base_wasm, plugin, context.dispatcher()))); + cb(std::make_unique(plugin, Common::Wasm::getOrCreateThreadLocalPlugin( + base_wasm, plugin, context.dispatcher()))); return; } // Per-thread WASM VM. // NB: the Slot set() call doesn't complete inline, so all arguments must outlive this call. - auto tls_slot = context.threadLocal().allocateSlot(); + auto tls_slot = + ThreadLocal::TypedSlot::makeUnique(context.threadLocal()); tls_slot->set([base_wasm, plugin](Event::Dispatcher& dispatcher) { - return std::static_pointer_cast( - Common::Wasm::getOrCreateThreadLocalWasm(base_wasm, plugin, dispatcher)); + return Common::Wasm::getOrCreateThreadLocalPlugin(base_wasm, plugin, dispatcher); }); - cb(std::make_unique(std::move(tls_slot))); + cb(std::make_unique(plugin, std::move(tls_slot))); }; if (!Common::Wasm::createWasm( diff --git a/source/extensions/bootstrap/wasm/config.h b/source/extensions/bootstrap/wasm/config.h index e70306746389..b8f3850ef621 100644 --- a/source/extensions/bootstrap/wasm/config.h +++ b/source/extensions/bootstrap/wasm/config.h @@ -16,14 +16,21 @@ namespace Extensions { namespace Bootstrap { namespace Wasm { +using Envoy::Extensions::Common::Wasm::PluginHandle; +using Envoy::Extensions::Common::Wasm::PluginHandleSharedPtr; +using Envoy::Extensions::Common::Wasm::PluginSharedPtr; + class WasmService { public: - WasmService(Common::Wasm::WasmHandleSharedPtr singleton) : singleton_(std::move(singleton)) {} - WasmService(ThreadLocal::SlotPtr tls_slot) : tls_slot_(std::move(tls_slot)) {} + WasmService(PluginSharedPtr plugin, PluginHandleSharedPtr singleton) + : plugin_(plugin), singleton_(std::move(singleton)) {} + WasmService(PluginSharedPtr plugin, ThreadLocal::TypedSlotPtr&& tls_slot) + : plugin_(plugin), tls_slot_(std::move(tls_slot)) {} private: - Common::Wasm::WasmHandleSharedPtr singleton_; - ThreadLocal::SlotPtr tls_slot_; + PluginSharedPtr plugin_; + PluginHandleSharedPtr singleton_; + ThreadLocal::TypedSlotPtr tls_slot_; }; using WasmServicePtr = std::unique_ptr; diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index e6e4f8ae0f05..f5d0f3183a18 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -810,8 +810,8 @@ BufferInterface* Context::getBuffer(WasmBufferType type) { case WasmBufferType::VmConfiguration: return buffer_.set(wasm()->vm_configuration()); case WasmBufferType::PluginConfiguration: - if (plugin_) { - return buffer_.set(plugin_->plugin_configuration_); + if (temp_plugin_) { + return buffer_.set(temp_plugin_->plugin_configuration_); } return nullptr; case WasmBufferType::HttpRequestBody: @@ -1182,18 +1182,18 @@ bool Context::validateConfiguration(absl::string_view configuration, if (!wasm()->validate_configuration_) { return true; } - plugin_ = plugin_base; + temp_plugin_ = plugin_base; auto result = wasm() ->validate_configuration_(this, id_, static_cast(configuration.size())) .u64_ != 0; - plugin_.reset(); + temp_plugin_.reset(); return result; } absl::string_view Context::getConfiguration() { - if (plugin_) { - return plugin_->plugin_configuration_; + if (temp_plugin_) { + return temp_plugin_->plugin_configuration_; } else { return wasm()->vm_configuration(); } diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index e288c1e50602..657a0331addd 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -31,6 +31,7 @@ using proxy_wasm::ContextBase; using proxy_wasm::Pairs; using proxy_wasm::PairsWithStringValues; using proxy_wasm::PluginBase; +using proxy_wasm::PluginHandleBase; using proxy_wasm::SharedQueueDequeueToken; using proxy_wasm::SharedQueueEnqueueToken; using proxy_wasm::WasmBase; @@ -45,6 +46,7 @@ using GrpcService = envoy::config::core::v3::GrpcService; class Wasm; +using PluginHandleBaseSharedPtr = std::shared_ptr; using WasmHandleBaseSharedPtr = std::shared_ptr; // Opaque context object. diff --git a/source/extensions/common/wasm/wasm.cc b/source/extensions/common/wasm/wasm.cc index ab2a45f0aaf7..1b0e8e513be2 100644 --- a/source/extensions/common/wasm/wasm.cc +++ b/source/extensions/common/wasm/wasm.cc @@ -243,16 +243,16 @@ ContextBase* Wasm::createRootContext(const std::shared_ptr& plugin) ContextBase* Wasm::createVmContext() { return new Context(this); } -void Wasm::log(absl::string_view root_id, const Http::RequestHeaderMap* request_headers, +void Wasm::log(const PluginSharedPtr& plugin, const Http::RequestHeaderMap* request_headers, const Http::ResponseHeaderMap* response_headers, const Http::ResponseTrailerMap* response_trailers, const StreamInfo::StreamInfo& stream_info) { - auto context = getRootContext(root_id); + auto context = getRootContext(plugin, true); context->log(request_headers, response_headers, response_trailers, stream_info); } -void Wasm::onStatsUpdate(absl::string_view root_id, Envoy::Stats::MetricSnapshot& snapshot) { - auto context = getRootContext(root_id); +void Wasm::onStatsUpdate(const PluginSharedPtr& plugin, Envoy::Stats::MetricSnapshot& snapshot) { + auto context = getRootContext(plugin, true); context->onStatsUpdate(snapshot); } @@ -281,6 +281,14 @@ getCloneFactory(WasmExtension* wasm_extension, Event::Dispatcher& dispatcher, }; } +static proxy_wasm::PluginHandleFactory getPluginFactory(WasmExtension* wasm_extension) { + auto wasm_plugin_factory = wasm_extension->pluginFactory(); + return [wasm_plugin_factory](WasmHandleBaseSharedPtr base_wasm, + absl::string_view plugin_key) -> std::shared_ptr { + return wasm_plugin_factory(std::static_pointer_cast(base_wasm), plugin_key); + }; +} + WasmEvent toWasmEvent(const std::shared_ptr& wasm) { if (!wasm) { return WasmEvent::UnableToCreateVM; @@ -474,13 +482,21 @@ bool createWasm(const VmConfig& vm_config, const PluginSharedPtr& plugin, create_root_context_for_testing); } -WasmHandleSharedPtr getOrCreateThreadLocalWasm(const WasmHandleSharedPtr& base_wasm, - const PluginSharedPtr& plugin, - Event::Dispatcher& dispatcher, - CreateContextFn create_root_context_for_testing) { - return std::static_pointer_cast(proxy_wasm::getOrCreateThreadLocalWasm( +PluginHandleSharedPtr +getOrCreateThreadLocalPlugin(const WasmHandleSharedPtr& base_wasm, const PluginSharedPtr& plugin, + Event::Dispatcher& dispatcher, + CreateContextFn create_root_context_for_testing) { + if (!base_wasm) { + if (!plugin->fail_open_) { + ENVOY_LOG_TO_LOGGER(Envoy::Logger::Registry::getLog(Envoy::Logger::Id::wasm), critical, + "Plugin configured to fail closed failed to load"); + } + return nullptr; + } + return std::static_pointer_cast(proxy_wasm::getOrCreateThreadLocalPlugin( std::static_pointer_cast(base_wasm), plugin, - getCloneFactory(getWasmExtension(), dispatcher, create_root_context_for_testing))); + getCloneFactory(getWasmExtension(), dispatcher, create_root_context_for_testing), + getPluginFactory(getWasmExtension()))); } } // namespace Wasm diff --git a/source/extensions/common/wasm/wasm.h b/source/extensions/common/wasm/wasm.h index a812d1a1a522..37091ff9a523 100644 --- a/source/extensions/common/wasm/wasm.h +++ b/source/extensions/common/wasm/wasm.h @@ -54,8 +54,8 @@ class Wasm : public WasmBase, Logger::Loggable { Upstream::ClusterManager& clusterManager() const { return cluster_manager_; } Event::Dispatcher& dispatcher() { return dispatcher_; } - Context* getRootContext(absl::string_view root_id) { - return static_cast(WasmBase::getRootContext(root_id)); + Context* getRootContext(const std::shared_ptr& plugin, bool allow_closed) { + return static_cast(WasmBase::getRootContext(plugin, allow_closed)); } void setTimerPeriod(uint32_t root_context_id, std::chrono::milliseconds period) override; virtual void tickHandler(uint32_t root_context_id); @@ -72,12 +72,13 @@ class Wasm : public WasmBase, Logger::Loggable { void getFunctions() override; // AccessLog::Instance - void log(absl::string_view root_id, const Http::RequestHeaderMap* request_headers, + void log(const PluginSharedPtr& plugin, const Http::RequestHeaderMap* request_headers, const Http::ResponseHeaderMap* response_headers, const Http::ResponseTrailerMap* response_trailers, const StreamInfo::StreamInfo& stream_info); - void onStatsUpdate(absl::string_view root_id, Envoy::Stats::MetricSnapshot& snapshot); + void onStatsUpdate(const PluginSharedPtr& plugin, Envoy::Stats::MetricSnapshot& snapshot); + virtual std::string buildVersion() { return BUILD_VERSION_NUMBER; } void initializeLifecycle(Server::ServerLifecycleNotifier& lifecycle_notifier); @@ -136,6 +137,23 @@ class WasmHandle : public WasmHandleBase, public ThreadLocal::ThreadLocalObject WasmSharedPtr wasm_; }; +using WasmHandleSharedPtr = std::shared_ptr; + +class PluginHandle : public PluginHandleBase, public ThreadLocal::ThreadLocalObject { +public: + explicit PluginHandle(const WasmHandleSharedPtr& wasm_handle, absl::string_view plugin_key) + : PluginHandleBase(std::static_pointer_cast(wasm_handle), plugin_key), + wasm_handle_(wasm_handle) {} + + WasmSharedPtr& wasm() { return wasm_handle_->wasm(); } + WasmHandleSharedPtr& wasmHandleForTest() { return wasm_handle_; } + +private: + WasmHandleSharedPtr wasm_handle_; +}; + +using PluginHandleSharedPtr = std::shared_ptr; + using CreateWasmCallback = std::function; // Returns false if createWasm failed synchronously. This is necessary because xDS *MUST* report @@ -150,10 +168,10 @@ bool createWasm(const VmConfig& vm_config, const PluginSharedPtr& plugin, CreateWasmCallback&& callback, CreateContextFn create_root_context_for_testing = nullptr); -WasmHandleSharedPtr -getOrCreateThreadLocalWasm(const WasmHandleSharedPtr& base_wasm, const PluginSharedPtr& plugin, - Event::Dispatcher& dispatcher, - CreateContextFn create_root_context_for_testing = nullptr); +PluginHandleSharedPtr +getOrCreateThreadLocalPlugin(const WasmHandleSharedPtr& base_wasm, const PluginSharedPtr& plugin, + Event::Dispatcher& dispatcher, + CreateContextFn create_root_context_for_testing = nullptr); void clearCodeCacheForTesting(); std::string anyToBytes(const ProtobufWkt::Any& any); diff --git a/source/extensions/common/wasm/wasm_extension.cc b/source/extensions/common/wasm/wasm_extension.cc index c75168f1761c..1917fa792a82 100644 --- a/source/extensions/common/wasm/wasm_extension.cc +++ b/source/extensions/common/wasm/wasm_extension.cc @@ -46,6 +46,14 @@ EnvoyWasm::createEnvoyWasmVmIntegration(const Stats::ScopeSharedPtr& scope, return std::make_unique(scope, runtime, short_runtime); } +PluginHandleExtensionFactory EnvoyWasm::pluginFactory() { + return [](const WasmHandleSharedPtr& base_wasm, + absl::string_view plugin_key) -> PluginHandleBaseSharedPtr { + return std::static_pointer_cast( + std::make_shared(base_wasm, plugin_key)); + }; +} + WasmHandleExtensionFactory EnvoyWasm::wasmFactory() { return [](const VmConfig vm_config, const Stats::ScopeSharedPtr& scope, Upstream::ClusterManager& cluster_manager, Event::Dispatcher& dispatcher, diff --git a/source/extensions/common/wasm/wasm_extension.h b/source/extensions/common/wasm/wasm_extension.h index 5d41a58bb337..22ae373162f2 100644 --- a/source/extensions/common/wasm/wasm_extension.h +++ b/source/extensions/common/wasm/wasm_extension.h @@ -31,6 +31,8 @@ class EnvoyWasmVmIntegration; using WasmHandleSharedPtr = std::shared_ptr; using CreateContextFn = std::function& plugin)>; +using PluginHandleExtensionFactory = std::function; using WasmHandleExtensionFactory = std::function { virtual std::unique_ptr createEnvoyWasmVmIntegration(const Stats::ScopeSharedPtr& scope, absl::string_view runtime, absl::string_view short_runtime) = 0; + virtual PluginHandleExtensionFactory pluginFactory() = 0; virtual WasmHandleExtensionFactory wasmFactory() = 0; virtual WasmHandleExtensionCloneFactory wasmCloneFactory() = 0; enum class WasmEvent : int { @@ -100,6 +103,7 @@ class EnvoyWasm : public WasmExtension { std::unique_ptr createEnvoyWasmVmIntegration(const Stats::ScopeSharedPtr& scope, absl::string_view runtime, absl::string_view short_runtime) override; + PluginHandleExtensionFactory pluginFactory() override; WasmHandleExtensionFactory wasmFactory() override; WasmHandleExtensionCloneFactory wasmCloneFactory() override; void onEvent(WasmEvent event, const PluginSharedPtr& plugin) override; diff --git a/source/extensions/filters/http/wasm/wasm_filter.cc b/source/extensions/filters/http/wasm/wasm_filter.cc index c62b06c4102d..90713ba01989 100644 --- a/source/extensions/filters/http/wasm/wasm_filter.cc +++ b/source/extensions/filters/http/wasm/wasm_filter.cc @@ -1,14 +1,5 @@ #include "extensions/filters/http/wasm/wasm_filter.h" -#include "envoy/http/codes.h" - -#include "common/buffer/buffer_impl.h" -#include "common/common/assert.h" -#include "common/common/enum_to_int.h" -#include "common/http/header_map_impl.h" -#include "common/http/message_impl.h" -#include "common/http/utility.h" - namespace Envoy { namespace Extensions { namespace HttpFilters { @@ -16,7 +7,8 @@ namespace Wasm { FilterConfig::FilterConfig(const envoy::extensions::filters::http::wasm::v3::Wasm& config, Server::Configuration::FactoryContext& context) - : tls_slot_(context.threadLocal().allocateSlot()) { + : tls_slot_( + ThreadLocal::TypedSlot::makeUnique(context.threadLocal())) { plugin_ = std::make_shared( config.config().name(), config.config().root_id(), config.config().vm_config().vm_id(), config.config().vm_config().runtime(), @@ -26,15 +18,9 @@ FilterConfig::FilterConfig(const envoy::extensions::filters::http::wasm::v3::Was auto plugin = plugin_; auto callback = [plugin, this](const Common::Wasm::WasmHandleSharedPtr& base_wasm) { // NB: the Slot set() call doesn't complete inline, so all arguments must outlive this call. - tls_slot_->set( - [base_wasm, - plugin](Event::Dispatcher& dispatcher) -> std::shared_ptr { - if (!base_wasm) { - return nullptr; - } - return std::static_pointer_cast( - Common::Wasm::getOrCreateThreadLocalWasm(base_wasm, plugin, dispatcher)); - }); + tls_slot_->set([base_wasm, plugin](Event::Dispatcher& dispatcher) { + return Common::Wasm::getOrCreateThreadLocalPlugin(base_wasm, plugin, dispatcher); + }); }; if (!Common::Wasm::createWasm( diff --git a/source/extensions/filters/http/wasm/wasm_filter.h b/source/extensions/filters/http/wasm/wasm_filter.h index 36bfd1503b77..956862e5f09b 100644 --- a/source/extensions/filters/http/wasm/wasm_filter.h +++ b/source/extensions/filters/http/wasm/wasm_filter.h @@ -16,8 +16,9 @@ namespace HttpFilters { namespace Wasm { using Envoy::Extensions::Common::Wasm::Context; +using Envoy::Extensions::Common::Wasm::PluginHandle; +using Envoy::Extensions::Common::Wasm::PluginSharedPtr; using Envoy::Extensions::Common::Wasm::Wasm; -using Envoy::Extensions::Common::Wasm::WasmHandle; class FilterConfig : Logger::Loggable { public: @@ -26,22 +27,23 @@ class FilterConfig : Logger::Loggable { std::shared_ptr createFilter() { Wasm* wasm = nullptr; - if (tls_slot_->get()) { - wasm = tls_slot_->getTyped().wasm().get(); + auto handle = tls_slot_->get(); + if (handle.has_value()) { + wasm = handle->wasm().get(); } if (plugin_->fail_open_ && (!wasm || wasm->isFailed())) { return nullptr; } if (wasm && !root_context_id_) { - root_context_id_ = wasm->getRootContext(plugin_->root_id_)->id(); + root_context_id_ = wasm->getRootContext(plugin_, false)->id(); } return std::make_shared(wasm, root_context_id_, plugin_); } private: uint32_t root_context_id_{0}; - Envoy::Extensions::Common::Wasm::PluginSharedPtr plugin_; - ThreadLocal::SlotPtr tls_slot_; + PluginSharedPtr plugin_; + ThreadLocal::TypedSlotPtr tls_slot_; Config::DataSource::RemoteAsyncDataProviderPtr remote_data_provider_; }; diff --git a/source/extensions/filters/network/wasm/wasm_filter.cc b/source/extensions/filters/network/wasm/wasm_filter.cc index 9d253b675abd..ccceeb9dc478 100644 --- a/source/extensions/filters/network/wasm/wasm_filter.cc +++ b/source/extensions/filters/network/wasm/wasm_filter.cc @@ -1,9 +1,5 @@ #include "extensions/filters/network/wasm/wasm_filter.h" -#include "common/buffer/buffer_impl.h" -#include "common/common/assert.h" -#include "common/common/enum_to_int.h" - namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -11,7 +7,8 @@ namespace Wasm { FilterConfig::FilterConfig(const envoy::extensions::filters::network::wasm::v3::Wasm& config, Server::Configuration::FactoryContext& context) - : tls_slot_(context.threadLocal().allocateSlot()) { + : tls_slot_( + ThreadLocal::TypedSlot::makeUnique(context.threadLocal())) { plugin_ = std::make_shared( config.config().name(), config.config().root_id(), config.config().vm_config().vm_id(), config.config().vm_config().runtime(), @@ -21,15 +18,9 @@ FilterConfig::FilterConfig(const envoy::extensions::filters::network::wasm::v3:: auto plugin = plugin_; auto callback = [plugin, this](Common::Wasm::WasmHandleSharedPtr base_wasm) { // NB: the Slot set() call doesn't complete inline, so all arguments must outlive this call. - tls_slot_->set( - [base_wasm, - plugin](Event::Dispatcher& dispatcher) -> std::shared_ptr { - if (!base_wasm) { - return nullptr; - } - return std::static_pointer_cast( - Common::Wasm::getOrCreateThreadLocalWasm(base_wasm, plugin, dispatcher)); - }); + tls_slot_->set([base_wasm, plugin](Event::Dispatcher& dispatcher) { + return Common::Wasm::getOrCreateThreadLocalPlugin(base_wasm, plugin, dispatcher); + }); }; if (!Common::Wasm::createWasm( diff --git a/source/extensions/filters/network/wasm/wasm_filter.h b/source/extensions/filters/network/wasm/wasm_filter.h index 51adbcd7ac0c..6a6fe2584b2c 100644 --- a/source/extensions/filters/network/wasm/wasm_filter.h +++ b/source/extensions/filters/network/wasm/wasm_filter.h @@ -16,8 +16,9 @@ namespace NetworkFilters { namespace Wasm { using Envoy::Extensions::Common::Wasm::Context; +using Envoy::Extensions::Common::Wasm::PluginHandle; +using Envoy::Extensions::Common::Wasm::PluginSharedPtr; using Envoy::Extensions::Common::Wasm::Wasm; -using Envoy::Extensions::Common::Wasm::WasmHandle; class FilterConfig : Logger::Loggable { public: @@ -26,25 +27,25 @@ class FilterConfig : Logger::Loggable { std::shared_ptr createFilter() { Wasm* wasm = nullptr; - if (tls_slot_->get()) { - wasm = tls_slot_->getTyped().wasm().get(); + auto handle = tls_slot_->get(); + if (handle.has_value()) { + wasm = handle->wasm().get(); } if (plugin_->fail_open_ && (!wasm || wasm->isFailed())) { return nullptr; } if (wasm && !root_context_id_) { - root_context_id_ = wasm->getRootContext(plugin_->root_id_)->id(); + root_context_id_ = wasm->getRootContext(plugin_, false)->id(); } return std::make_shared(wasm, root_context_id_, plugin_); } - Envoy::Extensions::Common::Wasm::Wasm* wasm() { - return tls_slot_->getTyped().wasm().get(); - } + + Wasm* wasmForTest() { return tls_slot_->get()->wasm().get(); } private: uint32_t root_context_id_{0}; - Envoy::Extensions::Common::Wasm::PluginSharedPtr plugin_; - ThreadLocal::SlotPtr tls_slot_; + PluginSharedPtr plugin_; + ThreadLocal::TypedSlotPtr tls_slot_; Config::DataSource::RemoteAsyncDataProviderPtr remote_data_provider_; }; diff --git a/source/extensions/stat_sinks/wasm/config.cc b/source/extensions/stat_sinks/wasm/config.cc index ba94937a3b3a..da07bbdd5880 100644 --- a/source/extensions/stat_sinks/wasm/config.cc +++ b/source/extensions/stat_sinks/wasm/config.cc @@ -22,14 +22,14 @@ WasmSinkFactory::createStatsSink(const Protobuf::Message& proto_config, MessageUtil::downcastAndValidate( proto_config, context.messageValidationContext().staticValidationVisitor()); - auto wasm_sink = std::make_unique(config.config().root_id(), nullptr); - auto plugin = std::make_shared( config.config().name(), config.config().root_id(), config.config().vm_config().vm_id(), config.config().vm_config().runtime(), Common::Wasm::anyToBytes(config.config().configuration()), config.config().fail_open(), envoy::config::core::v3::TrafficDirection::UNSPECIFIED, context.localInfo(), nullptr); + auto wasm_sink = std::make_unique(plugin, nullptr); + auto callback = [&wasm_sink, &context, plugin](Common::Wasm::WasmHandleSharedPtr base_wasm) { if (!base_wasm) { if (plugin->fail_open_) { @@ -40,7 +40,7 @@ WasmSinkFactory::createStatsSink(const Protobuf::Message& proto_config, return; } wasm_sink->setSingleton( - Common::Wasm::getOrCreateThreadLocalWasm(base_wasm, plugin, context.dispatcher())); + Common::Wasm::getOrCreateThreadLocalPlugin(base_wasm, plugin, context.dispatcher())); }; if (!Common::Wasm::createWasm( diff --git a/source/extensions/stat_sinks/wasm/wasm_stat_sink_impl.h b/source/extensions/stat_sinks/wasm/wasm_stat_sink_impl.h index 5f2a9b6e13f9..5b339a6c80e9 100644 --- a/source/extensions/stat_sinks/wasm/wasm_stat_sink_impl.h +++ b/source/extensions/stat_sinks/wasm/wasm_stat_sink_impl.h @@ -9,20 +9,21 @@ namespace Extensions { namespace StatSinks { namespace Wasm { -using Envoy::Extensions::Common::Wasm::WasmHandle; +using Envoy::Extensions::Common::Wasm::PluginHandleSharedPtr; +using Envoy::Extensions::Common::Wasm::PluginSharedPtr; class WasmStatSink : public Stats::Sink { public: - WasmStatSink(absl::string_view root_id, Common::Wasm::WasmHandleSharedPtr singleton) - : root_id_(root_id), singleton_(std::move(singleton)) {} + WasmStatSink(const PluginSharedPtr& plugin, PluginHandleSharedPtr singleton) + : plugin_(plugin), singleton_(singleton) {} void flush(Stats::MetricSnapshot& snapshot) override { - singleton_->wasm()->onStatsUpdate(root_id_, snapshot); + singleton_->wasm()->onStatsUpdate(plugin_, snapshot); } - void setSingleton(Common::Wasm::WasmHandleSharedPtr singleton) { + void setSingleton(PluginHandleSharedPtr singleton) { ASSERT(singleton != nullptr); - singleton_ = std::move(singleton); + singleton_ = singleton; } void onHistogramComplete(const Stats::Histogram& histogram, uint64_t value) override { @@ -31,8 +32,8 @@ class WasmStatSink : public Stats::Sink { } private: - std::string root_id_; - Common::Wasm::WasmHandleSharedPtr singleton_; + PluginSharedPtr plugin_; + PluginHandleSharedPtr singleton_; }; } // namespace Wasm diff --git a/test/extensions/bootstrap/wasm/test_data/logging_cpp.cc b/test/extensions/bootstrap/wasm/test_data/logging_cpp.cc index 70fde8f6ae19..9a5becfab1b3 100644 --- a/test/extensions/bootstrap/wasm/test_data/logging_cpp.cc +++ b/test/extensions/bootstrap/wasm/test_data/logging_cpp.cc @@ -26,10 +26,7 @@ extern "C" PROXY_WASM_KEEPALIVE uint32_t proxy_on_configure(uint32_t, uint32_t c extern "C" PROXY_WASM_KEEPALIVE void proxy_on_context_create(uint32_t, uint32_t) {} -extern "C" PROXY_WASM_KEEPALIVE uint32_t proxy_on_vm_start(uint32_t, uint32_t) { - proxy_set_tick_period_milliseconds(10); - return 1; -} +extern "C" PROXY_WASM_KEEPALIVE uint32_t proxy_on_vm_start(uint32_t, uint32_t) { return 1; } extern "C" PROXY_WASM_KEEPALIVE void proxy_on_tick(uint32_t) { const char* root_id = nullptr; diff --git a/test/extensions/bootstrap/wasm/wasm_test.cc b/test/extensions/bootstrap/wasm/wasm_test.cc index fdfdd9536557..5384c66dcae8 100644 --- a/test/extensions/bootstrap/wasm/wasm_test.cc +++ b/test/extensions/bootstrap/wasm/wasm_test.cc @@ -223,7 +223,7 @@ TEST_P(WasmTest, Segv) { auto context = static_cast(wasm_->start(plugin_)); EXPECT_CALL(*context, log_(spdlog::level::err, Eq("before badptr"))); EXPECT_FALSE(wasm_->configure(context, plugin_)); - wasm_->isFailed(); + EXPECT_TRUE(wasm_->isFailed()); } TEST_P(WasmTest, DivByZero) { @@ -235,7 +235,7 @@ TEST_P(WasmTest, DivByZero) { auto context = static_cast(wasm_->start(plugin_)); EXPECT_CALL(*context, log_(spdlog::level::err, Eq("before div by zero"))); context->onLog(); - wasm_->isFailed(); + EXPECT_TRUE(wasm_->isFailed()); } TEST_P(WasmTest, IntrinsicGlobals) { diff --git a/test/extensions/common/wasm/test_data/test_cpp.cc b/test/extensions/common/wasm/test_data/test_cpp.cc index 1d990901846a..f2003072ee8c 100644 --- a/test/extensions/common/wasm/test_data/test_cpp.cc +++ b/test/extensions/common/wasm/test_data/test_cpp.cc @@ -267,6 +267,10 @@ WASM_EXPORT(uint32_t, proxy_on_done, (uint32_t)) { return 0; } +WASM_EXPORT(void, proxy_on_tick, (uint32_t)) { + proxy_done(); +} + WASM_EXPORT(void, proxy_on_delete, (uint32_t)) { std::string message = "on_delete logging"; proxy_log(LogLevel::info, message.c_str(), message.size()); diff --git a/test/extensions/common/wasm/wasm_test.cc b/test/extensions/common/wasm/wasm_test.cc index 110f7f720eec..6d903abdbec2 100644 --- a/test/extensions/common/wasm/wasm_test.cc +++ b/test/extensions/common/wasm/wasm_test.cc @@ -237,8 +237,8 @@ TEST_P(WasmCommonTest, Logging) { wasm_handle.reset(); dispatcher->run(Event::Dispatcher::RunType::NonBlock); // This will fault on nullptr if wasm has been deleted. - plugin->plugin_configuration_ = "done"; - wasm_weak.lock()->configure(root_context, plugin); + wasm_weak.lock()->setTimerPeriod(root_context->id(), std::chrono::milliseconds(10)); + wasm_weak.lock()->tickHandler(root_context->id()); dispatcher->run(Event::Dispatcher::RunType::NonBlock); dispatcher->clearDeferredDeleteList(); } @@ -648,7 +648,7 @@ TEST_P(WasmCommonTest, VmCache) { auto root_id = ""; auto vm_id = ""; auto vm_configuration = "vm_cache"; - auto plugin_configuration = "init"; + auto plugin_configuration = "done"; auto plugin = std::make_shared( name, root_id, vm_id, GetParam(), plugin_configuration, false, envoy::config::core::v3::TrafficDirection::UNSPECIFIED, local_info, nullptr); @@ -692,7 +692,7 @@ TEST_P(WasmCommonTest, VmCache) { EXPECT_NE(wasm_handle2, nullptr); EXPECT_EQ(wasm_handle, wasm_handle2); - auto wasm_handle_local = getOrCreateThreadLocalWasm( + auto plugin_handle_local = getOrCreateThreadLocalPlugin( wasm_handle, plugin, [&dispatcher](const WasmHandleBaseSharedPtr& base_wasm) -> WasmHandleBaseSharedPtr { auto wasm = @@ -701,22 +701,24 @@ TEST_P(WasmCommonTest, VmCache) { nullptr, [](Wasm* wasm, const std::shared_ptr& plugin) -> ContextBase* { auto root_context = new TestContext(wasm, plugin); EXPECT_CALL(*root_context, log_(spdlog::level::info, Eq("on_vm_start vm_cache"))); - EXPECT_CALL(*root_context, log_(spdlog::level::info, Eq("on_configuration init"))); EXPECT_CALL(*root_context, log_(spdlog::level::info, Eq("on_done logging"))); EXPECT_CALL(*root_context, log_(spdlog::level::info, Eq("on_delete logging"))); return root_context; }); return std::make_shared(wasm); + }, + [](const WasmHandleBaseSharedPtr& base_wasm, + absl::string_view plugin_key) -> PluginHandleBaseSharedPtr { + return std::make_shared(std::static_pointer_cast(base_wasm), + plugin_key); }); wasm_handle.reset(); wasm_handle2.reset(); - auto wasm = wasm_handle_local->wasm().get(); - wasm_handle_local.reset(); + auto wasm = plugin_handle_local->wasm(); + plugin_handle_local.reset(); dispatcher->run(Event::Dispatcher::RunType::NonBlock); - - plugin->plugin_configuration_ = "done"; wasm->configure(wasm->getContext(1), plugin); plugin.reset(); dispatcher->run(Event::Dispatcher::RunType::NonBlock); @@ -795,7 +797,7 @@ TEST_P(WasmCommonTest, RemoteCode) { EXPECT_NE(wasm_handle, nullptr); - auto wasm_handle_local = getOrCreateThreadLocalWasm( + auto plugin_handle_local = getOrCreateThreadLocalPlugin( wasm_handle, plugin, [&dispatcher](const WasmHandleBaseSharedPtr& base_wasm) -> WasmHandleBaseSharedPtr { auto wasm = @@ -809,11 +811,17 @@ TEST_P(WasmCommonTest, RemoteCode) { return root_context; }); return std::make_shared(wasm); + }, + [](const WasmHandleBaseSharedPtr& base_wasm, + absl::string_view plugin_key) -> PluginHandleBaseSharedPtr { + return std::make_shared(std::static_pointer_cast(base_wasm), + plugin_key); }); wasm_handle.reset(); - auto wasm = wasm_handle_local->wasm().get(); - wasm_handle_local.reset(); + auto wasm = plugin_handle_local->wasm(); + plugin_handle_local.reset(); + dispatcher->run(Event::Dispatcher::RunType::NonBlock); wasm->configure(wasm->getContext(1), plugin); plugin.reset(); @@ -905,7 +913,7 @@ TEST_P(WasmCommonTest, RemoteCodeMultipleRetry) { dispatcher->run(Event::Dispatcher::RunType::NonBlock); EXPECT_NE(wasm_handle, nullptr); - auto wasm_handle_local = getOrCreateThreadLocalWasm( + auto plugin_handle_local = getOrCreateThreadLocalPlugin( wasm_handle, plugin, [&dispatcher](const WasmHandleBaseSharedPtr& base_wasm) -> WasmHandleBaseSharedPtr { auto wasm = @@ -919,11 +927,16 @@ TEST_P(WasmCommonTest, RemoteCodeMultipleRetry) { return root_context; }); return std::make_shared(wasm); + }, + [](const WasmHandleBaseSharedPtr& base_wasm, + absl::string_view plugin_key) -> PluginHandleBaseSharedPtr { + return std::make_shared(std::static_pointer_cast(base_wasm), + plugin_key); }); wasm_handle.reset(); - auto wasm = wasm_handle_local->wasm().get(); - wasm_handle_local.reset(); + auto wasm = plugin_handle_local->wasm(); + plugin_handle_local.reset(); dispatcher->run(Event::Dispatcher::RunType::NonBlock); wasm->configure(wasm->getContext(1), plugin); diff --git a/test/extensions/filters/http/wasm/wasm_filter_test.cc b/test/extensions/filters/http/wasm/wasm_filter_test.cc index 9d3cda60b6ec..a80dcd64a8e2 100644 --- a/test/extensions/filters/http/wasm/wasm_filter_test.cc +++ b/test/extensions/filters/http/wasm/wasm_filter_test.cc @@ -82,7 +82,7 @@ class WasmHttpFilterTest : public Common::Wasm::WasmHttpFilterTestBase< } setupBase(std::get<0>(GetParam()), code, createContextFn(), root_id, vm_configuration); } - void setupFilter(const std::string root_id = "") { setupFilterBase(root_id); } + void setupFilter() { setupFilterBase(); } void setupGrpcStreamTest(Grpc::RawAsyncStreamCallbacks*& callbacks); @@ -294,7 +294,7 @@ TEST_P(WasmHttpFilterTest, HeadersStopAndWatermark) { // Script that reads the body. TEST_P(WasmHttpFilterTest, BodyRequestReadBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::err, Eq(absl::string_view("onBody hello")))); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}, {"x-test-operation", "ReadBody"}}; EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter().decodeHeaders(request_headers, false)); @@ -306,7 +306,7 @@ TEST_P(WasmHttpFilterTest, BodyRequestReadBody) { // Script that prepends and appends to the body. TEST_P(WasmHttpFilterTest, BodyRequestPrependAndAppendToBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::err, Eq(absl::string_view("onBody prepend.hello.append")))); EXPECT_CALL(filter(), log_(spdlog::level::err, @@ -330,7 +330,7 @@ TEST_P(WasmHttpFilterTest, BodyRequestPrependAndAppendToBody) { // Script that replaces the body. TEST_P(WasmHttpFilterTest, BodyRequestReplaceBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::err, Eq(absl::string_view("onBody replace")))).Times(2); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}, {"x-test-operation", "ReplaceBody"}}; @@ -351,7 +351,7 @@ TEST_P(WasmHttpFilterTest, BodyRequestReplaceBody) { // Script that removes the body. TEST_P(WasmHttpFilterTest, BodyRequestRemoveBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::err, Eq(absl::string_view("onBody ")))); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}, {"x-test-operation", "RemoveBody"}}; @@ -364,7 +364,7 @@ TEST_P(WasmHttpFilterTest, BodyRequestRemoveBody) { // Script that buffers the body. TEST_P(WasmHttpFilterTest, BodyRequestBufferBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}, {"x-test-operation", "BufferBody"}}; @@ -407,7 +407,7 @@ TEST_P(WasmHttpFilterTest, BodyRequestBufferBody) { // Script that prepends and appends to the buffered body. TEST_P(WasmHttpFilterTest, BodyRequestPrependAndAppendToBufferedBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::err, Eq(absl::string_view("onBody prepend.hello.append")))); Http::TestRequestHeaderMapImpl request_headers{ @@ -421,7 +421,7 @@ TEST_P(WasmHttpFilterTest, BodyRequestPrependAndAppendToBufferedBody) { // Script that replaces the buffered body. TEST_P(WasmHttpFilterTest, BodyRequestReplaceBufferedBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::err, Eq(absl::string_view("onBody replace")))); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}, {"x-test-operation", "ReplaceBufferedBody"}}; @@ -434,7 +434,7 @@ TEST_P(WasmHttpFilterTest, BodyRequestReplaceBufferedBody) { // Script that removes the buffered body. TEST_P(WasmHttpFilterTest, BodyRequestRemoveBufferedBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::err, Eq(absl::string_view("onBody ")))); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}, {"x-test-operation", "RemoveBufferedBody"}}; @@ -447,7 +447,7 @@ TEST_P(WasmHttpFilterTest, BodyRequestRemoveBufferedBody) { // Script that buffers the first part of the body and streams the rest TEST_P(WasmHttpFilterTest, BodyRequestBufferThenStreamBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter().decodeHeaders(request_headers, false)); @@ -497,7 +497,7 @@ TEST_P(WasmHttpFilterTest, BodyRequestBufferThenStreamBody) { // Script that buffers the first part of the body and streams the rest TEST_P(WasmHttpFilterTest, BodyResponseBufferThenStreamBody) { setupTest("body"); - setupFilter("body"); + setupFilter(); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter().decodeHeaders(request_headers, false)); @@ -580,7 +580,7 @@ TEST_P(WasmHttpFilterTest, AccessLogCreate) { TEST_P(WasmHttpFilterTest, AsyncCall) { setupTest("async_call"); - setupFilter("async_call"); + setupFilter(); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; Http::MockAsyncClientRequest request(&cluster_manager_.async_client_); @@ -627,7 +627,7 @@ TEST_P(WasmHttpFilterTest, AsyncCallBadCall) { return; } setupTest("async_call"); - setupFilter("async_call"); + setupFilter(); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; Http::MockAsyncClientRequest request(&cluster_manager_.async_client_); @@ -647,7 +647,7 @@ TEST_P(WasmHttpFilterTest, AsyncCallBadCall) { TEST_P(WasmHttpFilterTest, AsyncCallFailure) { setupTest("async_call"); - setupFilter("async_call"); + setupFilter(); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; Http::MockAsyncClientRequest request(&cluster_manager_.async_client_); @@ -688,7 +688,7 @@ TEST_P(WasmHttpFilterTest, AsyncCallFailure) { TEST_P(WasmHttpFilterTest, AsyncCallAfterDestroyed) { setupTest("async_call"); - setupFilter("async_call"); + setupFilter(); Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}}; Http::MockAsyncClientRequest request(&cluster_manager_.async_client_); @@ -719,6 +719,7 @@ TEST_P(WasmHttpFilterTest, AsyncCallAfterDestroyed) { // Destroy the Context, Plugin and VM. context_.reset(); plugin_.reset(); + plugin_handle_.reset(); wasm_.reset(); Http::ResponseMessagePtr response_message(new Http::ResponseMessageImpl( @@ -738,7 +739,7 @@ TEST_P(WasmHttpFilterTest, GrpcCall) { return; } setupTest("grpc_call"); - setupFilter("grpc_call"); + setupFilter(); NiceMock request; Grpc::RawAsyncRequestCallbacks* callbacks = nullptr; Grpc::MockAsyncClientManager client_manager; @@ -793,7 +794,7 @@ TEST_P(WasmHttpFilterTest, GrpcCallBadCall) { return; } setupTest("grpc_call"); - setupFilter("grpc_call"); + setupFilter(); Grpc::MockAsyncClientManager client_manager; auto client_factory = std::make_unique(); auto async_client = std::make_unique(); @@ -822,7 +823,7 @@ TEST_P(WasmHttpFilterTest, GrpcCallFailure) { return; } setupTest("grpc_call"); - setupFilter("grpc_call"); + setupFilter(); NiceMock request; Grpc::RawAsyncRequestCallbacks* callbacks = nullptr; Grpc::MockAsyncClientManager client_manager; @@ -884,7 +885,7 @@ TEST_P(WasmHttpFilterTest, GrpcCallCancel) { return; } setupTest("grpc_call"); - setupFilter("grpc_call"); + setupFilter(); NiceMock request; Grpc::RawAsyncRequestCallbacks* callbacks = nullptr; Grpc::MockAsyncClientManager client_manager; @@ -928,7 +929,7 @@ TEST_P(WasmHttpFilterTest, GrpcCallClose) { return; } setupTest("grpc_call"); - setupFilter("grpc_call"); + setupFilter(); NiceMock request; Grpc::RawAsyncRequestCallbacks* callbacks = nullptr; Grpc::MockAsyncClientManager client_manager; @@ -972,7 +973,7 @@ TEST_P(WasmHttpFilterTest, GrpcCallAfterDestroyed) { return; } setupTest("grpc_call"); - setupFilter("grpc_call"); + setupFilter(); Grpc::MockAsyncRequest request; Grpc::RawAsyncRequestCallbacks* callbacks = nullptr; Grpc::MockAsyncClientManager client_manager; @@ -1013,6 +1014,7 @@ TEST_P(WasmHttpFilterTest, GrpcCallAfterDestroyed) { // Destroy the Context, Plugin and VM. context_.reset(); plugin_.reset(); + plugin_handle_.reset(); wasm_.reset(); ProtobufWkt::Value value; @@ -1029,7 +1031,7 @@ TEST_P(WasmHttpFilterTest, GrpcCallAfterDestroyed) { void WasmHttpFilterTest::setupGrpcStreamTest(Grpc::RawAsyncStreamCallbacks*& callbacks) { setupTest("grpc_stream"); - setupFilter("grpc_stream"); + setupFilter(); EXPECT_CALL(async_client_manager_, factoryForGrpcService(_, _, _)) .WillRepeatedly( @@ -1206,6 +1208,7 @@ TEST_P(WasmHttpFilterTest, GrpcStreamOpenAtShutdown) { // Destroy the Context, Plugin and VM. context_.reset(); plugin_.reset(); + plugin_handle_.reset(); wasm_.reset(); } @@ -1384,7 +1387,7 @@ TEST_P(WasmHttpFilterTest, SharedData) { TEST_P(WasmHttpFilterTest, SharedQueue) { setupTest("shared_queue"); - setupFilter("shared_queue"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::warn, Eq(absl::string_view("onRequestHeaders enqueue Ok")))); EXPECT_CALL(filter(), log_(spdlog::level::warn, @@ -1423,7 +1426,7 @@ TEST_P(WasmHttpFilterTest, RootId1) { return; } setupTest("context1"); - setupFilter("context1"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::debug, Eq(absl::string_view("onRequestHeaders1 2")))); Http::TestRequestHeaderMapImpl request_headers; EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter().decodeHeaders(request_headers, true)); @@ -1436,7 +1439,7 @@ TEST_P(WasmHttpFilterTest, RootId2) { return; } setupTest("context2"); - setupFilter("context2"); + setupFilter(); EXPECT_CALL(filter(), log_(spdlog::level::debug, Eq(absl::string_view("onRequestHeaders2 2")))); Http::TestRequestHeaderMapImpl request_headers; EXPECT_EQ(Http::FilterHeadersStatus::StopAllIterationAndWatermark, diff --git a/test/extensions/filters/network/wasm/config_test.cc b/test/extensions/filters/network/wasm/config_test.cc index f1507071ef78..68541490a82c 100644 --- a/test/extensions/filters/network/wasm/config_test.cc +++ b/test/extensions/filters/network/wasm/config_test.cc @@ -183,7 +183,7 @@ TEST_P(WasmNetworkFilterConfigTest, FilterConfigFailOpen) { envoy::extensions::filters::network::wasm::v3::Wasm proto_config; TestUtility::loadFromYaml(yaml, proto_config); NetworkFilters::Wasm::FilterConfig filter_config(proto_config, context_); - filter_config.wasm()->fail(proxy_wasm::FailState::RuntimeError, ""); + filter_config.wasmForTest()->fail(proxy_wasm::FailState::RuntimeError, ""); EXPECT_EQ(filter_config.createFilter(), nullptr); } diff --git a/test/extensions/filters/network/wasm/wasm_filter_test.cc b/test/extensions/filters/network/wasm/wasm_filter_test.cc index dd3a2e29a0c2..517d37cee3b5 100644 --- a/test/extensions/filters/network/wasm/wasm_filter_test.cc +++ b/test/extensions/filters/network/wasm/wasm_filter_test.cc @@ -58,7 +58,7 @@ class WasmNetworkFilterTest : public Common::Wasm::WasmNetworkFilterTestBase< "" /* root_id */, "" /* vm_configuration */, fail_open); } - void setupFilter() { setupFilterBase(""); } + void setupFilter() { setupFilterBase(); } TestFilter& filter() { return *static_cast(context_.get()); } diff --git a/test/test_common/wasm_base.h b/test/test_common/wasm_base.h index d049460d3e41..2cfc796084eb 100644 --- a/test/test_common/wasm_base.h +++ b/test/test_common/wasm_base.h @@ -80,12 +80,13 @@ template class WasmTestBase : public Base { lifecycle_notifier_, remote_data_provider_, [this](WasmHandleSharedPtr wasm) { wasm_ = wasm; }, create_root); if (wasm_) { - wasm_ = getOrCreateThreadLocalWasm( + plugin_handle_ = getOrCreateThreadLocalPlugin( wasm_, plugin_, dispatcher_, [this, create_root](Wasm* wasm, const std::shared_ptr& plugin) { root_context_ = static_cast(create_root(wasm, plugin)); return root_context_; }); + wasm_ = plugin_handle_->wasmHandleForTest(); } } @@ -101,6 +102,7 @@ template class WasmTestBase : public Base { NiceMock init_manager_; WasmHandleSharedPtr wasm_; PluginSharedPtr plugin_; + PluginHandleSharedPtr plugin_handle_; NiceMock ssl_; NiceMock connection_; NiceMock decoder_callbacks_; @@ -114,9 +116,9 @@ template class WasmTestBase : public Base { template class WasmHttpFilterTestBase : public WasmTestBase { public: - template void setupFilterBase(const std::string root_id = "") { + template void setupFilterBase() { auto wasm = WasmTestBase::wasm_ ? WasmTestBase::wasm_->wasm().get() : nullptr; - int root_context_id = wasm ? wasm->getRootContext(root_id)->id() : 0; + int root_context_id = wasm ? wasm->getRootContext(WasmTestBase::plugin_, false)->id() : 0; context_ = std::make_unique(wasm, root_context_id, WasmTestBase::plugin_); context_->setDecoderFilterCallbacks(decoder_callbacks_); context_->setEncoderFilterCallbacks(encoder_callbacks_); @@ -131,9 +133,9 @@ template class WasmHttpFilterTestBase : public W template class WasmNetworkFilterTestBase : public WasmTestBase { public: - template void setupFilterBase(const std::string root_id = "") { + template void setupFilterBase() { auto wasm = WasmTestBase::wasm_ ? WasmTestBase::wasm_->wasm().get() : nullptr; - int root_context_id = wasm ? wasm->getRootContext(root_id)->id() : 0; + int root_context_id = wasm ? wasm->getRootContext(WasmTestBase::plugin_, false)->id() : 0; context_ = std::make_unique(wasm, root_context_id, WasmTestBase::plugin_); context_->initializeReadFilterCallbacks(read_filter_callbacks_); context_->initializeWriteFilterCallbacks(write_filter_callbacks_);