From d7da0c53a8225510d601f8aac851406420877eb2 Mon Sep 17 00:00:00 2001 From: "Mark D. Roth" Date: Fri, 23 Aug 2024 08:44:55 -0700 Subject: [PATCH] [TokenFetcherCredentials] add backoff and pre-fetching (#37531) This adds functionality that is intended to be used for the new GcpServiceAccountIdentityCallCredentials implementation, as per gRFC A83 (https://github.com/grpc/proposal/pull/438). However, it is also a useful improvement for all token-fetching call credentials types, so I am adding it to the base class. Closes #37531 COPYBARA_INTEGRATE_REVIEW=https://github.com/grpc/grpc/pull/37531 from markdroth:token_fetcher_call_creds_prefetch_and_backoff 0fcdb48465dc01893beeda8dd277d91491be3420 PiperOrigin-RevId: 666809903 --- doc/trace_flags.md | 1 + src/core/BUILD | 4 +- src/core/lib/debug/trace_flags.cc | 2 + src/core/lib/debug/trace_flags.h | 1 + src/core/lib/debug/trace_flags.yaml | 3 + src/core/lib/promise/map.h | 2 +- .../external/external_account_credentials.cc | 6 +- .../external/external_account_credentials.h | 5 - .../file_external_account_credentials.cc | 1 - .../token_fetcher_credentials.cc | 284 ++++++++++++--- .../token_fetcher/token_fetcher_credentials.h | 83 ++++- test/core/security/credentials_test.cc | 335 +++++++++++++++++- 12 files changed, 638 insertions(+), 89 deletions(-) diff --git a/doc/trace_flags.md b/doc/trace_flags.md index 56ad7191781d38..4ea7129a28f48d 100644 --- a/doc/trace_flags.md +++ b/doc/trace_flags.md @@ -73,6 +73,7 @@ processing requests via debug logs. Available tracers include: - tcp - Bytes in and out of a channel. - timer - Timers (alarms) in the grpc internals. - timer_check - more detailed trace of timer logic in grpc internals. + - token_fetcher_credentials - Token fetcher call credentials framework, used for (e.g.) oauth2 token fetcher credentials. - tsi - TSI transport security. - weighted_round_robin_lb - Weighted round robin load balancing policy. - weighted_target_lb - Weighted target LB policy. diff --git a/src/core/BUILD b/src/core/BUILD index 50ce0f5c9ccaba..641ba363ab8f21 100644 --- a/src/core/BUILD +++ b/src/core/BUILD @@ -4335,14 +4335,17 @@ grpc_cc_library( deps = [ "arena_promise", "context", + "default_event_engine", "metadata", "poll", "pollset_set", "ref_counted", "time", "useful", + "//:backoff", "//:gpr", "//:grpc_security_base", + "//:grpc_trace", "//:httpcli", "//:iomgr", "//:orphanable", @@ -4436,7 +4439,6 @@ grpc_cc_library( language = "c++", deps = [ "closure", - "default_event_engine", "env", "error", "error_utils", diff --git a/src/core/lib/debug/trace_flags.cc b/src/core/lib/debug/trace_flags.cc index f3fa6c63314ab8..8f7183782c8195 100644 --- a/src/core/lib/debug/trace_flags.cc +++ b/src/core/lib/debug/trace_flags.cc @@ -114,6 +114,7 @@ TraceFlag subchannel_pool_trace(false, "subchannel_pool"); TraceFlag tcp_trace(false, "tcp"); TraceFlag timer_trace(false, "timer"); TraceFlag timer_check_trace(false, "timer_check"); +TraceFlag token_fetcher_credentials_trace(false, "token_fetcher_credentials"); TraceFlag tsi_trace(false, "tsi"); TraceFlag weighted_round_robin_lb_trace(false, "weighted_round_robin_lb"); TraceFlag weighted_target_lb_trace(false, "weighted_target_lb"); @@ -206,6 +207,7 @@ const absl::flat_hash_map& GetAllTraceFlags() { {"tcp", &tcp_trace}, {"timer", &timer_trace}, {"timer_check", &timer_check_trace}, + {"token_fetcher_credentials", &token_fetcher_credentials_trace}, {"tsi", &tsi_trace}, {"weighted_round_robin_lb", &weighted_round_robin_lb_trace}, {"weighted_target_lb", &weighted_target_lb_trace}, diff --git a/src/core/lib/debug/trace_flags.h b/src/core/lib/debug/trace_flags.h index 4aaf5e0111169d..9aa4df691f0ec1 100644 --- a/src/core/lib/debug/trace_flags.h +++ b/src/core/lib/debug/trace_flags.h @@ -112,6 +112,7 @@ extern TraceFlag subchannel_pool_trace; extern TraceFlag tcp_trace; extern TraceFlag timer_trace; extern TraceFlag timer_check_trace; +extern TraceFlag token_fetcher_credentials_trace; extern TraceFlag tsi_trace; extern TraceFlag weighted_round_robin_lb_trace; extern TraceFlag weighted_target_lb_trace; diff --git a/src/core/lib/debug/trace_flags.yaml b/src/core/lib/debug/trace_flags.yaml index 247d830e81f432..64c665af79958b 100644 --- a/src/core/lib/debug/trace_flags.yaml +++ b/src/core/lib/debug/trace_flags.yaml @@ -308,6 +308,9 @@ timer: timer_check: default: false description: more detailed trace of timer logic in grpc internals. +token_fetcher_credentials: + default: false + description: Token fetcher call credentials framework, used for (e.g.) oauth2 token fetcher credentials. tsi: default: false description: TSI transport security. diff --git a/src/core/lib/promise/map.h b/src/core/lib/promise/map.h index a2a2a773eea2bf..3ba8c19c2f6ac9 100644 --- a/src/core/lib/promise/map.h +++ b/src/core/lib/promise/map.h @@ -86,7 +86,7 @@ GPR_ATTRIBUTE_ALWAYS_INLINE_FUNCTION auto CheckDelayed(Promise promise) { delayed = true; return Pending{}; } - return std::make_tuple(r.value(), delayed); + return std::make_tuple(std::move(r.value()), delayed); }; } diff --git a/src/core/lib/security/credentials/external/external_account_credentials.cc b/src/core/lib/security/credentials/external/external_account_credentials.cc index 100b4ddd8e4d65..c83351e67e2966 100644 --- a/src/core/lib/security/credentials/external/external_account_credentials.cc +++ b/src/core/lib/security/credentials/external/external_account_credentials.cc @@ -45,7 +45,6 @@ #include #include -#include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/gprpp/status_helper.h" #include "src/core/lib/security/credentials/credentials.h" #include "src/core/lib/security/credentials/external/aws_external_account_credentials.h" @@ -591,10 +590,7 @@ ExternalAccountCredentials::Create( ExternalAccountCredentials::ExternalAccountCredentials( Options options, std::vector scopes, std::shared_ptr event_engine) - : event_engine_( - event_engine == nullptr - ? grpc_event_engine::experimental::GetDefaultEventEngine() - : std::move(event_engine)), + : TokenFetcherCredentials(std::move(event_engine)), options_(std::move(options)) { if (scopes.empty()) { scopes.push_back(GOOGLE_CLOUD_PLATFORM_DEFAULT_SCOPE); diff --git a/src/core/lib/security/credentials/external/external_account_credentials.h b/src/core/lib/security/credentials/external/external_account_credentials.h index 0617c3e6f07173..5dfbb24a36ede4 100644 --- a/src/core/lib/security/credentials/external/external_account_credentials.h +++ b/src/core/lib/security/credentials/external/external_account_credentials.h @@ -185,10 +185,6 @@ class ExternalAccountCredentials : public TokenFetcherCredentials { absl::string_view audience() const { return options_.audience; } - grpc_event_engine::experimental::EventEngine& event_engine() const { - return *event_engine_; - } - private: OrphanablePtr FetchToken( Timestamp deadline, @@ -204,7 +200,6 @@ class ExternalAccountCredentials : public TokenFetcherCredentials { Timestamp deadline, absl::AnyInvocable)> on_done) = 0; - std::shared_ptr event_engine_; Options options_; std::vector scopes_; }; diff --git a/src/core/lib/security/credentials/external/file_external_account_credentials.cc b/src/core/lib/security/credentials/external/file_external_account_credentials.cc index cad9b7f7ee43a3..086af8c79c7be6 100644 --- a/src/core/lib/security/credentials/external/file_external_account_credentials.cc +++ b/src/core/lib/security/credentials/external/file_external_account_credentials.cc @@ -26,7 +26,6 @@ #include #include -#include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/gprpp/load_file.h" #include "src/core/lib/slice/slice.h" #include "src/core/lib/slice/slice_internal.h" diff --git a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc index eb5bee6478a962..596641f0fffdcf 100644 --- a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc +++ b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.cc @@ -18,6 +18,8 @@ #include "src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h" +#include "src/core/lib/debug/trace.h" +#include "src/core/lib/event_engine/default_event_engine.h" #include "src/core/lib/iomgr/pollset_set.h" #include "src/core/lib/promise/context.h" #include "src/core/lib/promise/poll.h" @@ -28,8 +30,11 @@ namespace grpc_core { namespace { // Amount of time before the token's expiration that we consider it -// invalid and start a new fetch. Also determines the timeout for the -// fetch request. +// invalid to account for server processing time and clock skew. +constexpr Duration kTokenExpirationAdjustmentDuration = Duration::Seconds(30); + +// Amount of time before the token's expiration that we pre-fetch a new +// token. Also determines the timeout for the fetch request. constexpr Duration kTokenRefreshDuration = Duration::Seconds(60); } // namespace @@ -38,18 +43,193 @@ constexpr Duration kTokenRefreshDuration = Duration::Seconds(60); // TokenFetcherCredentials::Token // +TokenFetcherCredentials::Token::Token(Slice token, Timestamp expiration) + : token_(std::move(token)), + expiration_(expiration - kTokenExpirationAdjustmentDuration) {} + void TokenFetcherCredentials::Token::AddTokenToClientInitialMetadata( ClientMetadata& metadata) const { metadata.Append(GRPC_AUTHORIZATION_METADATA_KEY, token_.Ref(), [](absl::string_view, const Slice&) { abort(); }); } +// +// TokenFetcherCredentials::FetchState::BackoffTimer +// + +TokenFetcherCredentials::FetchState::BackoffTimer::BackoffTimer( + RefCountedPtr fetch_state) + : fetch_state_(std::move(fetch_state)) { + const Timestamp next_attempt_time = fetch_state_->backoff_.NextAttemptTime(); + const Duration duration = next_attempt_time - Timestamp::Now(); + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << fetch_state_->creds_.get() + << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this + << ": starting backoff timer for " << next_attempt_time << " (" + << duration << " from now)"; + timer_handle_ = fetch_state_->creds_->event_engine().RunAfter( + duration, [self = Ref()]() mutable { + ApplicationCallbackExecCtx callback_exec_ctx; + ExecCtx exec_ctx; + self->OnTimer(); + self.reset(); + }); +} + +void TokenFetcherCredentials::FetchState::BackoffTimer::Orphan() { + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << fetch_state_->creds_.get() + << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this + << ": backoff timer shut down"; + if (timer_handle_.has_value()) { + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << fetch_state_->creds_.get() + << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this + << ": cancelling timer"; + fetch_state_->creds_->event_engine().Cancel(*timer_handle_); + timer_handle_.reset(); + fetch_state_->ResumeQueuedCalls( + absl::CancelledError("credentials shutdown")); + } + Unref(); +} + +void TokenFetcherCredentials::FetchState::BackoffTimer::OnTimer() { + MutexLock lock(&fetch_state_->creds_->mu_); + if (!timer_handle_.has_value()) return; + timer_handle_.reset(); + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << fetch_state_->creds_.get() + << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this + << ": backoff timer fired"; + if (fetch_state_->queued_calls_.empty()) { + // If there are no pending calls when the timer fires, then orphan + // the FetchState object. Note that this drops the backoff state, + // but that's probably okay, because if we didn't have any pending + // calls during the backoff period, we probably won't see any + // immediately now either. + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << fetch_state_->creds_.get() + << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this + << ": no pending calls, clearing state"; + fetch_state_->creds_->fetch_state_.reset(); + } else { + // If there are pending calls, then start a new fetch attempt. + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << fetch_state_->creds_.get() + << "]: fetch_state=" << fetch_state_.get() << " backoff_timer=" << this + << ": starting new fetch attempt"; + fetch_state_->StartFetchAttempt(); + } +} + +// +// TokenFetcherCredentials::FetchState +// + +TokenFetcherCredentials::FetchState::FetchState( + WeakRefCountedPtr creds) + : creds_(std::move(creds)), + backoff_(BackOff::Options() + .set_initial_backoff(Duration::Seconds(1)) + .set_multiplier(1.6) + .set_jitter(creds_->test_only_use_backoff_jitter_ ? 0.2 : 0) + .set_max_backoff(Duration::Seconds(120))) { + StartFetchAttempt(); +} + +void TokenFetcherCredentials::FetchState::Orphan() { + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << creds_.get() + << "]: fetch_state=" << this << ": shutting down"; + // Cancels fetch or backoff timer, if any. + state_ = Shutdown{}; + Unref(); +} + +void TokenFetcherCredentials::FetchState::StartFetchAttempt() { + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << creds_.get() + << "]: fetch_state=" << this << ": starting fetch"; + state_ = creds_->FetchToken( + /*deadline=*/Timestamp::Now() + kTokenRefreshDuration, + [self = Ref()](absl::StatusOr> token) mutable { + self->TokenFetchComplete(std::move(token)); + }); +} + +void TokenFetcherCredentials::FetchState::TokenFetchComplete( + absl::StatusOr> token) { + MutexLock lock(&creds_->mu_); + // If we were shut down, clean up. + if (absl::holds_alternative(state_)) { + if (token.ok()) token = absl::CancelledError("credentials shutdown"); + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << creds_.get() + << "]: fetch_state=" << this + << ": shut down before fetch completed: " << token.status(); + ResumeQueuedCalls(std::move(token)); + return; + } + // If succeeded, update cache in creds object. + if (token.ok()) { + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << creds_.get() + << "]: fetch_state=" << this << ": token fetch succeeded"; + creds_->token_ = *token; + creds_->fetch_state_.reset(); // Orphan ourselves. + } else { + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << creds_.get() + << "]: fetch_state=" << this + << ": token fetch failed: " << token.status(); + // If failed, start backoff timer. + state_ = OrphanablePtr(new BackoffTimer(Ref())); + } + ResumeQueuedCalls(std::move(token)); +} + +void TokenFetcherCredentials::FetchState::ResumeQueuedCalls( + absl::StatusOr> token) { + // Invoke callbacks for all pending requests. + for (auto& queued_call : queued_calls_) { + queued_call->result = token; + queued_call->done.store(true, std::memory_order_release); + queued_call->waker.Wakeup(); + grpc_polling_entity_del_from_pollset_set( + queued_call->pollent, + grpc_polling_entity_pollset_set(&creds_->pollent_)); + } + queued_calls_.clear(); +} + +RefCountedPtr +TokenFetcherCredentials::FetchState::QueueCall( + ClientMetadataHandle initial_metadata) { + // Add call to pending list. + auto queued_call = MakeRefCounted(); + queued_call->waker = GetContext()->MakeNonOwningWaker(); + queued_call->pollent = GetContext(); + grpc_polling_entity_add_to_pollset_set( + queued_call->pollent, grpc_polling_entity_pollset_set(&creds_->pollent_)); + queued_call->md = std::move(initial_metadata); + queued_calls_.insert(queued_call); + return queued_call; +} + // // TokenFetcherCredentials // -TokenFetcherCredentials::TokenFetcherCredentials() - : pollent_(grpc_polling_entity_create_from_pollset_set( +TokenFetcherCredentials::TokenFetcherCredentials( + std::shared_ptr event_engine, + bool test_only_use_backoff_jitter) + : event_engine_( + event_engine == nullptr + ? grpc_event_engine::experimental::GetDefaultEventEngine() + : std::move(event_engine)), + test_only_use_backoff_jitter_(test_only_use_backoff_jitter), + pollent_(grpc_polling_entity_create_from_pollset_set( grpc_pollset_set_create())) {} TokenFetcherCredentials::~TokenFetcherCredentials() { @@ -58,73 +238,63 @@ TokenFetcherCredentials::~TokenFetcherCredentials() { void TokenFetcherCredentials::Orphaned() { MutexLock lock(&mu_); - auto* fetch_request = absl::get_if>(&token_); - if (fetch_request != nullptr) fetch_request->reset(); + fetch_state_.reset(); } ArenaPromise> TokenFetcherCredentials::GetRequestMetadata( ClientMetadataHandle initial_metadata, const GetRequestMetadataArgs*) { - RefCountedPtr pending_call; + RefCountedPtr queued_call; { MutexLock lock(&mu_); - // Check if we can use the cached token. - auto* cached_token = absl::get_if>(&token_); - if (cached_token != nullptr && *cached_token != nullptr && - ((*cached_token)->ExpirationTime() - Timestamp::Now()) > - kTokenRefreshDuration) { - (*cached_token)->AddTokenToClientInitialMetadata(*initial_metadata); - return Immediate(std::move(initial_metadata)); + // If we don't have a cached token or the token is within the + // refresh duration, start a new fetch if there isn't a pending one. + if ((token_ == nullptr || (token_->ExpirationTime() - Timestamp::Now()) <= + kTokenRefreshDuration) && + fetch_state_ == nullptr) { + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << this + << "]: " << GetContext()->DebugTag() + << " triggering new token fetch"; + fetch_state_ = OrphanablePtr( + new FetchState(WeakRefAsSubclass())); } - // Couldn't get the token from the cache. - // Add this call to the pending list. - pending_call = MakeRefCounted(); - pending_call->waker = GetContext()->MakeNonOwningWaker(); - pending_call->pollent = GetContext(); - grpc_polling_entity_add_to_pollset_set( - pending_call->pollent, grpc_polling_entity_pollset_set(&pollent_)); - pending_call->md = std::move(initial_metadata); - pending_calls_.insert(pending_call); - // Start a new fetch if needed. - if (!absl::holds_alternative>(token_)) { - token_ = FetchToken( - /*deadline=*/Timestamp::Now() + kTokenRefreshDuration, - [self = WeakRefAsSubclass()]( - absl::StatusOr> token) mutable { - self->TokenFetchComplete(std::move(token)); - }); + // If we have a cached non-expired token, use it. + if (token_ != nullptr && + (token_->ExpirationTime() - Timestamp::Now()) > Duration::Zero()) { + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << this + << "]: " << GetContext()->DebugTag() + << " using cached token"; + token_->AddTokenToClientInitialMetadata(*initial_metadata); + return Immediate(std::move(initial_metadata)); } + // If we don't have a cached token, this call will need to be queued. + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << this + << "]: " << GetContext()->DebugTag() + << " no cached token; queuing call"; + queued_call = fetch_state_->QueueCall(std::move(initial_metadata)); } - return [pending_call = std::move( - pending_call)]() -> Poll> { - if (!pending_call->done.load(std::memory_order_acquire)) { + return [this, queued_call = std::move(queued_call)]() + -> Poll> { + if (!queued_call->done.load(std::memory_order_acquire)) { return Pending{}; } - if (!pending_call->result.ok()) { - return pending_call->result.status(); + if (!queued_call->result.ok()) { + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << this + << "]: " << GetContext()->DebugTag() + << " token fetch failed; failing call"; + return queued_call->result.status(); } - (*pending_call->result)->AddTokenToClientInitialMetadata(*pending_call->md); - return std::move(pending_call->md); + GRPC_TRACE_LOG(token_fetcher_credentials, INFO) + << "[TokenFetcherCredentials " << this + << "]: " << GetContext()->DebugTag() + << " token fetch complete; resuming call"; + (*queued_call->result)->AddTokenToClientInitialMetadata(*queued_call->md); + return std::move(queued_call->md); }; } -void TokenFetcherCredentials::TokenFetchComplete( - absl::StatusOr> token) { - // Update cache and grab list of pending requests. - absl::flat_hash_set> pending_calls; - { - MutexLock lock(&mu_); - token_ = token.value_or(nullptr); - pending_calls_.swap(pending_calls); - } - // Invoke callbacks for all pending requests. - for (auto& pending_call : pending_calls) { - pending_call->result = token; - pending_call->done.store(true, std::memory_order_release); - pending_call->waker.Wakeup(); - grpc_polling_entity_del_from_pollset_set( - pending_call->pollent, grpc_polling_entity_pollset_set(&pollent_)); - } -} - } // namespace grpc_core diff --git a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h index 72793afcdee253..c73d469234b066 100644 --- a/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h +++ b/src/core/lib/security/credentials/token_fetcher/token_fetcher_credentials.h @@ -26,6 +26,9 @@ #include "absl/status/statusor.h" #include "absl/types/variant.h" +#include + +#include "src/core/lib/backoff/backoff.h" #include "src/core/lib/gprpp/orphanable.h" #include "src/core/lib/gprpp/ref_counted.h" #include "src/core/lib/gprpp/ref_counted_ptr.h" @@ -47,8 +50,7 @@ class TokenFetcherCredentials : public grpc_call_credentials { // Represents a token. class Token : public RefCounted { public: - Token(Slice token, Timestamp expiration) - : token_(std::move(token)), expiration_(expiration) {} + Token(Slice token, Timestamp expiration); // Returns the token's expiration time. Timestamp ExpirationTime() const { return expiration_; } @@ -73,7 +75,10 @@ class TokenFetcherCredentials : public grpc_call_credentials { // Base class for fetch requests. class FetchRequest : public InternallyRefCounted {}; - TokenFetcherCredentials(); + explicit TokenFetcherCredentials( + std::shared_ptr + event_engine = nullptr, + bool test_only_use_backoff_jitter = true); // Fetches a token. The on_done callback will be invoked when complete. virtual OrphanablePtr FetchToken( @@ -81,11 +86,15 @@ class TokenFetcherCredentials : public grpc_call_credentials { absl::AnyInvocable>)> on_done) = 0; + grpc_event_engine::experimental::EventEngine& event_engine() const { + return *event_engine_; + } + grpc_polling_entity* pollent() { return &pollent_; } private: // A call that is waiting for a token fetch request to complete. - struct PendingCall : public RefCounted { + struct QueuedCall : public RefCounted { std::atomic done{false}; Waker waker; grpc_polling_entity* pollent; @@ -93,20 +102,72 @@ class TokenFetcherCredentials : public grpc_call_credentials { absl::StatusOr> result; }; + class FetchState : public InternallyRefCounted { + public: + explicit FetchState(WeakRefCountedPtr creds) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); + + // Disabling thread safety annotations, since Orphan() is called + // by OrpahanablePtr<>, which does not have the right lock + // annotations. + void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS; + + RefCountedPtr QueueCall(ClientMetadataHandle initial_metadata) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); + + private: + class BackoffTimer : public InternallyRefCounted { + public: + explicit BackoffTimer(RefCountedPtr fetch_state) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); + + // Disabling thread safety annotations, since Orphan() is called + // by OrpahanablePtr<>, which does not have the right lock + // annotations. + void Orphan() override ABSL_NO_THREAD_SAFETY_ANALYSIS; + + private: + void OnTimer(); + + RefCountedPtr fetch_state_; + absl::optional + timer_handle_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_); + }; + + struct Shutdown {}; + + void StartFetchAttempt() + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); + void TokenFetchComplete(absl::StatusOr> token); + void ResumeQueuedCalls(absl::StatusOr> token) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(&TokenFetcherCredentials::mu_); + + WeakRefCountedPtr creds_; + // Pending token-fetch request or backoff timer, if any. + absl::variant, OrphanablePtr, + Shutdown> + state_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_); + // Calls that are queued up waiting for the token. + absl::flat_hash_set> queued_calls_ + ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_); + // Backoff state. + BackOff backoff_ ABSL_GUARDED_BY(&TokenFetcherCredentials::mu_); + }; + int cmp_impl(const grpc_call_credentials* other) const override { // TODO(yashykt): Check if we can do something better here return QsortCompare(static_cast(this), other); } - void TokenFetchComplete(absl::StatusOr> token); + std::shared_ptr event_engine_; + const bool test_only_use_backoff_jitter_; Mutex mu_; - // Either the cached token or a pending request to fetch the token. - absl::variant, OrphanablePtr> token_ - ABSL_GUARDED_BY(&mu_); - // Calls that are queued up waiting for the token. - absl::flat_hash_set> pending_calls_ - ABSL_GUARDED_BY(&mu_); + // Cached token, if any. + RefCountedPtr token_ ABSL_GUARDED_BY(&mu_); + // Fetch state, if any. + OrphanablePtr fetch_state_ ABSL_GUARDED_BY(&mu_); + grpc_polling_entity pollent_ ABSL_GUARDED_BY(&mu_); }; diff --git a/test/core/security/credentials_test.cc b/test/core/security/credentials_test.cc index fb93d522a7c9fd..03467770b4dbde 100644 --- a/test/core/security/credentials_test.cc +++ b/test/core/security/credentials_test.cc @@ -50,6 +50,7 @@ #include "src/core/lib/iomgr/error.h" #include "src/core/lib/iomgr/timer_manager.h" #include "src/core/lib/promise/exec_ctx_wakeup_scheduler.h" +#include "src/core/lib/promise/map.h" #include "src/core/lib/promise/promise.h" #include "src/core/lib/promise/seq.h" #include "src/core/lib/security/context/security_context.h" @@ -79,6 +80,7 @@ namespace grpc_core { +using grpc_event_engine::experimental::FuzzingEventEngine; using internal::grpc_flush_cached_google_default_credentials; using internal::set_gce_tenancy_checker_for_testing; @@ -426,16 +428,19 @@ TEST_F(CredentialsTest, class RequestMetadataState : public RefCounted { public: static RefCountedPtr NewInstance( - grpc_error_handle expected_error, std::string expected) { + grpc_error_handle expected_error, std::string expected, + absl::optional expect_delay = absl::nullopt) { return MakeRefCounted( - expected_error, std::move(expected), + expected_error, std::move(expected), expect_delay, grpc_polling_entity_create_from_pollset_set(grpc_pollset_set_create())); } RequestMetadataState(grpc_error_handle expected_error, std::string expected, + absl::optional expect_delay, grpc_polling_entity pollent) : expected_error_(expected_error), expected_(std::move(expected)), + expect_delay_(expect_delay), pollent_(pollent) {} ~RequestMetadataState() override { @@ -453,12 +458,18 @@ class RequestMetadataState : public RefCounted { activity_ = MakeActivity( [this, creds] { return Seq( - creds->GetRequestMetadata( + CheckDelayed(creds->GetRequestMetadata( ClientMetadataHandle(&md_, Arena::PooledDeleter(nullptr)), - &get_request_metadata_args_), - [this](absl::StatusOr metadata) { + &get_request_metadata_args_)), + [this](std::tuple, bool> + metadata_and_delayed) { + auto& metadata = std::get<0>(metadata_and_delayed); + const bool delayed = std::get<1>(metadata_and_delayed); + if (expect_delay_.has_value()) { + EXPECT_EQ(delayed, *expect_delay_); + } if (metadata.ok()) { - CHECK(metadata->get() == &md_); + EXPECT_EQ(metadata->get(), &md_); } return metadata.status(); }); @@ -523,6 +534,7 @@ class RequestMetadataState : public RefCounted { grpc_error_handle expected_error_; std::string expected_; + absl::optional expect_delay_; RefCountedPtr arena_ = SimpleArenaAllocator()->MakeArena(); grpc_metadata_batch md_; grpc_call_credentials::GetRequestMetadataArgs get_request_metadata_args_; @@ -2369,6 +2381,315 @@ int aws_external_account_creds_httpcli_post_success( return 1; } +class TokenFetcherCredentialsTest : public ::testing::Test { + protected: + class TestTokenFetcherCredentials final : public TokenFetcherCredentials { + public: + explicit TestTokenFetcherCredentials( + std::shared_ptr + event_engine = nullptr) + : TokenFetcherCredentials(std::move(event_engine), + /*test_only_use_backoff_jitter=*/false) {} + + ~TestTokenFetcherCredentials() override { CHECK_EQ(queue_.size(), 0); } + + void AddResult(absl::StatusOr> result) { + MutexLock lock(&mu_); + queue_.push_front(std::move(result)); + } + + size_t num_fetches() const { return num_fetches_; } + + private: + class TestFetchRequest final : public FetchRequest { + public: + TestFetchRequest( + grpc_event_engine::experimental::EventEngine& event_engine, + absl::AnyInvocable>)> + on_done, + absl::StatusOr> result) { + event_engine.Run([on_done = std::move(on_done), + result = std::move(result)]() mutable { + ApplicationCallbackExecCtx application_exec_ctx; + ExecCtx exec_ctx; + std::exchange(on_done, nullptr)(std::move(result)); + }); + } + + void Orphan() override { Unref(); } + }; + + OrphanablePtr FetchToken( + Timestamp deadline, + absl::AnyInvocable>)> on_done) + override { + absl::StatusOr> result; + { + MutexLock lock(&mu_); + CHECK(!queue_.empty()); + result = std::move(queue_.back()); + queue_.pop_back(); + } + num_fetches_.fetch_add(1); + return MakeOrphanable( + event_engine(), std::move(on_done), std::move(result)); + } + + std::string debug_string() override { + return "TestTokenFetcherCredentials"; + } + + UniqueTypeName type() const override { + static UniqueTypeName::Factory kFactory("TestTokenFetcherCredentials"); + return kFactory.Create(); + } + + Mutex mu_; + std::deque>> queue_ + ABSL_GUARDED_BY(&mu_); + + std::atomic num_fetches_{0}; + }; + + void SetUp() override { + grpc_timer_manager_set_start_threaded(false); + grpc_init(); + } + + void TearDown() override { + event_engine_->FuzzingDone(); + event_engine_->TickUntilIdle(); + event_engine_->UnsetGlobalHooks(); + creds_.reset(); + grpc_event_engine::experimental::WaitForSingleOwner( + std::move(event_engine_)); + grpc_shutdown_blocking(); + } + + static RefCountedPtr MakeToken( + absl::string_view token, Timestamp expiration = Timestamp::InfFuture()) { + return MakeRefCounted( + Slice::FromCopiedString(token), expiration); + } + + std::shared_ptr event_engine_ = + std::make_shared(FuzzingEventEngine::Options(), + fuzzing_event_engine::Actions()); + RefCountedPtr creds_ = + MakeRefCounted(event_engine_); +}; + +TEST_F(TokenFetcherCredentialsTest, Basic) { + const auto kExpirationTime = Timestamp::Now() + Duration::Hours(1); + ExecCtx exec_ctx; + creds_->AddResult(MakeToken("foo", kExpirationTime)); + // First request will trigger a fetch. + auto state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 1); + // Second request while fetch is still outstanding will be delayed but + // will not trigger a new fetch. + state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 1); + // Now tick to finish the fetch. + event_engine_->TickUntilIdle(); + // Next request will be served from cache with no delay. + state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: foo", /*expect_delay=*/false); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 1); + // Advance time to expiration minus expiration adjustment and prefetch time. + exec_ctx.TestOnlySetNow(kExpirationTime - Duration::Seconds(90)); + // No new fetch yet. + EXPECT_EQ(creds_->num_fetches(), 1); + // Next request will trigger a new fetch but will still use the + // cached token. + creds_->AddResult(MakeToken("bar")); + state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: foo", /*expect_delay=*/false); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 2); + event_engine_->TickUntilIdle(); + // Next request will use the new data. + state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: bar", /*expect_delay=*/false); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 2); +} + +TEST_F(TokenFetcherCredentialsTest, Expires30SecondsEarly) { + const auto kExpirationTime = Timestamp::Now() + Duration::Hours(1); + ExecCtx exec_ctx; + creds_->AddResult(MakeToken("foo", kExpirationTime)); + // First request will trigger a fetch. + auto state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 1); + event_engine_->TickUntilIdle(); + // Advance time to expiration minus 30 seconds. + exec_ctx.TestOnlySetNow(kExpirationTime - Duration::Seconds(30)); + // No new fetch yet. + EXPECT_EQ(creds_->num_fetches(), 1); + // Next request will trigger a new fetch and will delay the call until + // the fetch completes. + creds_->AddResult(MakeToken("bar")); + state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: bar", /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 2); + event_engine_->TickUntilIdle(); +} + +TEST_F(TokenFetcherCredentialsTest, FetchFails) { + const absl::Status kExpectedError = absl::UnavailableError("bummer, dude"); + absl::optional run_after_duration; + event_engine_->SetRunAfterDurationCallback( + [&](FuzzingEventEngine::Duration duration) { + run_after_duration = duration; + }); + ExecCtx exec_ctx; + creds_->AddResult(kExpectedError); + // First request will trigger a fetch, which will fail. + auto state = RequestMetadataState::NewInstance(kExpectedError, "", + /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 1); + while (!run_after_duration.has_value()) event_engine_->Tick(); + // Make sure backoff was set for the right period. + // This is 1 second (initial backoff) minus 1ms for the tick needed above. + EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); + run_after_duration.reset(); + // Start a new call now, which will be queued and then eventually + // resumed when the next fetch happens. + state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + // Tick until the next fetch starts. + creds_->AddResult(MakeToken("foo")); + event_engine_->TickUntilIdle(); + EXPECT_EQ(creds_->num_fetches(), 2); + // A call started now should use the new cached data. + state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: foo", /*expect_delay=*/false); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 2); +} + +TEST_F(TokenFetcherCredentialsTest, Backoff) { + const absl::Status kExpectedError = absl::UnavailableError("bummer, dude"); + absl::optional run_after_duration; + event_engine_->SetRunAfterDurationCallback( + [&](FuzzingEventEngine::Duration duration) { + run_after_duration = duration; + }); + ExecCtx exec_ctx; + creds_->AddResult(kExpectedError); + // First request will trigger a fetch, which will fail. + auto state = RequestMetadataState::NewInstance(kExpectedError, "", + /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 1); + while (!run_after_duration.has_value()) event_engine_->Tick(); + // Make sure backoff was set for the right period. + EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); + run_after_duration.reset(); + // Start a new call now, which will be queued and then eventually + // resumed when the next fetch happens. + state = RequestMetadataState::NewInstance(kExpectedError, "", + /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + // Tick until the next fetch fails and the backoff timer starts again. + creds_->AddResult(kExpectedError); + while (!run_after_duration.has_value()) event_engine_->Tick(); + EXPECT_EQ(creds_->num_fetches(), 2); + // The backoff time should be longer now. We account for jitter here. + EXPECT_EQ(run_after_duration, std::chrono::milliseconds(1600)) + << "actual: " << run_after_duration->count(); + run_after_duration.reset(); + // Start another new call to trigger another new fetch once the + // backoff expires. + state = RequestMetadataState::NewInstance(kExpectedError, "", + /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + // Tick until the next fetch starts. + creds_->AddResult(kExpectedError); + while (!run_after_duration.has_value()) event_engine_->Tick(); + EXPECT_EQ(creds_->num_fetches(), 3); + // Check backoff time again. + EXPECT_EQ(run_after_duration, std::chrono::milliseconds(2560)) + << "actual: " << run_after_duration->count(); +} + +TEST_F(TokenFetcherCredentialsTest, FetchNotStartedAfterBackoffWithoutRpc) { + const absl::Status kExpectedError = absl::UnavailableError("bummer, dude"); + absl::optional run_after_duration; + event_engine_->SetRunAfterDurationCallback( + [&](FuzzingEventEngine::Duration duration) { + run_after_duration = duration; + }); + ExecCtx exec_ctx; + creds_->AddResult(kExpectedError); + // First request will trigger a fetch, which will fail. + auto state = RequestMetadataState::NewInstance(kExpectedError, "", + /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 1); + while (!run_after_duration.has_value()) event_engine_->Tick(); + // Make sure backoff was set for the right period. + EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); + run_after_duration.reset(); + // Tick until the backoff expires. No new fetch should be started. + event_engine_->TickUntilIdle(); + EXPECT_EQ(creds_->num_fetches(), 1); + // Now start a new request, which will trigger a new fetch. + creds_->AddResult(MakeToken("foo")); + state = RequestMetadataState::NewInstance( + absl::OkStatus(), "authorization: foo", /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 2); +} + +TEST_F(TokenFetcherCredentialsTest, ShutdownWhileBackoffTimerPending) { + const absl::Status kExpectedError = absl::UnavailableError("bummer, dude"); + absl::optional run_after_duration; + event_engine_->SetRunAfterDurationCallback( + [&](FuzzingEventEngine::Duration duration) { + run_after_duration = duration; + }); + ExecCtx exec_ctx; + creds_->AddResult(kExpectedError); + // First request will trigger a fetch, which will fail. + auto state = RequestMetadataState::NewInstance(kExpectedError, "", + /*expect_delay=*/true); + state->RunRequestMetadataTest(creds_.get(), kTestUrlScheme, kTestAuthority, + kTestPath); + EXPECT_EQ(creds_->num_fetches(), 1); + while (!run_after_duration.has_value()) event_engine_->Tick(); + // Make sure backoff was set for the right period. + EXPECT_EQ(run_after_duration, std::chrono::seconds(1)); + run_after_duration.reset(); + // Do nothing else. Make sure the creds shut down correctly. +} + // The subclass of ExternalAccountCredentials for testing. // ExternalAccountCredentials is an abstract class so we can't directly test // against it. @@ -2484,8 +2805,6 @@ TEST_F(CredentialsTest, grpc_version_string())); } -using grpc_event_engine::experimental::FuzzingEventEngine; - class ExternalAccountCredentialsTest : public ::testing::Test { protected: void SetUp() override {