diff --git a/nativelink-scheduler/src/scheduler_state/awaited_action.rs b/nativelink-scheduler/src/scheduler_state/awaited_action.rs index 285f740d7..47f0bcfaf 100644 --- a/nativelink-scheduler/src/scheduler_state/awaited_action.rs +++ b/nativelink-scheduler/src/scheduler_state/awaited_action.rs @@ -68,7 +68,7 @@ pub struct AwaitedAction { action_info: Arc, // The unique identifier of the operation. - // TODO(operation_id should be stored here). + // TODO!(operation_id should be stored here). // operation_id: OperationId, /// The data that is used to sort the action in the queue. /// The first item in the tuple is the current priority, diff --git a/nativelink-scheduler/src/scheduler_state/state_manager.rs b/nativelink-scheduler/src/scheduler_state/state_manager.rs index 0a69c9fcb..0edd7a312 100644 --- a/nativelink-scheduler/src/scheduler_state/state_manager.rs +++ b/nativelink-scheduler/src/scheduler_state/state_manager.rs @@ -29,9 +29,9 @@ use nativelink_util::action_messages::{ ClientOperationId, ExecutionMetadata, OperationId, WorkerId, }; use nativelink_util::evicting_map::{EvictingMap, LenEntry}; +use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; -use nativelink_util::{background_spawn, spawn}; -use tokio::sync::{watch, Notify}; +use tokio::sync::{mpsc, watch, Notify}; use tracing::{event, Level}; use super::awaited_action::AwaitedActionSortKey; @@ -85,55 +85,50 @@ struct SortedAwaitedActions { completed_from_cache: BTreeSet, } +/// Represents a client that is currently listening to an action. +/// When the client is dropped, it will send the [`AwaitedAction`] to the +/// `client_operation_drop_tx` if there are other cleanups needed. #[derive(Debug)] struct ClientAwaitedAction { - /// A weak reference to the owning StateManagerImpl. - state_manager_impl: Weak>, - - /// The client operation id that is listening to the action. + /// The awaited action that the client is listening to. // Note: This is an Option because it is taken when the // ClientAwaitedAction is dropped, but will never actually be // None except during the drop. - client_operation_id: Option, + awaited_action: Option>, - /// The awaited action that the client is listening to. - awaited_action: Arc, + /// The sender to notify of this struct being dropped. + client_operation_drop_tx: mpsc::UnboundedSender>, } impl ClientAwaitedAction { fn new( - state_manager_impl: Weak>, - client_operation_id: Option, awaited_action: Arc, + client_operation_drop_tx: mpsc::UnboundedSender>, ) -> Self { awaited_action.inc_listening_clients(); Self { - state_manager_impl, - client_operation_id, - awaited_action, + awaited_action: Some(awaited_action), + client_operation_drop_tx, } } + + /// Returns the awaited action that the client is listening to. + fn awaited_action(&self) -> &Arc { + self.awaited_action + .as_ref() + .expect("AwaitedAction should be present") + } } impl Drop for ClientAwaitedAction { fn drop(&mut self) { - let Some(inner) = self.state_manager_impl.upgrade() else { - return; // Nothing to do, since the StateManagerImpl is already dropped. - }; - let client_operation_id = self - .client_operation_id + let awaited_action = self + .awaited_action .take() - .expect("Operation Id should be present"); - let awaited_action = self.awaited_action.clone(); - - // We must spawn the cleanup in the background so we can use await. - background_spawn!("client_awaited_action_drop", async move { - let mut inner = inner.lock().await; - awaited_action.dec_listening_clients(); - inner - .on_client_struct_dropped(&client_operation_id, awaited_action) - .await; - }); + .expect("AwaitedAction should be present"); + awaited_action.dec_listening_clients(); + // If we failed to send it means noone is listening. + let _ = self.client_operation_drop_tx.send(awaited_action); } } @@ -194,90 +189,75 @@ impl AwaitedActionDb { .await } - /// Removes the client operation id from the database and cleanup entry from other maps. - /// Returns `true` if the client operation id was found and removed. - async fn remove_client_operation_id( + /// When a client operation is dropped, we need to remove it from the + /// other maps and update the listening clients count on the [`AwaitedAction`]. + fn on_client_operations_drop( &mut self, - client_operation_id: &ClientOperationId, - awaited_action: Arc, - ) -> bool { - let did_remove = self - .client_operation_to_awaited_action - .remove(client_operation_id) - .await; - if !did_remove { - // Note: This might be very noisy, but we will leave it in for now to help - // with debugging. In the event it is too noisy we can downgrade the level. - event!( - Level::ERROR, - ?client_operation_id, - ?awaited_action, - "Client operation id not found in StateManager::remove_client_operation_id" - ); - } - if awaited_action.get_listening_clients() != 0 { - // We still have other clients listening to this action. - return did_remove; - } + awaited_actions: impl IntoIterator>, + ) { + for awaited_action in awaited_actions.into_iter() { + if awaited_action.get_listening_clients() != 0 { + // We still have other clients listening to this action. + continue; + } - let operation_id = awaited_action.get_operation_id(); + let operation_id = awaited_action.get_operation_id(); - // Cleanup operation_id_to_awaited_action. - if self - .operation_id_to_awaited_action - .remove(&operation_id) - .is_none() - { - event!( - Level::ERROR, - ?client_operation_id, - ?operation_id, - ?awaited_action, - "operation_id_to_awaited_action and client_operation_to_awaited_action are out of sync", - ); - } + // Cleanup operation_id_to_awaited_action. + if self + .operation_id_to_awaited_action + .remove(&operation_id) + .is_none() + { + event!( + Level::ERROR, + ?operation_id, + ?awaited_action, + "operation_id_to_awaited_action and client_operation_to_awaited_action are out of sync", + ); + } - // Cleanup action_info_hash_key_to_awaited_action if it was marked cached. - let action_info = awaited_action.get_action_info(); - match &action_info.unique_qualifier { - ActionUniqueQualifier::Cachable(action_key) => { - let maybe_awaited_action = self - .action_info_hash_key_to_awaited_action - .remove(action_key); - if maybe_awaited_action.is_none() { - event!( - Level::ERROR, - ?operation_id, - ?awaited_action, - ?action_key, - "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", - ); + // Cleanup action_info_hash_key_to_awaited_action if it was marked cached. + let action_info = awaited_action.get_action_info(); + match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(action_key) => { + let maybe_awaited_action = self + .action_info_hash_key_to_awaited_action + .remove(action_key); + if maybe_awaited_action.is_none() { + event!( + Level::ERROR, + ?operation_id, + ?awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", + ); + } + } + ActionUniqueQualifier::Uncachable(_action_key) => { + // This Operation should not be in the hash_key map. } } - ActionUniqueQualifier::Uncachable(_action_key) => { - // This Operation should not be in the hash_key map. - } - } - // Cleanup sorted_awaited_action. - let sort_info = awaited_action.get_sort_info(); - let sort_key = sort_info.get_previous_sort_key(); - let sort_map_for_state = - self.get_sort_map_for_state(&awaited_action.get_current_state().stage); - drop(sort_info); - let maybe_sorted_awaited_action = sort_map_for_state.take(&SortedAwaitedAction { - sort_key, - awaited_action, - }); - if maybe_sorted_awaited_action.is_none() { - event!( - Level::ERROR, - ?operation_id, - ?sort_key, - "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync", - ); + // Cleanup sorted_awaited_action. + let sort_info = awaited_action.get_sort_info(); + let sort_key = sort_info.get_previous_sort_key(); + let sort_map_for_state = + self.get_sort_map_for_state(&awaited_action.get_current_state().stage); + drop(sort_info); + let maybe_sorted_awaited_action = sort_map_for_state.take(&SortedAwaitedAction { + sort_key, + awaited_action, + }); + if maybe_sorted_awaited_action.is_none() { + event!( + Level::ERROR, + ?operation_id, + ?sort_key, + "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync", + ); + } } - did_remove } fn get_all_awaited_actions(&self) -> impl Iterator> { @@ -398,21 +378,22 @@ impl AwaitedActionDb { async fn subscribe_or_add_action( &mut self, - state_manager_impl: &Weak>, client_operation_id: ClientOperationId, action_info: Arc, + client_operation_drop_tx: &mpsc::UnboundedSender>, ) -> watch::Receiver> { // Check to see if the action is already known and subscribe if it is. let subscription_result = self .try_subscribe( - state_manager_impl, &client_operation_id, &action_info.unique_qualifier, action_info.priority, + client_operation_drop_tx, ) .await; let action_info = match subscription_result { Ok(subscription) => return subscription, + // TODO!(we should not ignore the error here.) Err(_) => action_info, }; @@ -425,11 +406,10 @@ impl AwaitedActionDb { let awaited_action = Arc::new(awaited_action); self.client_operation_to_awaited_action .insert( - client_operation_id.clone(), + client_operation_id, Arc::new(ClientAwaitedAction::new( - state_manager_impl.clone(), - Some(client_operation_id), awaited_action.clone(), + client_operation_drop_tx.clone(), )), ) .await; @@ -453,10 +433,10 @@ impl AwaitedActionDb { async fn try_subscribe( &mut self, - state_manager_impl: &Weak>, client_operation_id: &ClientOperationId, unique_qualifier: &ActionUniqueQualifier, priority: i32, + client_operation_drop_tx: &mpsc::UnboundedSender>, ) -> Result>, Error> { let unique_key = match unique_qualifier { ActionUniqueQualifier::Cachable(unique_key) => unique_key, @@ -493,7 +473,7 @@ impl AwaitedActionDb { awaited_action: awaited_action.clone(), }); let Some(mut sorted_awaited_action) = maybe_sorted_awaited_action else { - // TODO: Either use event on all of the above error here, but both is overkill. + // TODO!(Either use event on all of the above error here, but both is overkill). let err = make_err!( Code::Internal, "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync"); @@ -509,11 +489,10 @@ impl AwaitedActionDb { self.client_operation_to_awaited_action .insert( client_operation_id.clone(), - Arc::new(ClientAwaitedAction { - state_manager_impl: state_manager_impl.clone(), - client_operation_id: Some(client_operation_id.clone()), + Arc::new(ClientAwaitedAction::new( awaited_action, - }), + client_operation_drop_tx.clone(), + )), ) .await; @@ -533,9 +512,30 @@ impl StateManager { max_job_retries: usize, ) -> Self { Self { - inner: Arc::new_cyclic(move |weak_self| { + inner: Arc::new_cyclic(move |weak_self| -> Mutex { + let weak_inner = weak_self.clone(); + let (client_operation_drop_tx, mut client_operation_drop_rx) = + mpsc::unbounded_channel(); + let client_operation_cleanup_spawn = + spawn!("state_manager_client_drop_rx", async move { + /// Number of events to pull from the stream at a time. + const MAX_DROP_HANDLES_PER_CYCLE: usize = 1024; + let mut dropped_client_ids = Vec::with_capacity(MAX_DROP_HANDLES_PER_CYCLE); + loop { + dropped_client_ids.clear(); + client_operation_drop_rx + .recv_many(&mut dropped_client_ids, MAX_DROP_HANDLES_PER_CYCLE) + .await; + let Some(inner) = weak_inner.upgrade() else { + return; // Nothing to cleanup, our struct is dropped. + }; + let mut inner_mux = inner.lock().await; + inner_mux + .action_db + .on_client_operations_drop(dropped_client_ids.drain(..)); + } + }); Mutex::new(StateManagerImpl { - weak_self: weak_self.clone(), action_db: AwaitedActionDb { client_operation_to_awaited_action: EvictingMap::new( config, @@ -547,6 +547,8 @@ impl StateManager { }, tasks_change_notify, max_job_retries, + client_operation_drop_tx, + _client_operation_cleanup_spawn: client_operation_cleanup_spawn, }) }), } @@ -594,11 +596,11 @@ impl StateManager { .get_by_client_operation_id(client_operation_id) .await .filter(|client_awaited_action| { - filter_check(client_awaited_action.awaited_action.as_ref(), filter) + filter_check(client_awaited_action.awaited_action().as_ref(), filter) }) .map(|client_awaited_action| -> ActionStateResultStream { Box::pin(stream::once(async move { - to_action_state_result(client_awaited_action.awaited_action.clone()) + to_action_state_result(client_awaited_action.awaited_action().clone()) })) }) .unwrap_or_else(|| Box::pin(stream::empty()))); @@ -699,14 +701,22 @@ impl StateManager { /// includes the actions that are queued, active, and recently completed. It also includes the /// workers that are available to execute actions based on allocation strategy. pub(crate) struct StateManagerImpl { - weak_self: Weak>, - + /// Database for storing the state of all actions. action_db: AwaitedActionDb, /// Notify task<->worker matching engine that work needs to be done. tasks_change_notify: Arc, + /// Maximum number of times a job can be retried. max_job_retries: usize, + + /// Channel to notify when a client operation id is dropped. + client_operation_drop_tx: mpsc::UnboundedSender>, + + /// Task to cleanup client operation ids that are no longer being listened to. + // Note: This has a custom Drop function on it. It should stay alive only while + // the StateManager is alive. + _client_operation_cleanup_spawn: JoinHandleDropGuard<()>, } fn filter_check(awaited_action: &AwaitedAction, filter: &OperationFilter) -> bool { @@ -888,33 +898,15 @@ impl StateManagerImpl { ) -> Result>, Error> { let rx = self .action_db - .subscribe_or_add_action(&self.weak_self, new_client_operation_id, action_info) + .subscribe_or_add_action( + new_client_operation_id, + action_info, + &self.client_operation_drop_tx, + ) .await; self.tasks_change_notify.notify_one(); Ok(rx) } - - /// Called when the client struct is dropped. This will remove the client operation id - /// from the database and cleanup the entry from other maps. - /// This is not called a client disconnects, but rather when the EvictionMap drops the - /// struct. - async fn on_client_struct_dropped( - &mut self, - client_operation_id: &ClientOperationId, - awaited_action: Arc, - ) { - let did_remove_operation_id = self - .action_db - .remove_client_operation_id(client_operation_id, awaited_action) - .await; - if !did_remove_operation_id { - event!( - Level::ERROR, - ?client_operation_id, - "Client operation id not found in StateManager::on_client_struct_dropped" - ); - } - } } /// Utility struct to create a background task that keeps the client operation id alive.