diff --git a/Cargo.lock b/Cargo.lock index 808c553fb..c0c945860 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1924,7 +1924,6 @@ version = "0.4.0" dependencies = [ "async-lock", "async-trait", - "bitflags 2.5.0", "blake3", "futures", "hashbrown 0.14.5", @@ -2040,6 +2039,7 @@ version = "0.4.0" dependencies = [ "async-lock", "async-trait", + "bitflags 2.5.0", "blake3", "bytes", "console-subscriber", diff --git a/nativelink-scheduler/Cargo.toml b/nativelink-scheduler/Cargo.toml index 13890d628..9c766f300 100644 --- a/nativelink-scheduler/Cargo.toml +++ b/nativelink-scheduler/Cargo.toml @@ -28,7 +28,6 @@ tokio = { version = "1.37.0", features = ["sync", "rt", "parking_lot"] } tokio-stream = { version = "0.1.15", features = ["sync"] } tonic = { version = "0.11.0", features = ["gzip", "tls"] } tracing = "0.1.40" -bitflags = "2.5.0" redis = { version = "0.25.2", features = ["aio", "tokio", "json"] } serde = "1.0.203" redis-macros = "0.3.0" diff --git a/nativelink-scheduler/src/scheduler_state/workers.rs b/nativelink-scheduler/src/api_worker_scheduler.rs similarity index 97% rename from nativelink-scheduler/src/scheduler_state/workers.rs rename to nativelink-scheduler/src/api_worker_scheduler.rs index c98a8cb22..1239e42df 100644 --- a/nativelink-scheduler/src/scheduler_state/workers.rs +++ b/nativelink-scheduler/src/api_worker_scheduler.rs @@ -20,12 +20,12 @@ use nativelink_config::schedulers::WorkerAllocationStrategy; use nativelink_error::{error_if, make_err, make_input_err, Code, Error, ResultExt}; use nativelink_util::action_messages::{ActionInfo, ActionStage, OperationId, WorkerId}; use nativelink_util::metrics_utils::Registry; +use nativelink_util::operation_state_manager::WorkerStateManager; use nativelink_util::platform_properties::PlatformProperties; use tokio::sync::Notify; use tonic::async_trait; use tracing::{event, Level}; -use crate::operation_state_manager::WorkerStateManager; use crate::platform_property_manager::PlatformPropertyManager; use crate::worker::{Worker, WorkerTimestamp, WorkerUpdate}; use crate::worker_scheduler::WorkerScheduler; @@ -45,7 +45,7 @@ struct ApiWorkerSchedulerImpl { impl ApiWorkerSchedulerImpl { /// Refreshes the lifetime of the worker with the given timestamp. - pub(crate) fn refresh_lifetime( + fn refresh_lifetime( &mut self, worker_id: &WorkerId, timestamp: WorkerTimestamp, @@ -68,7 +68,7 @@ impl ApiWorkerSchedulerImpl { /// Adds a worker to the pool. /// Note: This function will not do any task matching. - pub(crate) fn add_worker(&mut self, worker: Worker) -> Result<(), Error> { + fn add_worker(&mut self, worker: Worker) -> Result<(), Error> { let worker_id = worker.id; self.workers.put(worker_id, worker); @@ -94,14 +94,14 @@ impl ApiWorkerSchedulerImpl { /// Removes worker from pool. /// Note: The caller is responsible for any rescheduling of any tasks that might be /// running. - pub(crate) fn remove_worker(&mut self, worker_id: &WorkerId) -> Option { + fn remove_worker(&mut self, worker_id: &WorkerId) -> Option { let result = self.workers.pop(worker_id); self.worker_change_notify.notify_one(); result } /// Sets if the worker is draining or not. - pub async fn set_drain_worker( + async fn set_drain_worker( &mut self, worker_id: &WorkerId, is_draining: bool, @@ -134,7 +134,7 @@ impl ApiWorkerSchedulerImpl { workers_iter.map(|(_, w)| &w.id).copied() } - pub async fn update_action( + async fn update_action( &mut self, worker_id: &WorkerId, operation_id: &OperationId, @@ -245,7 +245,7 @@ impl ApiWorkerSchedulerImpl { } /// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it. - pub async fn immediate_evict_worker( + async fn immediate_evict_worker( &mut self, worker_id: &WorkerId, err: Error, diff --git a/nativelink-scheduler/src/lib.rs b/nativelink-scheduler/src/lib.rs index 27bc8f8a9..c60675621 100644 --- a/nativelink-scheduler/src/lib.rs +++ b/nativelink-scheduler/src/lib.rs @@ -13,16 +13,16 @@ // limitations under the License. pub mod action_scheduler; +pub mod api_worker_scheduler; pub mod cache_lookup_scheduler; pub mod default_action_listener; pub mod default_scheduler_factory; pub mod grpc_scheduler; -pub mod operation_state_manager; +pub mod memory_scheduler_state; pub mod platform_property_manager; pub mod property_modifier_scheduler; pub mod redis_action_stage; pub mod redis_operation_state; -pub mod scheduler_state; pub mod simple_scheduler; pub mod worker; pub mod worker_scheduler; diff --git a/nativelink-scheduler/src/scheduler_state/awaited_action.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action.rs similarity index 98% rename from nativelink-scheduler/src/scheduler_state/awaited_action.rs rename to nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action.rs index 47f0bcfaf..33de3f2ba 100644 --- a/nativelink-scheduler/src/scheduler_state/awaited_action.rs +++ b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action.rs @@ -238,8 +238,7 @@ impl AwaitedAction { /// Sets the current state of the action and notifies subscribers. /// Returns true if the state was set, false if there are no subscribers. #[must_use] - // TODO(allada) IMPORTANT: This function should only be visible to ActionActionDb. - pub(crate) fn set_current_state(&self, state: Arc) -> bool { + pub fn set_current_state(&self, state: Arc) -> bool { self.update_worker_timestamp(); // Note: Use `send_replace()`. Using `send()` will not change the value if // there are no subscribers. diff --git a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action_db.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action_db.rs new file mode 100644 index 000000000..fb305d5c3 --- /dev/null +++ b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action_db.rs @@ -0,0 +1,408 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{BTreeSet, HashMap}; +use std::sync::Arc; +use std::time::SystemTime; + +use nativelink_config::stores::EvictionPolicy; +use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, OperationId, +}; +use nativelink_util::evicting_map::EvictingMap; +use tokio::sync::{mpsc, watch}; +use tracing::{event, Level}; + +use super::client_awaited_action::ClientAwaitedAction; +use super::{AwaitedAction, SortedAwaitedAction}; + +#[derive(Default)] +struct SortedAwaitedActions { + unknown: BTreeSet, + cache_check: BTreeSet, + queued: BTreeSet, + executing: BTreeSet, + completed: BTreeSet, + completed_from_cache: BTreeSet, +} + +/// The database for storing the state of all actions. +/// IMPORTANT: Any time an item is removed from +/// [`AwaitedActionDb::client_operation_to_awaited_action`], it must +/// also remove the entries from all the other maps. +pub struct AwaitedActionDb { + /// A lookup table to lookup the state of an action by its client operation id. + client_operation_to_awaited_action: + EvictingMap, SystemTime>, + + /// A lookup table to lookup the state of an action by its worker operation id. + operation_id_to_awaited_action: HashMap>, + + /// A lookup table to lookup the state of an action by its unique qualifier. + action_info_hash_key_to_awaited_action: HashMap>, + + /// A sorted set of [`AwaitedAction`]s. A wrapper is used to perform sorting + /// based on the [`AwaitedActionSortKey`] of the [`AwaitedAction`]. + /// + /// See [`AwaitedActionSortKey`] for more information on the ordering. + sorted_action_info_hash_keys: SortedAwaitedActions, +} + +#[allow(clippy::mutable_key_type)] +impl AwaitedActionDb { + pub fn new(eviction_config: &EvictionPolicy) -> Self { + Self { + client_operation_to_awaited_action: EvictingMap::new( + eviction_config, + SystemTime::now(), + ), + operation_id_to_awaited_action: HashMap::new(), + action_info_hash_key_to_awaited_action: HashMap::new(), + sorted_action_info_hash_keys: SortedAwaitedActions::default(), + } + } + + /// Refreshes/Updates the time to live of the [`ClientOperationId`] in + /// the [`EvictingMap`] by touching the key. + pub async fn refresh_client_operation_id( + &self, + client_operation_id: &ClientOperationId, + ) -> bool { + self.client_operation_to_awaited_action + .size_for_key(client_operation_id) + .await + .is_some() + } + + pub async fn get_by_client_operation_id( + &self, + client_operation_id: &ClientOperationId, + ) -> Option> { + self.client_operation_to_awaited_action + .get(client_operation_id) + .await + } + + /// When a client operation is dropped, we need to remove it from the + /// other maps and update the listening clients count on the [`AwaitedAction`]. + pub fn on_client_operations_drop( + &mut self, + awaited_actions: impl IntoIterator>, + ) { + for awaited_action in awaited_actions.into_iter() { + if awaited_action.get_listening_clients() != 0 { + // We still have other clients listening to this action. + continue; + } + + let operation_id = awaited_action.get_operation_id(); + + // Cleanup operation_id_to_awaited_action. + if self + .operation_id_to_awaited_action + .remove(&operation_id) + .is_none() + { + event!( + Level::ERROR, + ?operation_id, + ?awaited_action, + "operation_id_to_awaited_action and client_operation_to_awaited_action are out of sync", + ); + } + + // Cleanup action_info_hash_key_to_awaited_action if it was marked cached. + let action_info = awaited_action.get_action_info(); + match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(action_key) => { + let maybe_awaited_action = self + .action_info_hash_key_to_awaited_action + .remove(action_key); + if maybe_awaited_action.is_none() { + event!( + Level::ERROR, + ?operation_id, + ?awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", + ); + } + } + ActionUniqueQualifier::Uncachable(_action_key) => { + // This Operation should not be in the hash_key map. + } + } + + // Cleanup sorted_awaited_action. + let sort_info = awaited_action.get_sort_info(); + let sort_key = sort_info.get_previous_sort_key(); + let sort_map_for_state = + self.get_sort_map_for_state(&awaited_action.get_current_state().stage); + drop(sort_info); + let maybe_sorted_awaited_action = sort_map_for_state.take(&SortedAwaitedAction { + sort_key, + awaited_action, + }); + if maybe_sorted_awaited_action.is_none() { + event!( + Level::ERROR, + ?operation_id, + ?sort_key, + "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync", + ); + } + } + } + + pub fn get_all_awaited_actions(&self) -> impl Iterator> { + self.operation_id_to_awaited_action.values() + } + + pub fn get_by_operation_id(&self, operation_id: &OperationId) -> Option<&Arc> { + self.operation_id_to_awaited_action.get(operation_id) + } + + pub fn get_cache_check_actions(&self) -> &BTreeSet { + &self.sorted_action_info_hash_keys.cache_check + } + + pub fn get_queued_actions(&self) -> &BTreeSet { + &self.sorted_action_info_hash_keys.queued + } + + pub fn get_executing_actions(&self) -> &BTreeSet { + &self.sorted_action_info_hash_keys.executing + } + + pub fn get_completed_actions(&self) -> &BTreeSet { + &self.sorted_action_info_hash_keys.completed + } + + fn get_sort_map_for_state( + &mut self, + state: &ActionStage, + ) -> &mut BTreeSet { + match state { + ActionStage::Unknown => &mut self.sorted_action_info_hash_keys.unknown, + ActionStage::CacheCheck => &mut self.sorted_action_info_hash_keys.cache_check, + ActionStage::Queued => &mut self.sorted_action_info_hash_keys.queued, + ActionStage::Executing => &mut self.sorted_action_info_hash_keys.executing, + ActionStage::Completed(_) => &mut self.sorted_action_info_hash_keys.completed, + ActionStage::CompletedFromCache(_) => { + &mut self.sorted_action_info_hash_keys.completed_from_cache + } + } + } + + fn insert_sort_map_for_stage( + &mut self, + stage: &ActionStage, + sorted_awaited_action: SortedAwaitedAction, + ) { + let newly_inserted = match stage { + ActionStage::Unknown => self + .sorted_action_info_hash_keys + .unknown + .insert(sorted_awaited_action), + ActionStage::CacheCheck => self + .sorted_action_info_hash_keys + .cache_check + .insert(sorted_awaited_action), + ActionStage::Queued => self + .sorted_action_info_hash_keys + .queued + .insert(sorted_awaited_action), + ActionStage::Executing => self + .sorted_action_info_hash_keys + .executing + .insert(sorted_awaited_action), + ActionStage::Completed(_) => self + .sorted_action_info_hash_keys + .completed + .insert(sorted_awaited_action), + ActionStage::CompletedFromCache(_) => self + .sorted_action_info_hash_keys + .completed_from_cache + .insert(sorted_awaited_action), + }; + if !newly_inserted { + event!( + Level::ERROR, + "Tried to insert an action that was already in the sorted map. This should never happen.", + ); + } + } + + /// Sets the state of the action to the provided `action_state` and notifies all listeners. + /// If the action has no more listeners, returns `false`. + pub fn set_action_state( + &mut self, + awaited_action: Arc, + new_action_state: Arc, + ) -> bool { + // We need to first get a lock on the awaited action to ensure + // another operation doesn't update it while we are looking up + // the sorted key. + let sort_info = awaited_action.get_sort_info(); + let old_state = awaited_action.get_current_state(); + + let has_listeners = awaited_action.set_current_state(new_action_state.clone()); + + if !old_state.stage.is_same_stage(&new_action_state.stage) { + let sort_key = sort_info.get_previous_sort_key(); + let btree = self.get_sort_map_for_state(&old_state.stage); + drop(sort_info); + let maybe_sorted_awaited_action = btree.take(&SortedAwaitedAction { + sort_key, + awaited_action, + }); + + let Some(sorted_awaited_action) = maybe_sorted_awaited_action else { + event!( + Level::ERROR, + "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync", + ); + return false; + }; + + self.insert_sort_map_for_stage(&new_action_state.stage, sorted_awaited_action); + } + has_listeners + } + + pub async fn subscribe_or_add_action( + &mut self, + client_operation_id: ClientOperationId, + action_info: Arc, + client_operation_drop_tx: &mpsc::UnboundedSender>, + ) -> watch::Receiver> { + // Check to see if the action is already known and subscribe if it is. + let subscription_result = self + .try_subscribe( + &client_operation_id, + &action_info.unique_qualifier, + action_info.priority, + client_operation_drop_tx, + ) + .await; + let action_info = match subscription_result { + Ok(subscription) => return subscription, + // TODO!(we should not ignore the error here.) + Err(_) => action_info, + }; + + let maybe_unique_key = match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => Some(unique_key.clone()), + ActionUniqueQualifier::Uncachable(_unique_key) => None, + }; + let (awaited_action, sort_key, subscription) = + AwaitedAction::new_with_subscription(action_info); + let awaited_action = Arc::new(awaited_action); + self.client_operation_to_awaited_action + .insert( + client_operation_id, + Arc::new(ClientAwaitedAction::new( + awaited_action.clone(), + client_operation_drop_tx.clone(), + )), + ) + .await; + // Note: We only put items in the map that are cachable. + if let Some(unique_key) = maybe_unique_key { + self.action_info_hash_key_to_awaited_action + .insert(unique_key, awaited_action.clone()); + } + self.operation_id_to_awaited_action + .insert(awaited_action.get_operation_id(), awaited_action.clone()); + + self.insert_sort_map_for_stage( + &awaited_action.get_current_state().stage, + SortedAwaitedAction { + sort_key, + awaited_action, + }, + ); + subscription + } + + async fn try_subscribe( + &mut self, + client_operation_id: &ClientOperationId, + unique_qualifier: &ActionUniqueQualifier, + priority: i32, + client_operation_drop_tx: &mpsc::UnboundedSender>, + ) -> Result>, Error> { + let unique_key = match unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => unique_key, + ActionUniqueQualifier::Uncachable(_unique_key) => { + return Err(make_err!( + Code::InvalidArgument, + "Cannot subscribe to an existing item when skip_cache_lookup is true." + )); + } + }; + + let awaited_action = self + .action_info_hash_key_to_awaited_action + .get(unique_key) + .ok_or(make_input_err!( + "Could not find existing action with name: {unique_qualifier}" + )) + .err_tip(|| "In state_manager::try_subscribe")?; + + // Do not subscribe if the action is already completed, + // this is the responsibility of the CacheLookupScheduler. + if awaited_action.get_current_state().stage.is_finished() { + return Err(make_input_err!( + "Subscribing an item that is already completed should be handled by CacheLookupScheduler." + )); + } + let awaited_action = awaited_action.clone(); + if let Some(sort_info_lock) = awaited_action.upgrade_priority(priority) { + let state = awaited_action.get_current_state(); + let maybe_sorted_awaited_action = + self.get_sort_map_for_state(&state.stage) + .take(&SortedAwaitedAction { + sort_key: sort_info_lock.get_previous_sort_key(), + awaited_action: awaited_action.clone(), + }); + let Some(mut sorted_awaited_action) = maybe_sorted_awaited_action else { + // TODO!(Either use event on all of the above error here, but both is overkill). + let err = make_err!( + Code::Internal, + "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync"); + event!(Level::ERROR, ?unique_qualifier, ?awaited_action, "{err:?}",); + return Err(err); + }; + sorted_awaited_action.sort_key = sort_info_lock.get_new_sort_key(); + self.insert_sort_map_for_stage(&state.stage, sorted_awaited_action); + } + + let subscription = awaited_action.subscribe(); + + self.client_operation_to_awaited_action + .insert( + client_operation_id.clone(), + Arc::new(ClientAwaitedAction::new( + awaited_action, + client_operation_drop_tx.clone(), + )), + ) + .await; + + Ok(subscription) + } +} diff --git a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/client_awaited_action.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/client_awaited_action.rs new file mode 100644 index 000000000..6c75419b2 --- /dev/null +++ b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/client_awaited_action.rs @@ -0,0 +1,82 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use nativelink_util::evicting_map::LenEntry; +use tokio::sync::mpsc; + +use super::AwaitedAction; + +/// Represents a client that is currently listening to an action. +/// When the client is dropped, it will send the [`AwaitedAction`] to the +/// `client_operation_drop_tx` if there are other cleanups needed. +#[derive(Debug)] +pub(crate) struct ClientAwaitedAction { + /// The awaited action that the client is listening to. + // Note: This is an Option because it is taken when the + // ClientAwaitedAction is dropped, but will never actually be + // None except during the drop. + awaited_action: Option>, + + /// The sender to notify of this struct being dropped. + client_operation_drop_tx: mpsc::UnboundedSender>, +} + +impl ClientAwaitedAction { + pub fn new( + awaited_action: Arc, + client_operation_drop_tx: mpsc::UnboundedSender>, + ) -> Self { + awaited_action.inc_listening_clients(); + Self { + awaited_action: Some(awaited_action), + client_operation_drop_tx, + } + } + + /// Returns the awaited action that the client is listening to. + pub fn awaited_action(&self) -> &Arc { + self.awaited_action + .as_ref() + .expect("AwaitedAction should be present") + } +} + +impl Drop for ClientAwaitedAction { + fn drop(&mut self) { + let awaited_action = self + .awaited_action + .take() + .expect("AwaitedAction should be present"); + awaited_action.dec_listening_clients(); + // If we failed to send it means noone is listening. + let _ = self.client_operation_drop_tx.send(awaited_action); + } +} + +/// Trait to be able to use the [`EvictingMap`] with [`ClientAwaitedAction`]. +/// Note: We only use [`EvictingMap`] for a time based evictiong, which is +/// why the implementation has fixed default values in it. +impl LenEntry for ClientAwaitedAction { + #[inline] + fn len(&self) -> usize { + 0 + } + + #[inline] + fn is_empty(&self) -> bool { + true + } +} diff --git a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/mod.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/mod.rs new file mode 100644 index 000000000..0c3da873f --- /dev/null +++ b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/mod.rs @@ -0,0 +1,23 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub(crate) use awaited_action::AwaitedAction; +pub(crate) use awaited_action_db::AwaitedActionDb; +pub(crate) use sorted_awaited_action::SortedAwaitedAction; + +mod awaited_action; +#[allow(clippy::module_inception)] +mod awaited_action_db; +mod client_awaited_action; +mod sorted_awaited_action; diff --git a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/sorted_awaited_action.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/sorted_awaited_action.rs new file mode 100644 index 000000000..18a7d607a --- /dev/null +++ b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/sorted_awaited_action.rs @@ -0,0 +1,46 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp; +use std::sync::Arc; + +use super::awaited_action::{AwaitedAction, AwaitedActionSortKey}; + +#[derive(Debug, Clone)] +pub struct SortedAwaitedAction { + pub sort_key: AwaitedActionSortKey, + pub awaited_action: Arc, +} + +impl PartialEq for SortedAwaitedAction { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.awaited_action, &other.awaited_action) && self.sort_key == other.sort_key + } +} + +impl Eq for SortedAwaitedAction {} + +impl PartialOrd for SortedAwaitedAction { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SortedAwaitedAction { + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.sort_key.cmp(&other.sort_key).then_with(|| { + Arc::as_ptr(&self.awaited_action).cmp(&Arc::as_ptr(&other.awaited_action)) + }) + } +} diff --git a/nativelink-scheduler/src/scheduler_state/client_action_state_result.rs b/nativelink-scheduler/src/memory_scheduler_state/client_action_state_result.rs similarity index 96% rename from nativelink-scheduler/src/scheduler_state/client_action_state_result.rs rename to nativelink-scheduler/src/memory_scheduler_state/client_action_state_result.rs index 43c479187..b0469c3f7 100644 --- a/nativelink-scheduler/src/scheduler_state/client_action_state_result.rs +++ b/nativelink-scheduler/src/memory_scheduler_state/client_action_state_result.rs @@ -18,11 +18,10 @@ use std::sync::Arc; use async_trait::async_trait; use nativelink_error::Error; use nativelink_util::action_messages::{ActionInfo, ActionState}; +use nativelink_util::operation_state_manager::ActionStateResult; use nativelink_util::task::JoinHandleDropGuard; use tokio::sync::watch::Receiver; -use crate::operation_state_manager::ActionStateResult; - pub(crate) struct ClientActionStateResult { /// The action info for the action. action_info: Arc, @@ -38,7 +37,7 @@ pub(crate) struct ClientActionStateResult { } impl ClientActionStateResult { - pub(crate) fn new( + pub fn new( action_info: Arc, mut rx: Receiver>, maybe_keepalive_spawn: Option>, diff --git a/nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs b/nativelink-scheduler/src/memory_scheduler_state/matching_engine_action_state_result.rs similarity index 86% rename from nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs rename to nativelink-scheduler/src/memory_scheduler_state/matching_engine_action_state_result.rs index de95ab233..4b5948169 100644 --- a/nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs +++ b/nativelink-scheduler/src/memory_scheduler_state/matching_engine_action_state_result.rs @@ -18,16 +18,16 @@ use std::sync::Arc; use async_trait::async_trait; use nativelink_error::Error; use nativelink_util::action_messages::{ActionInfo, ActionState}; +use nativelink_util::operation_state_manager::ActionStateResult; use tokio::sync::watch; -use super::awaited_action::AwaitedAction; -use crate::operation_state_manager::ActionStateResult; +use super::awaited_action_db::AwaitedAction; -pub struct MatchingEngineActionStateResult { +pub(crate) struct MatchingEngineActionStateResult { awaited_action: Arc, } impl MatchingEngineActionStateResult { - pub(crate) fn new(awaited_action: Arc) -> Self { + pub fn new(awaited_action: Arc) -> Self { Self { awaited_action } } } diff --git a/nativelink-scheduler/src/scheduler_state/state_manager.rs b/nativelink-scheduler/src/memory_scheduler_state/memory_scheduler_state_manager.rs similarity index 50% rename from nativelink-scheduler/src/scheduler_state/state_manager.rs rename to nativelink-scheduler/src/memory_scheduler_state/memory_scheduler_state_manager.rs index 0edd7a312..0cfdb4b15 100644 --- a/nativelink-scheduler/src/scheduler_state/state_manager.rs +++ b/nativelink-scheduler/src/memory_scheduler_state/memory_scheduler_state_manager.rs @@ -12,712 +12,36 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::cmp; use std::collections::{BTreeSet, VecDeque}; use std::ops::Bound; use std::sync::{Arc, Weak}; -use std::time::{Duration, SystemTime}; +use std::time::Duration; use async_lock::Mutex; use async_trait::async_trait; use futures::stream::{self, unfold}; -use hashbrown::HashMap; use nativelink_config::stores::EvictionPolicy; -use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; +use nativelink_error::{make_err, Code, Error}; use nativelink_util::action_messages::{ - ActionInfo, ActionResult, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, - ClientOperationId, ExecutionMetadata, OperationId, WorkerId, + ActionInfo, ActionResult, ActionStage, ActionState, ActionUniqueQualifier, ClientOperationId, + ExecutionMetadata, OperationId, WorkerId, }; -use nativelink_util::evicting_map::{EvictingMap, LenEntry}; -use nativelink_util::spawn; -use nativelink_util::task::JoinHandleDropGuard; -use tokio::sync::{mpsc, watch, Notify}; -use tracing::{event, Level}; - -use super::awaited_action::AwaitedActionSortKey; -use crate::operation_state_manager::{ - ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, - OperationFilter, OperationStageFlags, OrderDirection, WorkerStateManager, -}; -use crate::scheduler_state::awaited_action::AwaitedAction; -use crate::scheduler_state::client_action_state_result::ClientActionStateResult; -use crate::scheduler_state::matching_engine_action_state_result::MatchingEngineActionStateResult; - -/// How often the owning database will have the AwaitedAction touched -/// to keep it from being evicted. -const KEEPALIVE_DURATION: Duration = Duration::from_secs(10); - -#[derive(Debug, Clone)] -struct SortedAwaitedAction { - sort_key: AwaitedActionSortKey, - awaited_action: Arc, -} - -impl PartialEq for SortedAwaitedAction { - fn eq(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.awaited_action, &other.awaited_action) && self.sort_key == other.sort_key - } -} - -impl Eq for SortedAwaitedAction {} - -impl PartialOrd for SortedAwaitedAction { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for SortedAwaitedAction { - fn cmp(&self, other: &Self) -> cmp::Ordering { - self.sort_key.cmp(&other.sort_key).then_with(|| { - Arc::as_ptr(&self.awaited_action).cmp(&Arc::as_ptr(&other.awaited_action)) - }) - } -} - -#[derive(Default)] -struct SortedAwaitedActions { - unknown: BTreeSet, - cache_check: BTreeSet, - queued: BTreeSet, - executing: BTreeSet, - completed: BTreeSet, - completed_from_cache: BTreeSet, -} - -/// Represents a client that is currently listening to an action. -/// When the client is dropped, it will send the [`AwaitedAction`] to the -/// `client_operation_drop_tx` if there are other cleanups needed. -#[derive(Debug)] -struct ClientAwaitedAction { - /// The awaited action that the client is listening to. - // Note: This is an Option because it is taken when the - // ClientAwaitedAction is dropped, but will never actually be - // None except during the drop. - awaited_action: Option>, - - /// The sender to notify of this struct being dropped. - client_operation_drop_tx: mpsc::UnboundedSender>, -} - -impl ClientAwaitedAction { - fn new( - awaited_action: Arc, - client_operation_drop_tx: mpsc::UnboundedSender>, - ) -> Self { - awaited_action.inc_listening_clients(); - Self { - awaited_action: Some(awaited_action), - client_operation_drop_tx, - } - } - - /// Returns the awaited action that the client is listening to. - fn awaited_action(&self) -> &Arc { - self.awaited_action - .as_ref() - .expect("AwaitedAction should be present") - } -} - -impl Drop for ClientAwaitedAction { - fn drop(&mut self) { - let awaited_action = self - .awaited_action - .take() - .expect("AwaitedAction should be present"); - awaited_action.dec_listening_clients(); - // If we failed to send it means noone is listening. - let _ = self.client_operation_drop_tx.send(awaited_action); - } -} - -/// Trait to be able to use the [`EvictingMap`] with [`ClientAwaitedAction`]. -/// Note: We only use [`EvictingMap`] for a time based evictiong, which is -/// why the implementation has fixed default values in it. -impl LenEntry for ClientAwaitedAction { - #[inline] - fn len(&self) -> usize { - 0 - } - - #[inline] - fn is_empty(&self) -> bool { - true - } -} - -/// The database for storing the state of all actions. -/// IMPORTANT: Any time an item is removed from -/// [`AwaitedActionDb::client_operation_to_awaited_action`], it must -/// also remove the entries from all the other maps. -pub struct AwaitedActionDb { - /// A lookup table to lookup the state of an action by its client operation id. - client_operation_to_awaited_action: - EvictingMap, SystemTime>, - - /// A lookup table to lookup the state of an action by its worker operation id. - operation_id_to_awaited_action: HashMap>, - - /// A lookup table to lookup the state of an action by its unique qualifier. - action_info_hash_key_to_awaited_action: HashMap>, - - /// A sorted set of [`AwaitedAction`]s. A wrapper is used to perform sorting - /// based on the [`AwaitedActionSortKey`] of the [`AwaitedAction`]. - /// - /// See [`AwaitedActionSortKey`] for more information on the ordering. - sorted_action_info_hash_keys: SortedAwaitedActions, -} - -#[allow(clippy::mutable_key_type)] -impl AwaitedActionDb { - /// Refreshes/Updates the time to live of the [`ClientOperationId`] in - /// the [`EvictingMap`] by touching the key. - async fn refresh_client_operation_id(&self, client_operation_id: &ClientOperationId) -> bool { - self.client_operation_to_awaited_action - .size_for_key(client_operation_id) - .await - .is_some() - } - - async fn get_by_client_operation_id( - &self, - client_operation_id: &ClientOperationId, - ) -> Option> { - self.client_operation_to_awaited_action - .get(client_operation_id) - .await - } - - /// When a client operation is dropped, we need to remove it from the - /// other maps and update the listening clients count on the [`AwaitedAction`]. - fn on_client_operations_drop( - &mut self, - awaited_actions: impl IntoIterator>, - ) { - for awaited_action in awaited_actions.into_iter() { - if awaited_action.get_listening_clients() != 0 { - // We still have other clients listening to this action. - continue; - } - - let operation_id = awaited_action.get_operation_id(); - - // Cleanup operation_id_to_awaited_action. - if self - .operation_id_to_awaited_action - .remove(&operation_id) - .is_none() - { - event!( - Level::ERROR, - ?operation_id, - ?awaited_action, - "operation_id_to_awaited_action and client_operation_to_awaited_action are out of sync", - ); - } - - // Cleanup action_info_hash_key_to_awaited_action if it was marked cached. - let action_info = awaited_action.get_action_info(); - match &action_info.unique_qualifier { - ActionUniqueQualifier::Cachable(action_key) => { - let maybe_awaited_action = self - .action_info_hash_key_to_awaited_action - .remove(action_key); - if maybe_awaited_action.is_none() { - event!( - Level::ERROR, - ?operation_id, - ?awaited_action, - ?action_key, - "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", - ); - } - } - ActionUniqueQualifier::Uncachable(_action_key) => { - // This Operation should not be in the hash_key map. - } - } - - // Cleanup sorted_awaited_action. - let sort_info = awaited_action.get_sort_info(); - let sort_key = sort_info.get_previous_sort_key(); - let sort_map_for_state = - self.get_sort_map_for_state(&awaited_action.get_current_state().stage); - drop(sort_info); - let maybe_sorted_awaited_action = sort_map_for_state.take(&SortedAwaitedAction { - sort_key, - awaited_action, - }); - if maybe_sorted_awaited_action.is_none() { - event!( - Level::ERROR, - ?operation_id, - ?sort_key, - "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync", - ); - } - } - } - - fn get_all_awaited_actions(&self) -> impl Iterator> { - self.operation_id_to_awaited_action.values() - } - - fn get_by_operation_id(&self, operation_id: &OperationId) -> Option<&Arc> { - self.operation_id_to_awaited_action.get(operation_id) - } - - fn get_cache_check_actions(&self) -> &BTreeSet { - &self.sorted_action_info_hash_keys.cache_check - } - - fn get_queued_actions(&self) -> &BTreeSet { - &self.sorted_action_info_hash_keys.queued - } - - fn get_executing_actions(&self) -> &BTreeSet { - &self.sorted_action_info_hash_keys.executing - } - - fn get_completed_actions(&self) -> &BTreeSet { - &self.sorted_action_info_hash_keys.completed - } - - fn get_sort_map_for_state( - &mut self, - state: &ActionStage, - ) -> &mut BTreeSet { - match state { - ActionStage::Unknown => &mut self.sorted_action_info_hash_keys.unknown, - ActionStage::CacheCheck => &mut self.sorted_action_info_hash_keys.cache_check, - ActionStage::Queued => &mut self.sorted_action_info_hash_keys.queued, - ActionStage::Executing => &mut self.sorted_action_info_hash_keys.executing, - ActionStage::Completed(_) => &mut self.sorted_action_info_hash_keys.completed, - ActionStage::CompletedFromCache(_) => { - &mut self.sorted_action_info_hash_keys.completed_from_cache - } - } - } - - fn insert_sort_map_for_stage( - &mut self, - stage: &ActionStage, - sorted_awaited_action: SortedAwaitedAction, - ) { - let newly_inserted = match stage { - ActionStage::Unknown => self - .sorted_action_info_hash_keys - .unknown - .insert(sorted_awaited_action), - ActionStage::CacheCheck => self - .sorted_action_info_hash_keys - .cache_check - .insert(sorted_awaited_action), - ActionStage::Queued => self - .sorted_action_info_hash_keys - .queued - .insert(sorted_awaited_action), - ActionStage::Executing => self - .sorted_action_info_hash_keys - .executing - .insert(sorted_awaited_action), - ActionStage::Completed(_) => self - .sorted_action_info_hash_keys - .completed - .insert(sorted_awaited_action), - ActionStage::CompletedFromCache(_) => self - .sorted_action_info_hash_keys - .completed_from_cache - .insert(sorted_awaited_action), - }; - if !newly_inserted { - event!( - Level::ERROR, - "Tried to insert an action that was already in the sorted map. This should never happen.", - ); - } - } - - /// Sets the state of the action to the provided `action_state` and notifies all listeners. - /// If the action has no more listeners, returns `false`. - fn set_action_state( - &mut self, - awaited_action: Arc, - new_action_state: Arc, - ) -> bool { - // We need to first get a lock on the awaited action to ensure - // another operation doesn't update it while we are looking up - // the sorted key. - let sort_info = awaited_action.get_sort_info(); - let old_state = awaited_action.get_current_state(); - - let has_listeners = awaited_action.set_current_state(new_action_state.clone()); - - if !old_state.stage.is_same_stage(&new_action_state.stage) { - let sort_key = sort_info.get_previous_sort_key(); - let btree = self.get_sort_map_for_state(&old_state.stage); - drop(sort_info); - let maybe_sorted_awaited_action = btree.take(&SortedAwaitedAction { - sort_key, - awaited_action, - }); - - let Some(sorted_awaited_action) = maybe_sorted_awaited_action else { - event!( - Level::ERROR, - "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync", - ); - return false; - }; - - self.insert_sort_map_for_stage(&new_action_state.stage, sorted_awaited_action); - } - has_listeners - } - - async fn subscribe_or_add_action( - &mut self, - client_operation_id: ClientOperationId, - action_info: Arc, - client_operation_drop_tx: &mpsc::UnboundedSender>, - ) -> watch::Receiver> { - // Check to see if the action is already known and subscribe if it is. - let subscription_result = self - .try_subscribe( - &client_operation_id, - &action_info.unique_qualifier, - action_info.priority, - client_operation_drop_tx, - ) - .await; - let action_info = match subscription_result { - Ok(subscription) => return subscription, - // TODO!(we should not ignore the error here.) - Err(_) => action_info, - }; - - let maybe_unique_key = match &action_info.unique_qualifier { - ActionUniqueQualifier::Cachable(unique_key) => Some(unique_key.clone()), - ActionUniqueQualifier::Uncachable(_unique_key) => None, - }; - let (awaited_action, sort_key, subscription) = - AwaitedAction::new_with_subscription(action_info); - let awaited_action = Arc::new(awaited_action); - self.client_operation_to_awaited_action - .insert( - client_operation_id, - Arc::new(ClientAwaitedAction::new( - awaited_action.clone(), - client_operation_drop_tx.clone(), - )), - ) - .await; - // Note: We only put items in the map that are cachable. - if let Some(unique_key) = maybe_unique_key { - self.action_info_hash_key_to_awaited_action - .insert(unique_key, awaited_action.clone()); - } - self.operation_id_to_awaited_action - .insert(awaited_action.get_operation_id(), awaited_action.clone()); - - self.insert_sort_map_for_stage( - &awaited_action.get_current_state().stage, - SortedAwaitedAction { - sort_key, - awaited_action, - }, - ); - subscription - } - - async fn try_subscribe( - &mut self, - client_operation_id: &ClientOperationId, - unique_qualifier: &ActionUniqueQualifier, - priority: i32, - client_operation_drop_tx: &mpsc::UnboundedSender>, - ) -> Result>, Error> { - let unique_key = match unique_qualifier { - ActionUniqueQualifier::Cachable(unique_key) => unique_key, - ActionUniqueQualifier::Uncachable(_unique_key) => { - return Err(make_err!( - Code::InvalidArgument, - "Cannot subscribe to an existing item when skip_cache_lookup is true." - )); - } - }; - - let awaited_action = self - .action_info_hash_key_to_awaited_action - .get(unique_key) - .ok_or(make_input_err!( - "Could not find existing action with name: {unique_qualifier}" - )) - .err_tip(|| "In state_manager::try_subscribe")?; - - // Do not subscribe if the action is already completed, - // this is the responsibility of the CacheLookupScheduler. - if awaited_action.get_current_state().stage.is_finished() { - return Err(make_input_err!( - "Subscribing an item that is already completed should be handled by CacheLookupScheduler." - )); - } - let awaited_action = awaited_action.clone(); - if let Some(sort_info_lock) = awaited_action.upgrade_priority(priority) { - let state = awaited_action.get_current_state(); - let maybe_sorted_awaited_action = - self.get_sort_map_for_state(&state.stage) - .take(&SortedAwaitedAction { - sort_key: sort_info_lock.get_previous_sort_key(), - awaited_action: awaited_action.clone(), - }); - let Some(mut sorted_awaited_action) = maybe_sorted_awaited_action else { - // TODO!(Either use event on all of the above error here, but both is overkill). - let err = make_err!( - Code::Internal, - "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync"); - event!(Level::ERROR, ?unique_qualifier, ?awaited_action, "{err:?}",); - return Err(err); - }; - sorted_awaited_action.sort_key = sort_info_lock.get_new_sort_key(); - self.insert_sort_map_for_stage(&state.stage, sorted_awaited_action); - } - - let subscription = awaited_action.subscribe(); - - self.client_operation_to_awaited_action - .insert( - client_operation_id.clone(), - Arc::new(ClientAwaitedAction::new( - awaited_action, - client_operation_drop_tx.clone(), - )), - ) - .await; - - Ok(subscription) - } -} - -#[repr(transparent)] -pub struct StateManager { - inner: Arc>, -} - -impl StateManager { - pub fn new( - config: &EvictionPolicy, - tasks_change_notify: Arc, - max_job_retries: usize, - ) -> Self { - Self { - inner: Arc::new_cyclic(move |weak_self| -> Mutex { - let weak_inner = weak_self.clone(); - let (client_operation_drop_tx, mut client_operation_drop_rx) = - mpsc::unbounded_channel(); - let client_operation_cleanup_spawn = - spawn!("state_manager_client_drop_rx", async move { - /// Number of events to pull from the stream at a time. - const MAX_DROP_HANDLES_PER_CYCLE: usize = 1024; - let mut dropped_client_ids = Vec::with_capacity(MAX_DROP_HANDLES_PER_CYCLE); - loop { - dropped_client_ids.clear(); - client_operation_drop_rx - .recv_many(&mut dropped_client_ids, MAX_DROP_HANDLES_PER_CYCLE) - .await; - let Some(inner) = weak_inner.upgrade() else { - return; // Nothing to cleanup, our struct is dropped. - }; - let mut inner_mux = inner.lock().await; - inner_mux - .action_db - .on_client_operations_drop(dropped_client_ids.drain(..)); - } - }); - Mutex::new(StateManagerImpl { - action_db: AwaitedActionDb { - client_operation_to_awaited_action: EvictingMap::new( - config, - SystemTime::now(), - ), - operation_id_to_awaited_action: HashMap::new(), - action_info_hash_key_to_awaited_action: HashMap::new(), - sorted_action_info_hash_keys: SortedAwaitedActions::default(), - }, - tasks_change_notify, - max_job_retries, - client_operation_drop_tx, - _client_operation_cleanup_spawn: client_operation_cleanup_spawn, - }) - }), - } - } - - async fn inner_filter_operations( - &self, - filter: &OperationFilter, - to_action_state_result: F, - ) -> Result - where - F: Fn(Arc) -> Arc + Send + Sync + 'static, - { - fn get_tree_for_stage( - action_db: &AwaitedActionDb, - stage: OperationStageFlags, - ) -> Option<&BTreeSet> { - match stage { - OperationStageFlags::CacheCheck => Some(action_db.get_cache_check_actions()), - OperationStageFlags::Queued => Some(action_db.get_queued_actions()), - OperationStageFlags::Executing => Some(action_db.get_executing_actions()), - OperationStageFlags::Completed => Some(action_db.get_completed_actions()), - _ => None, - } - } - - let inner = self.inner.lock().await; - - if let Some(operation_id) = &filter.operation_id { - return Ok(inner - .action_db - .get_by_operation_id(operation_id) - .filter(|awaited_action| filter_check(awaited_action.as_ref(), filter)) - .cloned() - .map(|awaited_action| -> ActionStateResultStream { - Box::pin(stream::once(async move { - to_action_state_result(awaited_action) - })) - }) - .unwrap_or_else(|| Box::pin(stream::empty()))); - } - if let Some(client_operation_id) = &filter.client_operation_id { - return Ok(inner - .action_db - .get_by_client_operation_id(client_operation_id) - .await - .filter(|client_awaited_action| { - filter_check(client_awaited_action.awaited_action().as_ref(), filter) - }) - .map(|client_awaited_action| -> ActionStateResultStream { - Box::pin(stream::once(async move { - to_action_state_result(client_awaited_action.awaited_action().clone()) - })) - }) - .unwrap_or_else(|| Box::pin(stream::empty()))); - } - - if get_tree_for_stage(&inner.action_db, filter.stages).is_none() { - let mut all_items: Vec> = inner - .action_db - .get_all_awaited_actions() - .filter(|awaited_action| filter_check(awaited_action.as_ref(), filter)) - .cloned() - .collect(); - match filter.order_by_priority_direction { - Some(OrderDirection::Asc) => all_items.sort_unstable_by(|a, b| { - a.get_sort_info() - .get_new_sort_key() - .cmp(&b.get_sort_info().get_new_sort_key()) - }), - Some(OrderDirection::Desc) => all_items.sort_unstable_by(|a, b| { - b.get_sort_info() - .get_new_sort_key() - .cmp(&a.get_sort_info().get_new_sort_key()) - }), - None => {} - } - return Ok(Box::pin(stream::iter( - all_items.into_iter().map(to_action_state_result), - ))); - } - - drop(inner); - - struct State< - F: Fn(Arc) -> Arc + Send + Sync + 'static, - > { - inner: Arc>, - filter: OperationFilter, - buffer: VecDeque, - start_key: Bound, - to_action_state_result: F, - } - let state = State { - inner: self.inner.clone(), - filter: filter.clone(), - buffer: VecDeque::new(), - start_key: Bound::Unbounded, - to_action_state_result, - }; - - const STREAM_BUFF_SIZE: usize = 64; - - Ok(Box::pin(unfold(state, move |mut state| async move { - if let Some(sorted_awaited_action) = state.buffer.pop_front() { - if state.buffer.is_empty() { - state.start_key = Bound::Excluded(sorted_awaited_action.clone()); - } - return Some(( - (state.to_action_state_result)(sorted_awaited_action.awaited_action), - state, - )); - } - - let inner = state.inner.lock().await; - - #[allow(clippy::mutable_key_type)] - let btree = get_tree_for_stage(&inner.action_db, state.filter.stages) - .expect("get_tree_for_stage() should have already returned Some but in iteration it returned None"); - - let range = (state.start_key.as_ref(), Bound::Unbounded); - if state.filter.order_by_priority_direction == Some(OrderDirection::Asc) { - btree - .range(range) - .filter(|item| filter_check(item.awaited_action.as_ref(), &state.filter)) - .take(STREAM_BUFF_SIZE) - .for_each(|item| state.buffer.push_back(item.clone())); - } else { - btree - .range(range) - .rev() - .filter(|item| filter_check(item.awaited_action.as_ref(), &state.filter)) - .take(STREAM_BUFF_SIZE) - .for_each(|item| state.buffer.push_back(item.clone())); - } - drop(inner); - let sorted_awaited_action = state.buffer.pop_front()?; - if state.buffer.is_empty() { - state.start_key = Bound::Excluded(sorted_awaited_action.clone()); - } - Some(( - (state.to_action_state_result)(sorted_awaited_action.awaited_action), - state, - )) - }))) - } -} - -/// StateManager is responsible for maintaining the state of the scheduler. Scheduler state -/// includes the actions that are queued, active, and recently completed. It also includes the -/// workers that are available to execute actions based on allocation strategy. -pub(crate) struct StateManagerImpl { - /// Database for storing the state of all actions. - action_db: AwaitedActionDb, - - /// Notify task<->worker matching engine that work needs to be done. - tasks_change_notify: Arc, - - /// Maximum number of times a job can be retried. - max_job_retries: usize, +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, + OperationFilter, OperationStageFlags, OrderDirection, WorkerStateManager, +}; +use nativelink_util::spawn; +use nativelink_util::task::JoinHandleDropGuard; +use tokio::sync::{mpsc, watch, Notify}; +use tracing::{event, Level}; - /// Channel to notify when a client operation id is dropped. - client_operation_drop_tx: mpsc::UnboundedSender>, +use super::awaited_action_db::{AwaitedAction, AwaitedActionDb, SortedAwaitedAction}; +use super::client_action_state_result::ClientActionStateResult; +use super::matching_engine_action_state_result::MatchingEngineActionStateResult; - /// Task to cleanup client operation ids that are no longer being listened to. - // Note: This has a custom Drop function on it. It should stay alive only while - // the StateManager is alive. - _client_operation_cleanup_spawn: JoinHandleDropGuard<()>, -} +/// How often the owning database will have the AwaitedAction touched +/// to keep it from being evicted. +const KEEPALIVE_DURATION: Duration = Duration::from_secs(10); fn filter_check(awaited_action: &AwaitedAction, filter: &OperationFilter) -> bool { // Note: The caller must filter `client_operation_id`. @@ -784,7 +108,57 @@ fn filter_check(awaited_action: &AwaitedAction, filter: &OperationFilter) -> boo true } -impl StateManagerImpl { +/// Utility struct to create a background task that keeps the client operation id alive. +fn make_client_keepalive_spawn( + client_operation_id: ClientOperationId, + inner_weak: Weak>, +) -> JoinHandleDropGuard<()> { + spawn!("client_action_state_result_keepalive", async move { + loop { + tokio::time::sleep(KEEPALIVE_DURATION).await; + let Some(inner) = inner_weak.upgrade() else { + return; // Nothing to do. + }; + let inner = inner.lock().await; + let refresh_success = inner + .action_db + .refresh_client_operation_id(&client_operation_id) + .await; + if !refresh_success { + event! { + Level::ERROR, + ?client_operation_id, + "Client operation id not found in MemorySchedulerStateManager::add_action keepalive" + }; + } + } + }) +} + +/// MemorySchedulerStateManager is responsible for maintaining the state of the scheduler. +/// Scheduler state includes the actions that are queued, active, and recently completed. +/// It also includes the workers that are available to execute actions based on allocation +/// strategy. +struct MemorySchedulerStateManagerImpl { + /// Database for storing the state of all actions. + action_db: AwaitedActionDb, + + /// Notify task<->worker matching engine that work needs to be done. + tasks_change_notify: Arc, + + /// Maximum number of times a job can be retried. + max_job_retries: usize, + + /// Channel to notify when a client operation id is dropped. + client_operation_drop_tx: mpsc::UnboundedSender>, + + /// Task to cleanup client operation ids that are no longer being listened to. + // Note: This has a custom Drop function on it. It should stay alive only while + // the MemorySchedulerStateManager is alive. + _client_operation_cleanup_spawn: JoinHandleDropGuard<()>, +} + +impl MemorySchedulerStateManagerImpl { fn inner_update_operation( &mut self, operation_id: &OperationId, @@ -797,7 +171,7 @@ impl StateManagerImpl { .ok_or_else(|| { make_err!( Code::Internal, - "Could not find action info StateManager::update_operation" + "Could not find action info MemorySchedulerStateManager::update_operation" ) })? .clone(); @@ -909,35 +283,197 @@ impl StateManagerImpl { } } -/// Utility struct to create a background task that keeps the client operation id alive. -fn make_client_keepalive_spawn( - client_operation_id: ClientOperationId, - inner_weak: Weak>, -) -> JoinHandleDropGuard<()> { - spawn!("client_action_state_result_keepalive", async move { - loop { - tokio::time::sleep(KEEPALIVE_DURATION).await; - let Some(inner) = inner_weak.upgrade() else { - return; // Nothing to do. - }; - let inner = inner.lock().await; - let refresh_success = inner +#[repr(transparent)] +pub struct MemorySchedulerStateManager { + inner: Arc>, +} + +impl MemorySchedulerStateManager { + pub fn new( + eviction_config: &EvictionPolicy, + tasks_change_notify: Arc, + max_job_retries: usize, + ) -> Self { + Self { + inner: Arc::new_cyclic(move |weak_self| -> Mutex { + let weak_inner = weak_self.clone(); + let (client_operation_drop_tx, mut client_operation_drop_rx) = + mpsc::unbounded_channel(); + let client_operation_cleanup_spawn = + spawn!("state_manager_client_drop_rx", async move { + /// Number of events to pull from the stream at a time. + const MAX_DROP_HANDLES_PER_CYCLE: usize = 1024; + let mut dropped_client_ids = Vec::with_capacity(MAX_DROP_HANDLES_PER_CYCLE); + loop { + dropped_client_ids.clear(); + client_operation_drop_rx + .recv_many(&mut dropped_client_ids, MAX_DROP_HANDLES_PER_CYCLE) + .await; + let Some(inner) = weak_inner.upgrade() else { + return; // Nothing to cleanup, our struct is dropped. + }; + let mut inner_mux = inner.lock().await; + inner_mux + .action_db + .on_client_operations_drop(dropped_client_ids.drain(..)); + } + }); + Mutex::new(MemorySchedulerStateManagerImpl { + action_db: AwaitedActionDb::new(eviction_config), + tasks_change_notify, + max_job_retries, + client_operation_drop_tx, + _client_operation_cleanup_spawn: client_operation_cleanup_spawn, + }) + }), + } + } + + async fn inner_filter_operations( + &self, + filter: &OperationFilter, + to_action_state_result: F, + ) -> Result + where + F: Fn(Arc) -> Arc + Send + Sync + 'static, + { + fn get_tree_for_stage( + action_db: &AwaitedActionDb, + stage: OperationStageFlags, + ) -> Option<&BTreeSet> { + match stage { + OperationStageFlags::CacheCheck => Some(action_db.get_cache_check_actions()), + OperationStageFlags::Queued => Some(action_db.get_queued_actions()), + OperationStageFlags::Executing => Some(action_db.get_executing_actions()), + OperationStageFlags::Completed => Some(action_db.get_completed_actions()), + _ => None, + } + } + + let inner = self.inner.lock().await; + + if let Some(operation_id) = &filter.operation_id { + return Ok(inner .action_db - .refresh_client_operation_id(&client_operation_id) - .await; - if !refresh_success { - event! { - Level::ERROR, - ?client_operation_id, - "Client operation id not found in StateManager::add_action keepalive" - }; + .get_by_operation_id(operation_id) + .filter(|awaited_action| filter_check(awaited_action.as_ref(), filter)) + .cloned() + .map(|awaited_action| -> ActionStateResultStream { + Box::pin(stream::once(async move { + to_action_state_result(awaited_action) + })) + }) + .unwrap_or_else(|| Box::pin(stream::empty()))); + } + if let Some(client_operation_id) = &filter.client_operation_id { + return Ok(inner + .action_db + .get_by_client_operation_id(client_operation_id) + .await + .filter(|client_awaited_action| { + filter_check(client_awaited_action.awaited_action().as_ref(), filter) + }) + .map(|client_awaited_action| -> ActionStateResultStream { + Box::pin(stream::once(async move { + to_action_state_result(client_awaited_action.awaited_action().clone()) + })) + }) + .unwrap_or_else(|| Box::pin(stream::empty()))); + } + + if get_tree_for_stage(&inner.action_db, filter.stages).is_none() { + let mut all_items: Vec> = inner + .action_db + .get_all_awaited_actions() + .filter(|awaited_action| filter_check(awaited_action.as_ref(), filter)) + .cloned() + .collect(); + match filter.order_by_priority_direction { + Some(OrderDirection::Asc) => all_items.sort_unstable_by(|a, b| { + a.get_sort_info() + .get_new_sort_key() + .cmp(&b.get_sort_info().get_new_sort_key()) + }), + Some(OrderDirection::Desc) => all_items.sort_unstable_by(|a, b| { + b.get_sort_info() + .get_new_sort_key() + .cmp(&a.get_sort_info().get_new_sort_key()) + }), + None => {} } + return Ok(Box::pin(stream::iter( + all_items.into_iter().map(to_action_state_result), + ))); } - }) + + drop(inner); + + struct State< + F: Fn(Arc) -> Arc + Send + Sync + 'static, + > { + inner: Arc>, + filter: OperationFilter, + buffer: VecDeque, + start_key: Bound, + to_action_state_result: F, + } + let state = State { + inner: self.inner.clone(), + filter: filter.clone(), + buffer: VecDeque::new(), + start_key: Bound::Unbounded, + to_action_state_result, + }; + + const STREAM_BUFF_SIZE: usize = 64; + + Ok(Box::pin(unfold(state, move |mut state| async move { + if let Some(sorted_awaited_action) = state.buffer.pop_front() { + if state.buffer.is_empty() { + state.start_key = Bound::Excluded(sorted_awaited_action.clone()); + } + return Some(( + (state.to_action_state_result)(sorted_awaited_action.awaited_action), + state, + )); + } + + let inner = state.inner.lock().await; + + #[allow(clippy::mutable_key_type)] + let btree = get_tree_for_stage(&inner.action_db, state.filter.stages) + .expect("get_tree_for_stage() should have already returned Some but in iteration it returned None"); + + let range = (state.start_key.as_ref(), Bound::Unbounded); + if state.filter.order_by_priority_direction == Some(OrderDirection::Asc) { + btree + .range(range) + .filter(|item| filter_check(item.awaited_action.as_ref(), &state.filter)) + .take(STREAM_BUFF_SIZE) + .for_each(|item| state.buffer.push_back(item.clone())); + } else { + btree + .range(range) + .rev() + .filter(|item| filter_check(item.awaited_action.as_ref(), &state.filter)) + .take(STREAM_BUFF_SIZE) + .for_each(|item| state.buffer.push_back(item.clone())); + } + drop(inner); + let sorted_awaited_action = state.buffer.pop_front()?; + if state.buffer.is_empty() { + state.start_key = Bound::Excluded(sorted_awaited_action.clone()); + } + Some(( + (state.to_action_state_result)(sorted_awaited_action.awaited_action), + state, + )) + }))) + } } #[async_trait] -impl ClientStateManager for StateManager { +impl ClientStateManager for MemorySchedulerStateManager { async fn add_action( &self, client_operation_id: ClientOperationId, @@ -978,7 +514,7 @@ impl ClientStateManager for StateManager { } #[async_trait] -impl WorkerStateManager for StateManager { +impl WorkerStateManager for MemorySchedulerStateManager { async fn update_operation( &self, operation_id: &OperationId, @@ -991,7 +527,7 @@ impl WorkerStateManager for StateManager { } #[async_trait] -impl MatchingEngineStateManager for StateManager { +impl MatchingEngineStateManager for MemorySchedulerStateManager { async fn filter_operations( &self, filter: &OperationFilter, diff --git a/nativelink-scheduler/src/scheduler_state/mod.rs b/nativelink-scheduler/src/memory_scheduler_state/mod.rs similarity index 74% rename from nativelink-scheduler/src/scheduler_state/mod.rs rename to nativelink-scheduler/src/memory_scheduler_state/mod.rs index 01fb387bb..a06f17897 100644 --- a/nativelink-scheduler/src/scheduler_state/mod.rs +++ b/nativelink-scheduler/src/memory_scheduler_state/mod.rs @@ -12,8 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -pub(crate) mod awaited_action; -pub(crate) mod client_action_state_result; -pub(crate) mod matching_engine_action_state_result; -pub mod state_manager; -pub mod workers; +mod awaited_action_db; +mod client_action_state_result; +mod matching_engine_action_state_result; +mod memory_scheduler_state_manager; + +pub(crate) use memory_scheduler_state_manager::MemorySchedulerStateManager; diff --git a/nativelink-scheduler/src/redis_action_stage.rs b/nativelink-scheduler/src/redis_action_stage.rs index 3176c7324..24b30bbd0 100644 --- a/nativelink-scheduler/src/redis_action_stage.rs +++ b/nativelink-scheduler/src/redis_action_stage.rs @@ -14,10 +14,9 @@ use nativelink_error::{make_input_err, Error, ResultExt}; use nativelink_util::action_messages::{ActionResult, ActionStage}; +use nativelink_util::operation_state_manager::OperationStageFlags; use serde::{Deserialize, Serialize}; -use crate::operation_state_manager::OperationStageFlags; - #[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] pub enum RedisOperationStage { CacheCheck, diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 9a90ed510..622037362 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -23,6 +23,10 @@ use nativelink_util::action_messages::{ ActionInfo, ActionStage, ActionState, ClientOperationId, OperationId, WorkerId, }; use nativelink_util::metrics_utils::Registry; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ClientStateManager, MatchingEngineStateManager, OperationFilter, + OperationStageFlags, +}; use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; use tokio::sync::{watch, Notify}; @@ -31,13 +35,9 @@ use tokio_stream::StreamExt; use tracing::{event, Level}; use crate::action_scheduler::{ActionListener, ActionScheduler}; -use crate::operation_state_manager::{ - ActionStateResult, ClientStateManager, MatchingEngineStateManager, OperationFilter, - OperationStageFlags, -}; +use crate::api_worker_scheduler::ApiWorkerScheduler; +use crate::memory_scheduler_state::MemorySchedulerStateManager; use crate::platform_property_manager::PlatformPropertyManager; -use crate::scheduler_state::state_manager::StateManager; -use crate::scheduler_state::workers::ApiWorkerScheduler; use crate::worker::{Worker, WorkerTimestamp}; use crate::worker_scheduler::WorkerScheduler; @@ -308,7 +308,7 @@ impl SimpleScheduler { } let tasks_or_worker_change_notify = Arc::new(Notify::new()); - let state_manager = Arc::new(StateManager::new( + let state_manager = Arc::new(MemorySchedulerStateManager::new( &EvictionPolicy { max_seconds: CLIENT_EVICTION_SECONDS, ..Default::default() diff --git a/nativelink-service/tests/worker_api_server_test.rs b/nativelink-service/tests/worker_api_server_test.rs index 8ca1ec2cc..e0c28cef5 100644 --- a/nativelink-service/tests/worker_api_server_test.rs +++ b/nativelink-service/tests/worker_api_server_test.rs @@ -31,9 +31,8 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: execute_result, update_for_worker, ExecuteResult, KeepAliveRequest, SupportedProperties, }; use nativelink_proto::google::rpc::Status as ProtoStatus; -use nativelink_scheduler::operation_state_manager::WorkerStateManager; +use nativelink_scheduler::api_worker_scheduler::ApiWorkerScheduler; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; -use nativelink_scheduler::scheduler_state::workers::ApiWorkerScheduler; use nativelink_scheduler::worker_scheduler::WorkerScheduler; use nativelink_service::worker_api_server::{ConnectWorkerStream, NowFn, WorkerApiServer}; use nativelink_util::action_messages::{ @@ -41,6 +40,7 @@ use nativelink_util::action_messages::{ }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; +use nativelink_util::operation_state_manager::WorkerStateManager; use nativelink_util::platform_properties::PlatformProperties; use pretty_assertions::assert_eq; use tokio::join; diff --git a/nativelink-util/Cargo.toml b/nativelink-util/Cargo.toml index 679bc51fc..2adfe7409 100644 --- a/nativelink-util/Cargo.toml +++ b/nativelink-util/Cargo.toml @@ -10,6 +10,7 @@ nativelink-proto = { path = "../nativelink-proto" } async-lock = "3.3.0" async-trait = "0.1.80" +bitflags = "2.5.0" blake3 = { version = "1.5.1", features = ["mmap"] } bytes = "1.6.0" console-subscriber = { version = "0.3.0" } diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 717985274..e735c86d0 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -23,6 +23,7 @@ pub mod fastcdc; pub mod fs; pub mod health_utils; pub mod metrics_utils; +pub mod operation_state_manager; pub mod origin_context; pub mod platform_properties; pub mod proto_stream_utils; diff --git a/nativelink-scheduler/src/operation_state_manager.rs b/nativelink-util/src/operation_state_manager.rs similarity index 98% rename from nativelink-scheduler/src/operation_state_manager.rs rename to nativelink-util/src/operation_state_manager.rs index 2ee6b8cde..090b6e8b0 100644 --- a/nativelink-scheduler/src/operation_state_manager.rs +++ b/nativelink-util/src/operation_state_manager.rs @@ -17,14 +17,14 @@ use std::pin::Pin; use std::sync::Arc; use std::time::SystemTime; +use crate::action_messages::{ + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ClientOperationId, OperationId, WorkerId, +}; +use crate::common::DigestInfo; use async_trait::async_trait; use bitflags::bitflags; use futures::Stream; use nativelink_error::Error; -use nativelink_util::action_messages::{ - ActionInfo, ActionStage, ActionState, ActionUniqueKey, ClientOperationId, OperationId, WorkerId, -}; -use nativelink_util::common::DigestInfo; use tokio::sync::watch; bitflags! {