From 713bb5b036d24e860c1bb6bff89b4b6f113d60b4 Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Fri, 2 Jun 2023 09:28:43 +1000 Subject: [PATCH] Use a trait for remote action result caching (#19097) This separates the `remote::remote_cache` coordination code from the gRPC REAPI implementation by: - adding a `remote::remote_cache::ActionCacheProvider` trait - moving the REAPI implementation into `remote::remote_cache::reapi` and implementing that trait This is, in theory, just a lift-and-shift, with no functionality change. This is preparation work for supporting more remote stores like GHA cache or S3, for https://github.com/pantsbuild/pants/issues/11149, and is specifically broken out of https://github.com/pantsbuild/pants/pull/17840. It continues #19050. Additional work required to actually solve https://github.com/pantsbuild/pants/issues/11149: - implementing other byte store and action cache providers - dynamically choosing the right providers for `store::remote::ByteStore` and `remote::remote_cache` The commits are individually reviewable: 1. preparatory breaking out of gRPC code 2. defining the trait 3. move the REAPI code and implement the trait, close to naively as possible: - https://gist.github.com/huonw/a60ad807b05ecea98387294c22de67cb has a white-space-ignoring diff between `remote_cache.rs` after 1, and the `reapi.rs` file in this commit (it's less useful than #19050, since most of the code is deleted, but, buried in there are chunks of completely unchanged code) 4. minor clean-up --- .../remote/src/remote_cache.rs | 186 +++++++----------- .../remote/src/remote_cache/reapi.rs | 123 ++++++++++++ 2 files changed, 193 insertions(+), 116 deletions(-) create mode 100644 src/rust/engine/process_execution/remote/src/remote_cache/reapi.rs diff --git a/src/rust/engine/process_execution/remote/src/remote_cache.rs b/src/rust/engine/process_execution/remote/src/remote_cache.rs index d09aa35b6b9..f62fe39435a 100644 --- a/src/rust/engine/process_execution/remote/src/remote_cache.rs +++ b/src/rust/engine/process_execution/remote/src/remote_cache.rs @@ -1,7 +1,6 @@ // Copyright 2022 Pants project contributors (see CONTRIBUTORS.md). // Licensed under the Apache License, Version 2.0 (see LICENSE). use std::collections::{BTreeMap, HashSet}; -use std::convert::TryInto; use std::fmt::{self, Debug}; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -10,27 +9,24 @@ use async_trait::async_trait; use fs::{directory, DigestTrie, RelativePath, SymlinkBehavior}; use futures::future::{BoxFuture, TryFutureExt}; use futures::FutureExt; -use grpc_util::retry::{retry_call, status_is_retryable}; -use grpc_util::{headers_to_http_header_map, layered_service, status_to_str, LayeredService}; use hashing::Digest; use parking_lot::Mutex; use protos::gen::build::bazel::remote::execution::v2 as remexec; use protos::require_digest; -use remexec::action_cache_client::ActionCacheClient; use remexec::{ActionResult, Command, Tree}; use store::{Store, StoreError}; use workunit_store::{ in_workunit, Level, Metric, ObservationMetric, RunningWorkunit, WorkunitMetadata, }; -use crate::remote::apply_headers; use process_execution::{ check_cache_content, populate_fallible_execution_result, CacheContentBehavior, Context, FallibleProcessResultWithPlatform, Process, ProcessCacheScope, ProcessError, ProcessExecutionEnvironment, ProcessResultSource, }; use process_execution::{make_execute_request, EntireExecuteRequest}; -use tonic::{Code, Request, Status}; + +mod reapi; #[derive(Clone, Copy, Debug, strum_macros::EnumString)] #[strum(serialize_all = "snake_case")] @@ -40,6 +36,23 @@ pub enum RemoteCacheWarningsBehavior { Backoff, } +/// This `ActionCacheProvider` trait captures the operations required to be able to cache command +/// executions remotely. +#[async_trait] +trait ActionCacheProvider: Sync + Send + 'static { + async fn update_action_result( + &self, + action_digest: Digest, + action_result: ActionResult, + ) -> Result<(), String>; + + async fn get_action_result( + &self, + action_digest: Digest, + context: &Context, + ) -> Result, String>; +} + /// This `CommandRunner` implementation caches results remotely using the Action Cache service /// of the Remote Execution API. /// @@ -56,7 +69,7 @@ pub struct CommandRunner { append_only_caches_base_path: Option, executor: task_executor::Executor, store: Store, - action_cache_client: Arc>, + provider: Arc, cache_read: bool, cache_write: bool, cache_content_behavior: CacheContentBehavior, @@ -74,7 +87,7 @@ impl CommandRunner { store: Store, action_cache_address: &str, root_ca_certs: Option>, - mut headers: BTreeMap, + headers: BTreeMap, cache_read: bool, cache_write: bool, warnings_behavior: RemoteCacheWarningsBehavior, @@ -83,25 +96,14 @@ impl CommandRunner { rpc_timeout: Duration, append_only_caches_base_path: Option, ) -> Result { - let tls_client_config = if action_cache_address.starts_with("https://") { - Some(grpc_util::tls::Config::new_without_mtls(root_ca_certs).try_into()?) - } else { - None - }; - - let endpoint = grpc_util::create_endpoint( + let provider = Arc::new(reapi::Provider::new( + instance_name.clone(), action_cache_address, - tls_client_config.as_ref(), - &mut headers, - )?; - let http_headers = headers_to_http_header_map(&headers)?; - let channel = layered_service( - tonic::transport::Channel::balance_list(vec![endpoint].into_iter()), + root_ca_certs, + headers, concurrency_limit, - http_headers, - Some((rpc_timeout, Metric::RemoteCacheRequestTimeouts)), - ); - let action_cache_client = Arc::new(ActionCacheClient::new(channel)); + rpc_timeout, + )?); Ok(CommandRunner { inner, @@ -110,7 +112,7 @@ impl CommandRunner { append_only_caches_base_path, executor, store, - action_cache_client, + provider, cache_read, cache_write, cache_content_behavior, @@ -280,10 +282,9 @@ impl CommandRunner { let response = check_action_cache( action_digest, &request.description, - self.instance_name.clone(), request.execution_environment.clone(), &context, - self.action_cache_client.clone(), + self.provider.clone(), self.store.clone(), self.cache_content_behavior, ) @@ -384,7 +385,6 @@ impl CommandRunner { async fn update_action_cache( &self, result: &FallibleProcessResultWithPlatform, - instance_name: Option, command: &Command, action_digest: Digest, command_digest: Digest, @@ -405,28 +405,10 @@ impl CommandRunner { .ensure_remote_has_recursive(digests_for_action_result) .await?; - let client = self.action_cache_client.as_ref().clone(); - retry_call( - client, - move |mut client| { - let update_action_cache_request = remexec::UpdateActionResultRequest { - instance_name: instance_name.clone().unwrap_or_else(|| "".to_owned()), - action_digest: Some(action_digest.into()), - action_result: Some(action_result.clone()), - ..remexec::UpdateActionResultRequest::default() - }; - - async move { - client - .update_action_result(update_action_cache_request) - .await - } - }, - status_is_retryable, - ) - .await - .map_err(status_to_str)?; - + self + .provider + .update_action_result(action_digest, action_result) + .await?; Ok(()) } @@ -532,13 +514,7 @@ impl process_execution::CommandRunner for CommandRunner { let write_fut = in_workunit!("remote_cache_write", Level::Trace, |workunit| async move { workunit.increment_counter(Metric::RemoteCacheWriteAttempts, 1); let write_result = command_runner - .update_action_cache( - &result, - command_runner.instance_name.clone(), - &command, - action_digest, - command_digest, - ) + .update_action_cache(&result, &command, action_digest, command_digest) .await; match write_result { Ok(_) => workunit.increment_counter(Metric::RemoteCacheWriteSuccesses, 1), @@ -574,10 +550,9 @@ impl process_execution::CommandRunner for CommandRunner { async fn check_action_cache( action_digest: Digest, command_description: &str, - instance_name: Option, environment: ProcessExecutionEnvironment, context: &Context, - action_cache_client: Arc>, + provider: Arc, store: Store, cache_content_behavior: CacheContentBehavior, ) -> Result, ProcessError> { @@ -589,69 +564,48 @@ async fn check_action_cache( workunit.increment_counter(Metric::RemoteCacheRequests, 1); let start = Instant::now(); - let client = action_cache_client.as_ref().clone(); - let response = retry_call( - client, - move |mut client| { - let request = remexec::GetActionResultRequest { - action_digest: Some(action_digest.into()), - instance_name: instance_name.clone().unwrap_or_default(), - ..remexec::GetActionResultRequest::default() - }; - let request = apply_headers(Request::new(request), &context.build_id); - async move { client.get_action_result(request).await } - }, - status_is_retryable, - ) - .and_then(|action_result| async move { - let action_result = action_result.into_inner(); - let response = populate_fallible_execution_result( - store.clone(), - context.run_id, - &action_result, - false, - ProcessResultSource::HitRemotely, - environment, - ) - .await - .map_err(|e| Status::unavailable(format!("Output roots could not be loaded: {e}")))?; - - let cache_content_valid = check_cache_content(&response, &store, cache_content_behavior) + let response = provider + .get_action_result(action_digest, context) + .and_then(|action_result| async move { + let Some(action_result) = action_result else { return Ok(None) }; + + let response = populate_fallible_execution_result( + store.clone(), + context.run_id, + &action_result, + false, + ProcessResultSource::HitRemotely, + environment, + ) .await - .map_err(|e| { - Status::unavailable(format!("Output content could not be validated: {e}")) - })?; - - if cache_content_valid { - Ok(response) - } else { - Err(Status::not_found("")) - } - }) - .await; + .map_err(|e| format!("Output roots could not be loaded: {e}"))?; + + let cache_content_valid = check_cache_content(&response, &store, cache_content_behavior) + .await + .map_err(|e| format!("Output content could not be validated: {e}"))?; + + if cache_content_valid { + Ok(Some(response)) + } else { + Ok(None) + } + }) + .await; workunit.record_observation( ObservationMetric::RemoteCacheGetActionResultTimeMicros, start.elapsed().as_micros() as u64, ); - match response { - Ok(response) => { - workunit.increment_counter(Metric::RemoteCacheRequestsCached, 1); - Ok(Some(response)) - } - Err(status) => match status.code() { - Code::NotFound => { - workunit.increment_counter(Metric::RemoteCacheRequestsUncached, 1); - Ok(None) - } - _ => { - workunit.increment_counter(Metric::RemoteCacheReadErrors, 1); - // TODO: Ensure that we're catching missing digests. - Err(status_to_str(status).into()) - } - }, - } + let counter = match response { + Ok(Some(_)) => Metric::RemoteCacheRequestsCached, + Ok(None) => Metric::RemoteCacheRequestsUncached, + // TODO: Ensure that we're catching missing digests. + Err(_) => Metric::RemoteCacheReadErrors, + }; + workunit.increment_counter(counter, 1); + + response.map_err(ProcessError::from) } ) .await diff --git a/src/rust/engine/process_execution/remote/src/remote_cache/reapi.rs b/src/rust/engine/process_execution/remote/src/remote_cache/reapi.rs new file mode 100644 index 00000000000..84d8c28d1dd --- /dev/null +++ b/src/rust/engine/process_execution/remote/src/remote_cache/reapi.rs @@ -0,0 +1,123 @@ +// Copyright 2023 Pants project contributors (see CONTRIBUTORS.md). +// Licensed under the Apache License, Version 2.0 (see LICENSE). +use std::collections::BTreeMap; +use std::convert::TryInto; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use grpc_util::retry::{retry_call, status_is_retryable}; +use grpc_util::{headers_to_http_header_map, layered_service, status_to_str, LayeredService}; +use hashing::Digest; +use protos::gen::build::bazel::remote::execution::v2 as remexec; +use remexec::action_cache_client::ActionCacheClient; +use remexec::ActionResult; +use workunit_store::Metric; + +use crate::remote::apply_headers; +use process_execution::Context; +use tonic::{Code, Request}; + +use super::ActionCacheProvider; + +pub struct Provider { + instance_name: Option, + action_cache_client: Arc>, +} + +impl Provider { + pub fn new( + instance_name: Option, + action_cache_address: &str, + root_ca_certs: Option>, + mut headers: BTreeMap, + concurrency_limit: usize, + rpc_timeout: Duration, + ) -> Result { + let tls_client_config = if action_cache_address.starts_with("https://") { + Some(grpc_util::tls::Config::new_without_mtls(root_ca_certs).try_into()?) + } else { + None + }; + + let endpoint = grpc_util::create_endpoint( + action_cache_address, + tls_client_config.as_ref(), + &mut headers, + )?; + let http_headers = headers_to_http_header_map(&headers)?; + let channel = layered_service( + tonic::transport::Channel::balance_list(vec![endpoint].into_iter()), + concurrency_limit, + http_headers, + Some((rpc_timeout, Metric::RemoteCacheRequestTimeouts)), + ); + let action_cache_client = Arc::new(ActionCacheClient::new(channel)); + + Ok(Provider { + instance_name, + action_cache_client, + }) + } +} + +#[async_trait] +impl ActionCacheProvider for Provider { + async fn update_action_result( + &self, + action_digest: Digest, + action_result: ActionResult, + ) -> Result<(), String> { + let client = self.action_cache_client.as_ref().clone(); + retry_call( + client, + move |mut client| { + let update_action_cache_request = remexec::UpdateActionResultRequest { + instance_name: self.instance_name.clone().unwrap_or_else(|| "".to_owned()), + action_digest: Some(action_digest.into()), + action_result: Some(action_result.clone()), + ..remexec::UpdateActionResultRequest::default() + }; + + async move { + client + .update_action_result(update_action_cache_request) + .await + } + }, + status_is_retryable, + ) + .await + .map_err(status_to_str)?; + + Ok(()) + } + + async fn get_action_result( + &self, + action_digest: Digest, + context: &Context, + ) -> Result, String> { + let client = self.action_cache_client.as_ref().clone(); + let response = retry_call( + client, + move |mut client| { + let request = remexec::GetActionResultRequest { + action_digest: Some(action_digest.into()), + instance_name: self.instance_name.clone().unwrap_or_default(), + ..remexec::GetActionResultRequest::default() + }; + let request = apply_headers(Request::new(request), &context.build_id); + async move { client.get_action_result(request).await } + }, + status_is_retryable, + ) + .await; + + match response { + Ok(response) => Ok(Some(response.into_inner())), + Err(status) if status.code() == Code::NotFound => Ok(None), + Err(status) => Err(status_to_str(status)), + } + } +}