Skip to content

Commit

Permalink
Use a trait for remote action result caching (#19097)
Browse files Browse the repository at this point in the history
This separates the `remote::remote_cache` coordination code from the
gRPC REAPI implementation by:

- adding a `remote::remote_cache::ActionCacheProvider` trait
- moving the REAPI implementation into `remote::remote_cache::reapi` and
implementing that trait

This is, in theory, just a lift-and-shift, with no functionality change.

This is preparation work for supporting more remote stores like GHA
cache or S3, for #11149, and
is specifically broken out of
#17840. It continues #19050.
Additional work required to actually solve
#11149:

- implementing other byte store and action cache providers
- dynamically choosing the right providers for
`store::remote::ByteStore` and `remote::remote_cache`

The commits are individually reviewable:
1. preparatory breaking out of gRPC code
2. defining the trait
3. move the REAPI code and implement the trait, close to naively as
possible:
  - https://gist.github.com/huonw/a60ad807b05ecea98387294c22de67cb has a
    white-space-ignoring diff between `remote_cache.rs` after 1, and the
    `reapi.rs` file in this commit (it's less useful than #19050, since most
    of the code is deleted, but, buried in there are chunks of completely
    unchanged code)
4. minor clean-up
  • Loading branch information
huonw authored Jun 1, 2023
1 parent d97e4ef commit 713bb5b
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 116 deletions.
186 changes: 70 additions & 116 deletions src/rust/engine/process_execution/remote/src/remote_cache.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
// Licensed under the Apache License, Version 2.0 (see LICENSE).
use std::collections::{BTreeMap, HashSet};
use std::convert::TryInto;
use std::fmt::{self, Debug};
use std::sync::Arc;
use std::time::{Duration, Instant};
Expand All @@ -10,27 +9,24 @@ use async_trait::async_trait;
use fs::{directory, DigestTrie, RelativePath, SymlinkBehavior};
use futures::future::{BoxFuture, TryFutureExt};
use futures::FutureExt;
use grpc_util::retry::{retry_call, status_is_retryable};
use grpc_util::{headers_to_http_header_map, layered_service, status_to_str, LayeredService};
use hashing::Digest;
use parking_lot::Mutex;
use protos::gen::build::bazel::remote::execution::v2 as remexec;
use protos::require_digest;
use remexec::action_cache_client::ActionCacheClient;
use remexec::{ActionResult, Command, Tree};
use store::{Store, StoreError};
use workunit_store::{
in_workunit, Level, Metric, ObservationMetric, RunningWorkunit, WorkunitMetadata,
};

use crate::remote::apply_headers;
use process_execution::{
check_cache_content, populate_fallible_execution_result, CacheContentBehavior, Context,
FallibleProcessResultWithPlatform, Process, ProcessCacheScope, ProcessError,
ProcessExecutionEnvironment, ProcessResultSource,
};
use process_execution::{make_execute_request, EntireExecuteRequest};
use tonic::{Code, Request, Status};

mod reapi;

#[derive(Clone, Copy, Debug, strum_macros::EnumString)]
#[strum(serialize_all = "snake_case")]
Expand All @@ -40,6 +36,23 @@ pub enum RemoteCacheWarningsBehavior {
Backoff,
}

/// This `ActionCacheProvider` trait captures the operations required to be able to cache command
/// executions remotely.
#[async_trait]
trait ActionCacheProvider: Sync + Send + 'static {
async fn update_action_result(
&self,
action_digest: Digest,
action_result: ActionResult,
) -> Result<(), String>;

async fn get_action_result(
&self,
action_digest: Digest,
context: &Context,
) -> Result<Option<ActionResult>, String>;
}

/// This `CommandRunner` implementation caches results remotely using the Action Cache service
/// of the Remote Execution API.
///
Expand All @@ -56,7 +69,7 @@ pub struct CommandRunner {
append_only_caches_base_path: Option<String>,
executor: task_executor::Executor,
store: Store,
action_cache_client: Arc<ActionCacheClient<LayeredService>>,
provider: Arc<dyn ActionCacheProvider>,
cache_read: bool,
cache_write: bool,
cache_content_behavior: CacheContentBehavior,
Expand All @@ -74,7 +87,7 @@ impl CommandRunner {
store: Store,
action_cache_address: &str,
root_ca_certs: Option<Vec<u8>>,
mut headers: BTreeMap<String, String>,
headers: BTreeMap<String, String>,
cache_read: bool,
cache_write: bool,
warnings_behavior: RemoteCacheWarningsBehavior,
Expand All @@ -83,25 +96,14 @@ impl CommandRunner {
rpc_timeout: Duration,
append_only_caches_base_path: Option<String>,
) -> Result<Self, String> {
let tls_client_config = if action_cache_address.starts_with("https://") {
Some(grpc_util::tls::Config::new_without_mtls(root_ca_certs).try_into()?)
} else {
None
};

let endpoint = grpc_util::create_endpoint(
let provider = Arc::new(reapi::Provider::new(
instance_name.clone(),
action_cache_address,
tls_client_config.as_ref(),
&mut headers,
)?;
let http_headers = headers_to_http_header_map(&headers)?;
let channel = layered_service(
tonic::transport::Channel::balance_list(vec![endpoint].into_iter()),
root_ca_certs,
headers,
concurrency_limit,
http_headers,
Some((rpc_timeout, Metric::RemoteCacheRequestTimeouts)),
);
let action_cache_client = Arc::new(ActionCacheClient::new(channel));
rpc_timeout,
)?);

Ok(CommandRunner {
inner,
Expand All @@ -110,7 +112,7 @@ impl CommandRunner {
append_only_caches_base_path,
executor,
store,
action_cache_client,
provider,
cache_read,
cache_write,
cache_content_behavior,
Expand Down Expand Up @@ -280,10 +282,9 @@ impl CommandRunner {
let response = check_action_cache(
action_digest,
&request.description,
self.instance_name.clone(),
request.execution_environment.clone(),
&context,
self.action_cache_client.clone(),
self.provider.clone(),
self.store.clone(),
self.cache_content_behavior,
)
Expand Down Expand Up @@ -384,7 +385,6 @@ impl CommandRunner {
async fn update_action_cache(
&self,
result: &FallibleProcessResultWithPlatform,
instance_name: Option<String>,
command: &Command,
action_digest: Digest,
command_digest: Digest,
Expand All @@ -405,28 +405,10 @@ impl CommandRunner {
.ensure_remote_has_recursive(digests_for_action_result)
.await?;

let client = self.action_cache_client.as_ref().clone();
retry_call(
client,
move |mut client| {
let update_action_cache_request = remexec::UpdateActionResultRequest {
instance_name: instance_name.clone().unwrap_or_else(|| "".to_owned()),
action_digest: Some(action_digest.into()),
action_result: Some(action_result.clone()),
..remexec::UpdateActionResultRequest::default()
};

async move {
client
.update_action_result(update_action_cache_request)
.await
}
},
status_is_retryable,
)
.await
.map_err(status_to_str)?;

self
.provider
.update_action_result(action_digest, action_result)
.await?;
Ok(())
}

Expand Down Expand Up @@ -532,13 +514,7 @@ impl process_execution::CommandRunner for CommandRunner {
let write_fut = in_workunit!("remote_cache_write", Level::Trace, |workunit| async move {
workunit.increment_counter(Metric::RemoteCacheWriteAttempts, 1);
let write_result = command_runner
.update_action_cache(
&result,
command_runner.instance_name.clone(),
&command,
action_digest,
command_digest,
)
.update_action_cache(&result, &command, action_digest, command_digest)
.await;
match write_result {
Ok(_) => workunit.increment_counter(Metric::RemoteCacheWriteSuccesses, 1),
Expand Down Expand Up @@ -574,10 +550,9 @@ impl process_execution::CommandRunner for CommandRunner {
async fn check_action_cache(
action_digest: Digest,
command_description: &str,
instance_name: Option<String>,
environment: ProcessExecutionEnvironment,
context: &Context,
action_cache_client: Arc<ActionCacheClient<LayeredService>>,
provider: Arc<dyn ActionCacheProvider>,
store: Store,
cache_content_behavior: CacheContentBehavior,
) -> Result<Option<FallibleProcessResultWithPlatform>, ProcessError> {
Expand All @@ -589,69 +564,48 @@ async fn check_action_cache(
workunit.increment_counter(Metric::RemoteCacheRequests, 1);

let start = Instant::now();
let client = action_cache_client.as_ref().clone();
let response = retry_call(
client,
move |mut client| {
let request = remexec::GetActionResultRequest {
action_digest: Some(action_digest.into()),
instance_name: instance_name.clone().unwrap_or_default(),
..remexec::GetActionResultRequest::default()
};
let request = apply_headers(Request::new(request), &context.build_id);
async move { client.get_action_result(request).await }
},
status_is_retryable,
)
.and_then(|action_result| async move {
let action_result = action_result.into_inner();
let response = populate_fallible_execution_result(
store.clone(),
context.run_id,
&action_result,
false,
ProcessResultSource::HitRemotely,
environment,
)
.await
.map_err(|e| Status::unavailable(format!("Output roots could not be loaded: {e}")))?;

let cache_content_valid = check_cache_content(&response, &store, cache_content_behavior)
let response = provider
.get_action_result(action_digest, context)
.and_then(|action_result| async move {
let Some(action_result) = action_result else { return Ok(None) };

let response = populate_fallible_execution_result(
store.clone(),
context.run_id,
&action_result,
false,
ProcessResultSource::HitRemotely,
environment,
)
.await
.map_err(|e| {
Status::unavailable(format!("Output content could not be validated: {e}"))
})?;

if cache_content_valid {
Ok(response)
} else {
Err(Status::not_found(""))
}
})
.await;
.map_err(|e| format!("Output roots could not be loaded: {e}"))?;

let cache_content_valid = check_cache_content(&response, &store, cache_content_behavior)
.await
.map_err(|e| format!("Output content could not be validated: {e}"))?;

if cache_content_valid {
Ok(Some(response))
} else {
Ok(None)
}
})
.await;

workunit.record_observation(
ObservationMetric::RemoteCacheGetActionResultTimeMicros,
start.elapsed().as_micros() as u64,
);

match response {
Ok(response) => {
workunit.increment_counter(Metric::RemoteCacheRequestsCached, 1);
Ok(Some(response))
}
Err(status) => match status.code() {
Code::NotFound => {
workunit.increment_counter(Metric::RemoteCacheRequestsUncached, 1);
Ok(None)
}
_ => {
workunit.increment_counter(Metric::RemoteCacheReadErrors, 1);
// TODO: Ensure that we're catching missing digests.
Err(status_to_str(status).into())
}
},
}
let counter = match response {
Ok(Some(_)) => Metric::RemoteCacheRequestsCached,
Ok(None) => Metric::RemoteCacheRequestsUncached,
// TODO: Ensure that we're catching missing digests.
Err(_) => Metric::RemoteCacheReadErrors,
};
workunit.increment_counter(counter, 1);

response.map_err(ProcessError::from)
}
)
.await
Expand Down
Loading

0 comments on commit 713bb5b

Please sign in to comment.