From 3b8c3a583b7df12bddba188fe2df221523c6b0f5 Mon Sep 17 00:00:00 2001 From: "Nathan (Blaise) Bruer" Date: Wed, 17 Jul 2024 11:13:20 -0500 Subject: [PATCH] [Refactor] Overhaul of scheduler component (#1169) This is a significant overhaul to Nativelink's scheduler component. This new scheduler design is to enable a distributed scheduling system. The new components & definitions: * AwaitedActionDb - An interface that is easier to work with when dealing with key-value storage systems. * MemoryAwaitedActionDb - An in-memory set of hashmaps & btrees used to satisfy the requirements of AwaitedActionDb interface. * ClientStateManager - A minimal interface required to satisfy the requirements of a client-facing scheduler. * WorkerStateManager - A minimal interface required to satisfy the requirements of a worker-facing scheduler. * MatchingEngineStateManager - A minimal interface required to satisfy a engine that matches queued jobs to workers. * SimpleSchedulerStateManager - An implements that satisfies ClientStateManager, WorkerStateManager & MatchingEngineStateManager with all the logic of the previous "SimpleScheduler" logic moved behind each interface. * ApiWorkerScheduler - A component that handles all knowledge about workers state and implmenets the WorkerScheduler interface and translates them into the WorkerStateManager interface. * SimpleScheduler - Translation calls of the ClientScheduler interface into ClientStateManager & MatchingEngineStateManager. This component is currently always forwards calls to SimpleSchedulerStateManager then to MemoryAwaitedActionDb. Future changes will make these inner components dynamic via config. In addition we have hardened the interactions of different kind of IDs in NativeLink. Most relevant is the separation & introduction of: * OperationId - Represents an individual operation being requested to be executed that is unique across all of time. * ClientOperationId - An ID issued to the client when the client requests to execute a job. This ID will point to an OperationId internally, but the client is never exposed to the OperationId. * AwaitedActionHashKey - A key used to uniquely identify an action that is not unique across time. This means that this key might have multiple OperationId's that have executed it across different points in time. This key is used as a "fingerprint" of an operation that the client wants to execute and the scheduler may decide to join the stream onto an existing operation if this key has a hit. Overall these changes pave the way for more robust scheduler implementations, most notably, distributed scheduler implementations will be easier to implement and will be introduced in followup PRs. This commit was developed on a side branch and consisted of the following commits with corresponding code reviews: 54ed73cc6fcaf9236a544e45a90b9071b5df6ad8 Add scheduler metrics back (#1171) 50fdbd73983dfdbd2620ac6bfcd5c7b642b5705e fix formatting (#1170) 89262366544d2310e0e9a5bd91c99297112d4e27 Merge in main and format (#1168) 9c2c7b9d3a2bfd71615fe46893ac937e6b23974c key as u64 (#1166) 0192051fca6e3509a6b2953ab53f77c12b4c7092 Cleanup unused code and comments (#1165) 080df5d0c559055d12ac53275024d32e35077675 Add versioning to AwaitedAction (#1163) 73c19c41e31fef55a3c6e81beb19a7233f0c2e8d Fix sequence bug in new memory store manager (#1162) 6e50d2c16fad92f2f8fb856b199f7798c30d5967 New AwaitedActionDb implementation (#1157) 18db991e3a414c3f95c2dd5a14bfafce80358de9 Fix test on running_actions_manager_test (#1141) e50ef3c27d16fc840fd63b917f39f1c317f48398 Rename workers to `worker_scheduler` 1fdd5056f357124d6aff63fe145821b795cf9780 SimpleScheduler now uses config for action pruning (#1137) eaaa872bc62a952fb926214a6fb11bd4ffd25bae Change encoding for items that are cachable (#1136) d64705627c045d108340546a3ea135edb984163a Errors are now properly handles in subscription (#1135) 7c3e730e30e96fb5d08e4d43cca8d27593bcea03 Restructure files to be more appropriate (#1131) 5e98ec9f67f56744376e060b6aa423f1ff401a47 ClientAwaitedAction now uses a channel to notify drops happened (#1130) 52beaf9c727cd396ce4b5b4095bc8dc7be8ce940 Cleanup unused structs (#1128) e86fe0826b3512e78804744a8f1db2ec4c5b21ea Remove all uses of salt and put under ActionUniqueQualifier (#1126) 3b860367ba131c918127b6426e791b9ca0ca8b6b Remove all need for workers to know about ActionId (#1125) 5482d7f5e5f84c2ea80d79c703624c73b6ee36cc Fix bazel build and test on dev (#1123) ba52c7fb76bba56ff5485b91cd516712f867dc64 Implement get_action_info to all ActionStateResult impls (#1118) 2fa4fee48374ed0cf12dcd88225d6d208c0eeeb2 Remove MatchingEngineStateManager::remove_operation (#1119) 34dea0633f8b5198d9c51904beee8f4c617fc822 Remove unused proto field (#1117) 3070a40935c1a490b530af8551329c44bb4500ff Remove metrics from new scheduler (#1116) e95adfcd6035d3cecac74e67829b83e44a598228 StateManager will now cleanup actions on client disconnect (#1107) 6f8c001d1a96a0dc74a5c22716ba2f54e4c9f1b1 Fix worker execution issues (#1114) d353c3003437609547c5dcc8738637b13151a9b5 rename set_priority to upgrade_priority (#1112) 0d93671ba51be41a1a6b0297260cfc5ca0766ebc StateManager can now be notified of noone listeneing (#1093) cfc0cf62c673dc399938247b0c3abe55863c2687 ActionScheduler will now use ActionListener instead of tokio::watch (#1091) d70d31d74cdc337861ebddd8d19201b088b8faa6 QA fixes for scheduler-v2 (#1092) f2cea0cc3fcc259b9a4b7257f2eae807d7e99d05 [Refactor] Complete rewrite of SimpleScheduler 34d93b7d1cda7c0f4a430364b7588002ab612494 [Refactor] Move worker notification in SimpleScheduler under Workers b9d970264b59201e13053844a277092164c8416a [Refactor] Moves worker logic back to SimpleScheduler 7a16e2e6043b17e7813e41450b4de9c40de435f4 [Refactor] Move scheduler state behind mute --- Cargo.lock | 40 +- nativelink-config/src/schedulers.rs | 2 +- .../remote_execution/worker_api.proto | 41 +- ..._machina.nativelink.remote_execution.pb.rs | 46 +- nativelink-scheduler/BUILD.bazel | 19 +- nativelink-scheduler/Cargo.toml | 2 +- nativelink-scheduler/src/action_scheduler.rs | 29 +- .../src/api_worker_scheduler.rs | 477 ++++++++ .../src/awaited_action_db/awaited_action.rs | 191 ++++ .../src/awaited_action_db/mod.rs | 121 ++ .../src/cache_lookup_scheduler.rs | 299 +++-- .../src/default_action_listener.rs | 77 ++ .../src/default_scheduler_factory.rs | 24 +- nativelink-scheduler/src/grpc_scheduler.rs | 61 +- nativelink-scheduler/src/lib.rs | 9 +- .../src/memory_awaited_action_db.rs | 987 ++++++++++++++++ .../src/property_modifier_scheduler.rs | 29 +- .../src/redis_action_stage.rs | 78 -- .../src/redis_operation_state.rs | 465 -------- .../src/scheduler_state/awaited_action.rs | 67 -- .../client_action_state_result.rs | 51 - .../src/scheduler_state/completed_action.rs | 72 -- .../matching_engine_action_state_result.rs | 53 - .../src/scheduler_state/metrics.rs | 143 --- .../src/scheduler_state/mod.rs | 21 - .../src/scheduler_state/state_manager.rs | 742 ------------ .../src/scheduler_state/workers.rs | 114 -- nativelink-scheduler/src/simple_scheduler.rs | 1015 +++++------------ .../src/simple_scheduler_state_manager.rs | 480 ++++++++ nativelink-scheduler/src/worker.rs | 45 +- nativelink-scheduler/src/worker_scheduler.rs | 8 +- .../tests/action_messages_test.rs | 144 +-- .../tests/cache_lookup_scheduler_test.rs | 51 +- .../tests/property_modifier_scheduler_test.rs | 100 +- .../tests/simple_scheduler_test.rs | 719 ++++++------ .../tests/utils/mock_scheduler.rs | 47 +- .../tests/utils/scheduler_utils.rs | 15 +- nativelink-service/BUILD.bazel | 2 + nativelink-service/Cargo.toml | 2 + nativelink-service/src/execution_server.rs | 146 ++- nativelink-service/src/worker_api_server.rs | 36 +- .../tests/worker_api_server_test.rs | 323 ++++-- nativelink-store/tests/cas_utils_test.rs | 8 +- nativelink-util/BUILD.bazel | 4 + nativelink-util/Cargo.toml | 2 + nativelink-util/src/action_messages.rs | 461 ++++---- nativelink-util/src/chunked_stream.rs | 110 ++ nativelink-util/src/lib.rs | 2 + .../src/operation_state_manager.rs | 115 +- nativelink-util/tests/operation_id_tests.rs | 68 +- nativelink-worker/src/local_worker.rs | 43 +- .../src/running_actions_manager.rs | 101 +- nativelink-worker/tests/local_worker_test.rs | 63 +- .../tests/running_actions_manager_test.rs | 297 ++--- .../utils/mock_running_actions_manager.rs | 30 +- src/bin/nativelink.rs | 2 +- 56 files changed, 4535 insertions(+), 4164 deletions(-) create mode 100644 nativelink-scheduler/src/api_worker_scheduler.rs create mode 100644 nativelink-scheduler/src/awaited_action_db/awaited_action.rs create mode 100644 nativelink-scheduler/src/awaited_action_db/mod.rs create mode 100644 nativelink-scheduler/src/default_action_listener.rs create mode 100644 nativelink-scheduler/src/memory_awaited_action_db.rs delete mode 100644 nativelink-scheduler/src/redis_action_stage.rs delete mode 100644 nativelink-scheduler/src/redis_operation_state.rs delete mode 100644 nativelink-scheduler/src/scheduler_state/awaited_action.rs delete mode 100644 nativelink-scheduler/src/scheduler_state/client_action_state_result.rs delete mode 100644 nativelink-scheduler/src/scheduler_state/completed_action.rs delete mode 100644 nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs delete mode 100644 nativelink-scheduler/src/scheduler_state/metrics.rs delete mode 100644 nativelink-scheduler/src/scheduler_state/mod.rs delete mode 100644 nativelink-scheduler/src/scheduler_state/state_manager.rs delete mode 100644 nativelink-scheduler/src/scheduler_state/workers.rs create mode 100644 nativelink-scheduler/src/simple_scheduler_state_manager.rs create mode 100644 nativelink-util/src/chunked_stream.rs rename {nativelink-scheduler => nativelink-util}/src/operation_state_manager.rs (56%) diff --git a/Cargo.lock b/Cargo.lock index 864aed5e5..2405dcc9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -723,9 +723,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d08263faac5cde2a4d52b513dadb80846023aade56fcd8fc99ba73ba8050e92" +checksum = "e9ec96fe9a81b5e365f9db71fe00edc4fe4ca2cc7dcb7861f0603012a7caa210" dependencies = [ "arrayref", "arrayvec", @@ -825,9 +825,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.2" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47de7e88bbbd467951ae7f5a6f34f70d1b4d9cfce53d5fd70f74ebe118b3db56" +checksum = "324c74f2155653c90b04f25b2a47a8a631360cb908f92a772695f430c7e31052" [[package]] name = "cfg-if" @@ -1951,7 +1951,6 @@ version = "0.4.0" dependencies = [ "async-lock", "async-trait", - "bitflags 2.6.0", "blake3", "futures", "hashbrown 0.14.5", @@ -1971,6 +1970,7 @@ dependencies = [ "scopeguard", "serde", "serde_json", + "static_assertions", "tokio", "tokio-stream", "tonic", @@ -1982,6 +1982,8 @@ dependencies = [ name = "nativelink-service" version = "0.4.0" dependencies = [ + "async-lock", + "async-trait", "bytes", "futures", "hyper 0.14.30", @@ -2064,6 +2066,7 @@ version = "0.4.0" dependencies = [ "async-lock", "async-trait", + "bitflags 2.6.0", "blake3", "bytes", "console-subscriber", @@ -2078,6 +2081,7 @@ dependencies = [ "nativelink-macro", "nativelink-proto", "parking_lot", + "pin-project", "pin-project-lite", "pretty_assertions", "prometheus-client", @@ -2259,7 +2263,7 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.2", + "redox_syscall 0.5.3", "smallvec", "windows-targets 0.52.6", ] @@ -2660,9 +2664,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" dependencies = [ "bitflags 2.6.0", ] @@ -2962,9 +2966,9 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "scc" -version = "2.1.2" +version = "2.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af947d0ca10a2f3e00c7ec1b515b7c83e5cb3fa62d4c11a64301d9eec54440e9" +checksum = "a4465c22496331e20eb047ff46e7366455bc01c0c02015c4a376de0b2cd3a1af" dependencies = [ "sdd", ] @@ -2996,9 +3000,9 @@ dependencies = [ [[package]] name = "sdd" -version = "0.2.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84345e4c9bd703274a082fb80caaa99b7612be48dfaa1dd9266577ec412309d" +checksum = "1e806d6633ef141556fef75e345275e35652e9c045bbbc21e6ecfce3e9aa2638" [[package]] name = "seahash" @@ -3022,9 +3026,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", "core-foundation", @@ -3035,9 +3039,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" dependencies = [ "core-foundation-sys", "libc", @@ -3410,9 +3414,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.38.0" +version = "1.38.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df" dependencies = [ "backtrace", "bytes", diff --git a/nativelink-config/src/schedulers.rs b/nativelink-config/src/schedulers.rs index 1ff707b32..27b7d1728 100644 --- a/nativelink-config/src/schedulers.rs +++ b/nativelink-config/src/schedulers.rs @@ -98,7 +98,7 @@ pub struct SimpleScheduler { /// a WaitExecution is called after the action has completed. /// Default: 60 (seconds) #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] - pub retain_completed_for_s: u64, + pub retain_completed_for_s: u32, /// Remove workers from pool once the worker has not responded in this /// amount of time in seconds. diff --git a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto index 51b81b978..f3adbd3e9 100644 --- a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto +++ b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto @@ -99,25 +99,8 @@ message ExecuteResult { /// that initially sent the job as part of the BRE protocol. string instance_name = 6; - /// The original execution digest request for this response. The scheduler knows what it - /// should be, but we do safety checks to ensure it really is the request we expected. - build.bazel.remote.execution.v2.Digest action_digest = 2; - - /// The salt originally sent along with the StartExecute request. This salt is used - /// as a seed for cases where the execution digest should never be cached or merged - /// with other jobs. This salt is added to the hash function used to compute jobs that - /// are running or cached. - uint64 salt = 3; - - // The digest function that was used to compute the action digest - // and all related blobs. - // - // If the digest function used is one of MD5, MURMUR3, SHA1, SHA256, - // SHA384, SHA512, or VSO, the client MAY leave this field unset. In - // that case the server SHOULD infer the digest function using the - // length of the action digest hash and the digest functions announced - // in the server's capabilities. - build.bazel.remote.execution.v2.DigestFunction.Value digest_function = 7; + /// The operation ID that was executed. + string operation_id = 8; /// The actual response data. oneof result { @@ -131,7 +114,7 @@ message ExecuteResult { google.rpc.Status internal_error = 5; } - reserved 8; // NextId. + reserved 9; // NextId. } /// Result sent back from the server when a node connects. @@ -141,10 +124,10 @@ message ConnectionResult { reserved 2; // NextId. } -/// Request to kill a running action sent from the scheduler to a worker. -message KillActionRequest { - /// The the hex encoded unique qualifier for the action to be killed. - string action_id = 1; +/// Request to kill a running operation sent from the scheduler to a worker. +message KillOperationRequest { + /// The the operation id for the operation to be killed. + string operation_id = 1; reserved 2; // NextId. } /// Communication from the scheduler to the worker. @@ -169,8 +152,8 @@ message UpdateForWorker { /// The worker may discard any outstanding work that is being executed. google.protobuf.Empty disconnect = 4; - /// Instructs the worker to kill a specific running action. - KillActionRequest kill_action_request = 5; + /// Instructs the worker to kill a specific running operation. + KillOperationRequest kill_operation_request = 5; } reserved 6; // NextId. } @@ -179,14 +162,14 @@ message StartExecute { /// The action information used to execute job. build.bazel.remote.execution.v2.ExecuteRequest execute_request = 1; - /// See documentation in ExecuteResult::salt. - uint64 salt = 2; + /// Id of the operation. + string operation_id = 4; /// The time at which the command was added to the queue to allow population /// of the ActionResult. google.protobuf.Timestamp queued_timestamp = 3; - reserved 4; // NextId. + reserved 5; // NextId. } /// This is a special message used to save actions into the CAS that can be used diff --git a/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs b/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs index d4e9eae70..268b5e3ce 100644 --- a/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs +++ b/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs @@ -60,31 +60,9 @@ pub struct ExecuteResult { /// / that initially sent the job as part of the BRE protocol. #[prost(string, tag = "6")] pub instance_name: ::prost::alloc::string::String, - /// / The original execution digest request for this response. The scheduler knows what it - /// / should be, but we do safety checks to ensure it really is the request we expected. - #[prost(message, optional, tag = "2")] - pub action_digest: ::core::option::Option< - super::super::super::super::super::build::bazel::remote::execution::v2::Digest, - >, - /// / The salt originally sent along with the StartExecute request. This salt is used - /// / as a seed for cases where the execution digest should never be cached or merged - /// / with other jobs. This salt is added to the hash function used to compute jobs that - /// / are running or cached. - #[prost(uint64, tag = "3")] - pub salt: u64, - /// The digest function that was used to compute the action digest - /// and all related blobs. - /// - /// If the digest function used is one of MD5, MURMUR3, SHA1, SHA256, - /// SHA384, SHA512, or VSO, the client MAY leave this field unset. In - /// that case the server SHOULD infer the digest function using the - /// length of the action digest hash and the digest functions announced - /// in the server's capabilities. - #[prost( - enumeration = "super::super::super::super::super::build::bazel::remote::execution::v2::digest_function::Value", - tag = "7" - )] - pub digest_function: i32, + /// / The operation ID that was executed. + #[prost(string, tag = "8")] + pub operation_id: ::prost::alloc::string::String, /// / The actual response data. #[prost(oneof = "execute_result::Result", tags = "4, 5")] pub result: ::core::option::Option, @@ -116,13 +94,13 @@ pub struct ConnectionResult { #[prost(string, tag = "1")] pub worker_id: ::prost::alloc::string::String, } -/// / Request to kill a running action sent from the scheduler to a worker. +/// / Request to kill a running operation sent from the scheduler to a worker. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct KillActionRequest { - /// / The the hex encoded unique qualifier for the action to be killed. +pub struct KillOperationRequest { + /// / The the operation id for the operation to be killed. #[prost(string, tag = "1")] - pub action_id: ::prost::alloc::string::String, + pub operation_id: ::prost::alloc::string::String, } /// / Communication from the scheduler to the worker. #[allow(clippy::derive_partial_eq_without_eq)] @@ -155,9 +133,9 @@ pub mod update_for_worker { /// / The worker may discard any outstanding work that is being executed. #[prost(message, tag = "4")] Disconnect(()), - /// / Instructs the worker to kill a specific running action. + /// / Instructs the worker to kill a specific running operation. #[prost(message, tag = "5")] - KillActionRequest(super::KillActionRequest), + KillOperationRequest(super::KillOperationRequest), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -168,9 +146,9 @@ pub struct StartExecute { pub execute_request: ::core::option::Option< super::super::super::super::super::build::bazel::remote::execution::v2::ExecuteRequest, >, - /// / See documentation in ExecuteResult::salt. - #[prost(uint64, tag = "2")] - pub salt: u64, + /// / Id of the operation. + #[prost(string, tag = "4")] + pub operation_id: ::prost::alloc::string::String, /// / The time at which the command was added to the queue to allow population /// / of the ActionResult. #[prost(message, optional, tag = "3")] diff --git a/nativelink-scheduler/BUILD.bazel b/nativelink-scheduler/BUILD.bazel index ba4fc8f89..e7251c78e 100644 --- a/nativelink-scheduler/BUILD.bazel +++ b/nativelink-scheduler/BUILD.bazel @@ -10,24 +10,19 @@ rust_library( name = "nativelink-scheduler", srcs = [ "src/action_scheduler.rs", + "src/api_worker_scheduler.rs", + "src/awaited_action_db/awaited_action.rs", + "src/awaited_action_db/mod.rs", "src/cache_lookup_scheduler.rs", + "src/default_action_listener.rs", "src/default_scheduler_factory.rs", "src/grpc_scheduler.rs", "src/lib.rs", - "src/operation_state_manager.rs", + "src/memory_awaited_action_db.rs", "src/platform_property_manager.rs", "src/property_modifier_scheduler.rs", - "src/redis_action_stage.rs", - "src/redis_operation_state.rs", - "src/scheduler_state/awaited_action.rs", - "src/scheduler_state/client_action_state_result.rs", - "src/scheduler_state/completed_action.rs", - "src/scheduler_state/matching_engine_action_state_result.rs", - "src/scheduler_state/metrics.rs", - "src/scheduler_state/mod.rs", - "src/scheduler_state/state_manager.rs", - "src/scheduler_state/workers.rs", "src/simple_scheduler.rs", + "src/simple_scheduler_state_manager.rs", "src/worker.rs", "src/worker_scheduler.rs", ], @@ -42,7 +37,6 @@ rust_library( "//nativelink-store", "//nativelink-util", "@crates//:async-lock", - "@crates//:bitflags", "@crates//:blake3", "@crates//:futures", "@crates//:hashbrown", @@ -55,6 +49,7 @@ rust_library( "@crates//:scopeguard", "@crates//:serde", "@crates//:serde_json", + "@crates//:static_assertions", "@crates//:tokio", "@crates//:tokio-stream", "@crates//:tonic", diff --git a/nativelink-scheduler/Cargo.toml b/nativelink-scheduler/Cargo.toml index 109eb8eb5..9c766f300 100644 --- a/nativelink-scheduler/Cargo.toml +++ b/nativelink-scheduler/Cargo.toml @@ -28,11 +28,11 @@ tokio = { version = "1.37.0", features = ["sync", "rt", "parking_lot"] } tokio-stream = { version = "0.1.15", features = ["sync"] } tonic = { version = "0.11.0", features = ["gzip", "tls"] } tracing = "0.1.40" -bitflags = "2.5.0" redis = { version = "0.25.2", features = ["aio", "tokio", "json"] } serde = "1.0.203" redis-macros = "0.3.0" serde_json = "1.0.117" +static_assertions = "1.1.0" [dev-dependencies] nativelink-macro = { path = "../nativelink-macro" } diff --git a/nativelink-scheduler/src/action_scheduler.rs b/nativelink-scheduler/src/action_scheduler.rs index 16fd7b873..5a2b9be81 100644 --- a/nativelink-scheduler/src/action_scheduler.rs +++ b/nativelink-scheduler/src/action_scheduler.rs @@ -12,16 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; +use futures::Future; use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionState}; +use nativelink_util::action_messages::{ActionInfo, ActionState, ClientOperationId}; use nativelink_util::metrics_utils::Registry; -use tokio::sync::watch; use crate::platform_property_manager::PlatformPropertyManager; +/// ActionListener interface is responsible for interfacing with clients +/// that are interested in the state of an action. +pub trait ActionListener: Sync + Send + Unpin { + /// Returns the client operation id. + fn client_operation_id(&self) -> &ClientOperationId; + + /// Waits for the action state to change. + fn changed( + &mut self, + ) -> Pin, Error>> + Send + '_>>; +} + /// ActionScheduler interface is responsible for interactions between the scheduler /// and action related operations. #[async_trait] @@ -35,17 +48,15 @@ pub trait ActionScheduler: Sync + Send + Unpin { /// Adds an action to the scheduler for remote execution. async fn add_action( &self, + client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error>; + ) -> Result>, Error>; /// Find an existing action by its name. - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>>; - - /// Cleans up the cache of recently completed actions. - async fn clean_recently_completed_actions(&self); + client_operation_id: &ClientOperationId, + ) -> Result>>, Error>; /// Register the metrics for the action scheduler. fn register_metrics(self: Arc, _registry: &mut Registry) {} diff --git a/nativelink-scheduler/src/api_worker_scheduler.rs b/nativelink-scheduler/src/api_worker_scheduler.rs new file mode 100644 index 000000000..9eda38012 --- /dev/null +++ b/nativelink-scheduler/src/api_worker_scheduler.rs @@ -0,0 +1,477 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use async_lock::Mutex; +use lru::LruCache; +use nativelink_config::schedulers::WorkerAllocationStrategy; +use nativelink_error::{error_if, make_err, make_input_err, Code, Error, ResultExt}; +use nativelink_util::action_messages::{ActionInfo, ActionStage, OperationId, WorkerId}; +use nativelink_util::metrics_utils::{Collector, CollectorState, MetricsComponent, Registry}; +use nativelink_util::operation_state_manager::WorkerStateManager; +use nativelink_util::platform_properties::{PlatformProperties, PlatformPropertyValue}; +use tokio::sync::Notify; +use tonic::async_trait; +use tracing::{event, Level}; + +use crate::platform_property_manager::PlatformPropertyManager; +use crate::worker::{Worker, WorkerTimestamp, WorkerUpdate}; +use crate::worker_scheduler::WorkerScheduler; + +/// A collection of workers that are available to run tasks. +struct ApiWorkerSchedulerImpl { + /// A `LruCache` of workers availabled based on `allocation_strategy`. + workers: LruCache, + + /// The worker state manager. + worker_state_manager: Arc, + /// The allocation strategy for workers. + allocation_strategy: WorkerAllocationStrategy, + /// A channel to notify the matching engine that the worker pool has changed. + worker_change_notify: Arc, +} + +impl ApiWorkerSchedulerImpl { + /// Refreshes the lifetime of the worker with the given timestamp. + fn refresh_lifetime( + &mut self, + worker_id: &WorkerId, + timestamp: WorkerTimestamp, + ) -> Result<(), Error> { + let worker = self.workers.peek_mut(worker_id).ok_or_else(|| { + make_input_err!( + "Worker not found in worker map in refresh_lifetime() {}", + worker_id + ) + })?; + error_if!( + worker.last_update_timestamp > timestamp, + "Worker already had a timestamp of {}, but tried to update it with {}", + worker.last_update_timestamp, + timestamp + ); + worker.last_update_timestamp = timestamp; + Ok(()) + } + + /// Adds a worker to the pool. + /// Note: This function will not do any task matching. + fn add_worker(&mut self, worker: Worker) -> Result<(), Error> { + let worker_id = worker.id; + self.workers.put(worker_id, worker); + + // Worker is not cloneable, and we do not want to send the initial connection results until + // we have added it to the map, or we might get some strange race conditions due to the way + // the multi-threaded runtime works. + let worker = self.workers.peek_mut(&worker_id).unwrap(); + let res = worker + .send_initial_connection_result() + .err_tip(|| "Failed to send initial connection result to worker"); + if let Err(err) = &res { + event!( + Level::ERROR, + ?worker_id, + ?err, + "Worker connection appears to have been closed while adding to pool" + ); + } + self.worker_change_notify.notify_one(); + res + } + + /// Removes worker from pool. + /// Note: The caller is responsible for any rescheduling of any tasks that might be + /// running. + fn remove_worker(&mut self, worker_id: &WorkerId) -> Option { + let result = self.workers.pop(worker_id); + self.worker_change_notify.notify_one(); + result + } + + /// Sets if the worker is draining or not. + async fn set_drain_worker( + &mut self, + worker_id: &WorkerId, + is_draining: bool, + ) -> Result<(), Error> { + let worker = self + .workers + .get_mut(worker_id) + .err_tip(|| format!("Worker {worker_id} doesn't exist in the pool"))?; + worker.is_draining = is_draining; + self.worker_change_notify.notify_one(); + Ok(()) + } + + fn inner_find_worker_for_action( + &self, + platform_properties: &PlatformProperties, + ) -> Option { + let mut workers_iter = self.workers.iter(); + let workers_iter = match self.allocation_strategy { + // Use rfind to get the least recently used that satisfies the properties. + WorkerAllocationStrategy::least_recently_used => workers_iter.rfind(|(_, w)| { + w.can_accept_work() && platform_properties.is_satisfied_by(&w.platform_properties) + }), + // Use find to get the most recently used that satisfies the properties. + WorkerAllocationStrategy::most_recently_used => workers_iter.find(|(_, w)| { + w.can_accept_work() && platform_properties.is_satisfied_by(&w.platform_properties) + }), + }; + workers_iter.map(|(_, w)| &w.id).copied() + } + + async fn update_action( + &mut self, + worker_id: &WorkerId, + operation_id: &OperationId, + action_stage: Result, + ) -> Result<(), Error> { + let worker = self.workers.get_mut(worker_id).err_tip(|| { + format!("Worker {worker_id} does not exist in SimpleScheduler::update_action") + })?; + + // Ensure the worker is supposed to be running the operation. + if !worker.running_action_infos.contains_key(operation_id) { + let err = make_err!( + Code::Internal, + "Operation {operation_id} should not be running on worker {worker_id} in SimpleScheduler::update_action" + ); + return Result::<(), _>::Err(err.clone()) + .merge(self.immediate_evict_worker(worker_id, err).await); + } + + // Update the operation in the worker state manager. + { + let update_operation_res = self + .worker_state_manager + .update_operation(operation_id, worker_id, action_stage.clone()) + .await + .err_tip(|| "in update_operation on SimpleScheduler::update_action"); + if let Err(err) = update_operation_res { + event!( + Level::ERROR, + ?operation_id, + ?worker_id, + ?err, + "Failed to update_operation on update_action" + ); + return Err(err); + } + } + + // We are done if the action is not finished or there was an error. + let is_finished = action_stage + .as_ref() + .map_or_else(|_| true, |action_stage| action_stage.is_finished()); + if !is_finished { + return Ok(()); + } + + // Clear this action from the current worker if finished. + let complete_action_res = { + let was_paused = !worker.can_accept_work(); + + // Note: We need to run this before dealing with backpressure logic. + let complete_action_res = worker.complete_action(operation_id); + + let due_to_backpressure = action_stage + .as_ref() + .map_or_else(|e| e.code == Code::ResourceExhausted, |_| false); + // Only pause if there's an action still waiting that will unpause. + if (was_paused || due_to_backpressure) && worker.has_actions() { + worker.is_paused = true; + } + complete_action_res + }; + + self.worker_change_notify.notify_one(); + + complete_action_res + } + + /// Notifies the specified worker to run the given action and handles errors by evicting + /// the worker if the notification fails. + async fn worker_notify_run_action( + &mut self, + worker_id: WorkerId, + operation_id: OperationId, + action_info: Arc, + ) -> Result<(), Error> { + if let Some(worker) = self.workers.get_mut(&worker_id) { + let notify_worker_result = + worker.notify_update(WorkerUpdate::RunAction((operation_id, action_info.clone()))); + + if notify_worker_result.is_err() { + event!( + Level::WARN, + ?worker_id, + ?action_info, + ?notify_worker_result, + "Worker command failed, removing worker", + ); + + let err = make_err!( + Code::Internal, + "Worker command failed, removing worker {worker_id} -- {notify_worker_result:?}", + ); + + return Result::<(), _>::Err(err.clone()) + .merge(self.immediate_evict_worker(&worker_id, err).await); + } + } else { + event!( + Level::WARN, + ?worker_id, + ?operation_id, + ?action_info, + "Worker not found in worker map in worker_notify_run_action" + ); + } + Ok(()) + } + + /// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it. + async fn immediate_evict_worker( + &mut self, + worker_id: &WorkerId, + err: Error, + ) -> Result<(), Error> { + let mut result = Ok(()); + if let Some(mut worker) = self.remove_worker(worker_id) { + // We don't care if we fail to send message to worker, this is only a best attempt. + let _ = worker.notify_update(WorkerUpdate::Disconnect); + for (operation_id, _) in worker.running_action_infos.drain() { + result = result.merge( + self.worker_state_manager + .update_operation(&operation_id, worker_id, Err(err.clone())) + .await, + ); + } + } + // Note: Calling this many time is very cheap, it'll only trigger `do_try_match` once. + // TODO(allada) This should be moved to inside the Workers struct. + self.worker_change_notify.notify_one(); + result + } +} + +pub struct ApiWorkerScheduler { + inner: Mutex, + platform_property_manager: Arc, + + /// Timeout of how long to evict workers if no response in this given amount of time in seconds. + worker_timeout_s: u64, +} + +impl ApiWorkerScheduler { + pub fn new( + worker_state_manager: Arc, + platform_property_manager: Arc, + allocation_strategy: WorkerAllocationStrategy, + worker_change_notify: Arc, + worker_timeout_s: u64, + ) -> Arc { + Arc::new(Self { + inner: Mutex::new(ApiWorkerSchedulerImpl { + workers: LruCache::unbounded(), + worker_state_manager, + allocation_strategy, + worker_change_notify, + }), + platform_property_manager, + worker_timeout_s, + }) + } + + pub async fn worker_notify_run_action( + &self, + worker_id: WorkerId, + operation_id: OperationId, + action_info: Arc, + ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner + .worker_notify_run_action(worker_id, operation_id, action_info) + .await + } + + /// Attempts to find a worker that is capable of running this action. + // TODO(blaise.bruer) This algorithm is not very efficient. Simple testing using a tree-like + // structure showed worse performance on a 10_000 worker * 7 properties * 1000 queued tasks + // simulation of worst cases in a single threaded environment. + pub async fn find_worker_for_action( + &self, + platform_properties: &PlatformProperties, + ) -> Option { + let inner = self.inner.lock().await; + inner.inner_find_worker_for_action(platform_properties) + } + + /// Checks to see if the worker exists in the worker pool. Should only be used in unit tests. + #[must_use] + pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool { + let inner = self.inner.lock().await; + inner.workers.contains(worker_id) + } + + /// A unit test function used to send the keep alive message to the worker from the server. + pub async fn send_keep_alive_to_worker_for_test( + &self, + worker_id: &WorkerId, + ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + let worker = inner.workers.get_mut(worker_id).ok_or_else(|| { + make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) + })?; + worker.keep_alive() + } +} + +#[async_trait] +impl WorkerScheduler for ApiWorkerScheduler { + fn get_platform_property_manager(&self) -> &PlatformPropertyManager { + self.platform_property_manager.as_ref() + } + + async fn add_worker(&self, worker: Worker) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + let worker_id = worker.id; + let result = inner + .add_worker(worker) + .err_tip(|| "Error while adding worker, removing from pool"); + if let Err(err) = result { + return Result::<(), _>::Err(err.clone()) + .merge(inner.immediate_evict_worker(&worker_id, err).await); + } + Ok(()) + } + + async fn update_action( + &self, + worker_id: &WorkerId, + operation_id: &OperationId, + action_stage: Result, + ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner + .update_action(worker_id, operation_id, action_stage) + .await + } + + async fn worker_keep_alive_received( + &self, + worker_id: &WorkerId, + timestamp: WorkerTimestamp, + ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner + .refresh_lifetime(worker_id, timestamp) + .err_tip(|| "Error refreshing lifetime in worker_keep_alive_received()") + } + + async fn remove_worker(&self, worker_id: &WorkerId) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner + .immediate_evict_worker( + worker_id, + make_err!(Code::Internal, "Received request to remove worker"), + ) + .await + } + + async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + + let mut result = Ok(()); + // Items should be sorted based on last_update_timestamp, so we don't need to iterate the entire + // map most of the time. + let worker_ids_to_remove: Vec = inner + .workers + .iter() + .rev() + .map_while(|(worker_id, worker)| { + if worker.last_update_timestamp <= now_timestamp - self.worker_timeout_s { + Some(*worker_id) + } else { + None + } + }) + .collect(); + for worker_id in &worker_ids_to_remove { + event!( + Level::WARN, + ?worker_id, + "Worker timed out, removing from pool" + ); + result = result.merge( + inner + .immediate_evict_worker( + worker_id, + make_err!( + Code::Internal, + "Worker {worker_id} timed out, removing from pool" + ), + ) + .await, + ); + } + + result + } + + async fn set_drain_worker(&self, worker_id: &WorkerId, is_draining: bool) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner.set_drain_worker(worker_id, is_draining).await + } + + fn register_metrics(self: Arc, registry: &mut Registry) { + self.inner + .lock_blocking() + .worker_state_manager + .clone() + .register_metrics(registry); + registry.register_collector(Box::new(Collector::new(&self))); + } +} + +impl MetricsComponent for ApiWorkerScheduler { + fn gather_metrics(&self, c: &mut CollectorState) { + let inner = self.inner.lock_blocking(); + let mut props = HashMap::<&String, u64>::new(); + for (_worker_id, worker) in inner.workers.iter() { + c.publish_with_labels( + "workers", + worker, + "", + vec![("worker_id".into(), worker.id.to_string().into())], + ); + for (property, prop_value) in &worker.platform_properties.properties { + let current_value = props.get(&property).unwrap_or(&0); + if let PlatformPropertyValue::Minimum(worker_value) = prop_value { + props.insert(property, *current_value + *worker_value); + } + } + } + for (property, prop_value) in props { + c.publish( + &format!("{property}_available_properties"), + &prop_value, + format!("Total sum of available properties for {property}"), + ); + } + } +} diff --git a/nativelink-scheduler/src/awaited_action_db/awaited_action.rs b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs new file mode 100644 index 000000000..eff7b3e01 --- /dev/null +++ b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs @@ -0,0 +1,191 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionState, OperationId, WorkerId, +}; +use static_assertions::{assert_eq_size, const_assert, const_assert_eq}; + +/// The version of the awaited action. +/// This number will always increment by one each time +/// the action is updated. +#[derive(Debug, Clone, Copy)] +struct AwaitedActionVersion(u64); + +/// An action that is being awaited on and last known state. +#[derive(Debug, Clone)] +pub struct AwaitedAction { + /// The current version of the action. + version: AwaitedActionVersion, + + /// The action that is being awaited on. + action_info: Arc, + + /// The operation id of the action. + operation_id: OperationId, + + /// The currentsort key used to order the actions. + sort_key: AwaitedActionSortKey, + + /// The time the action was last updated. + last_worker_updated_timestamp: SystemTime, + + /// Worker that is currently running this action, None if unassigned. + worker_id: Option, + + /// The current state of the action. + state: Arc, + + /// Number of attempts the job has been tried. + pub attempts: usize, +} + +impl AwaitedAction { + pub fn new(operation_id: OperationId, action_info: Arc) -> Self { + let stage = ActionStage::Queued; + let sort_key = AwaitedActionSortKey::new_with_unique_key( + action_info.priority, + &action_info.insert_timestamp, + ); + let state = Arc::new(ActionState { + stage, + id: operation_id.clone(), + }); + Self { + version: AwaitedActionVersion(0), + action_info, + operation_id, + sort_key, + attempts: 0, + last_worker_updated_timestamp: SystemTime::now(), + worker_id: None, + state, + } + } + + pub fn version(&self) -> u64 { + self.version.0 + } + + pub fn increment_version(&mut self) { + self.version = AwaitedActionVersion(self.version.0 + 1); + } + + pub fn action_info(&self) -> &Arc { + &self.action_info + } + + pub fn operation_id(&self) -> &OperationId { + &self.operation_id + } + + pub fn sort_key(&self) -> AwaitedActionSortKey { + self.sort_key + } + + pub fn state(&self) -> &Arc { + &self.state + } + + pub fn worker_id(&self) -> Option { + self.worker_id + } + + pub fn last_worker_updated_timestamp(&self) -> SystemTime { + self.last_worker_updated_timestamp + } + + /// Sets the worker id that is currently processing this action. + pub fn set_worker_id(&mut self, new_maybe_worker_id: Option) { + if self.worker_id != new_maybe_worker_id { + self.worker_id = new_maybe_worker_id; + self.last_worker_updated_timestamp = SystemTime::now(); + } + } + + /// Sets the current state of the action and notifies subscribers. + /// Returns true if the state was set, false if there are no subscribers. + pub fn set_state(&mut self, mut state: Arc) { + std::mem::swap(&mut self.state, &mut state); + self.last_worker_updated_timestamp = SystemTime::now(); + } +} + +/// The key used to sort the awaited actions. +/// +/// The rules for sorting are as follows: +/// 1. priority of the action +/// 2. insert order of the action (lower = higher priority) +/// 3. (mostly random hash based on the action info) +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct AwaitedActionSortKey(u64); + +impl AwaitedActionSortKey { + #[rustfmt::skip] + const fn new(priority: i32, insert_timestamp: u32) -> Self { + // Shift `new_priority` so [`i32::MIN`] is represented by zero. + // This makes it so any nagative values are positive, but + // maintains ordering. + const MIN_I32: i64 = (i32::MIN as i64).abs(); + let priority = ((priority as i64 + MIN_I32) as u32).to_be_bytes(); + + // Invert our timestamp so the larger the timestamp the lower the number. + // This makes timestamp descending order instead of ascending. + let timestamp = (insert_timestamp ^ u32::MAX).to_be_bytes(); + + AwaitedActionSortKey(u64::from_be_bytes([ + priority[0], priority[1], priority[2], priority[3], + timestamp[0], timestamp[1], timestamp[2], timestamp[3], + ])) + } + + fn new_with_unique_key(priority: i32, insert_timestamp: &SystemTime) -> Self { + let timestamp = insert_timestamp + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as u32; + Self::new(priority, timestamp) + } +} + +// Ensure the size of the sort key is the same as a `u64`. +assert_eq_size!(AwaitedActionSortKey, u64); + +const_assert_eq!( + AwaitedActionSortKey::new(0x1234_5678, 0x9abc_def0).0, + // Note: Result has 0x12345678 + 0x80000000 = 0x92345678 because we need + // to shift the `i32::MIN` value to be represented by zero. + // Note: `6543210f` are the inverted bits of `9abcdef0`. + // This effectively inverts the priority to now have the highest priority + // be the lowest timestamps. + AwaitedActionSortKey(0x9234_5678_6543_210f).0 +); +// Ensure the priority is used as the sort key first. +const_assert!( + AwaitedActionSortKey::new(i32::MAX, 0).0 > AwaitedActionSortKey::new(i32::MAX - 1, 0).0 +); +const_assert!(AwaitedActionSortKey::new(i32::MAX - 1, 0).0 > AwaitedActionSortKey::new(1, 0).0); +const_assert!(AwaitedActionSortKey::new(1, 0).0 > AwaitedActionSortKey::new(0, 0).0); +const_assert!(AwaitedActionSortKey::new(0, 0).0 > AwaitedActionSortKey::new(-1, 0).0); +const_assert!(AwaitedActionSortKey::new(-1, 0).0 > AwaitedActionSortKey::new(i32::MIN + 1, 0).0); +const_assert!( + AwaitedActionSortKey::new(i32::MIN + 1, 0).0 > AwaitedActionSortKey::new(i32::MIN, 0).0 +); + +// Ensure the insert timestamp is used as the sort key second. +const_assert!(AwaitedActionSortKey::new(0, u32::MIN).0 > AwaitedActionSortKey::new(0, u32::MAX).0); diff --git a/nativelink-scheduler/src/awaited_action_db/mod.rs b/nativelink-scheduler/src/awaited_action_db/mod.rs new file mode 100644 index 000000000..1d3cc623d --- /dev/null +++ b/nativelink-scheduler/src/awaited_action_db/mod.rs @@ -0,0 +1,121 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp; +use std::ops::Bound; +use std::sync::Arc; + +pub use awaited_action::{AwaitedAction, AwaitedActionSortKey}; +use futures::{Future, Stream}; +use nativelink_error::Error; +use nativelink_util::action_messages::{ActionInfo, ClientOperationId, OperationId}; +use nativelink_util::metrics_utils::MetricsComponent; + +mod awaited_action; + +/// A simple enum to represent the state of an AwaitedAction. +#[derive(Debug, Clone, Copy)] +pub enum SortedAwaitedActionState { + CacheCheck, + Queued, + Executing, + Completed, +} + +/// A struct pointing to an AwaitedAction that can be sorted. +#[derive(Debug, Clone)] +pub struct SortedAwaitedAction { + pub sort_key: AwaitedActionSortKey, + pub operation_id: OperationId, +} + +impl PartialEq for SortedAwaitedAction { + fn eq(&self, other: &Self) -> bool { + self.sort_key == other.sort_key && self.operation_id == other.operation_id + } +} + +impl Eq for SortedAwaitedAction {} + +impl PartialOrd for SortedAwaitedAction { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SortedAwaitedAction { + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.sort_key + .cmp(&other.sort_key) + .then_with(|| self.operation_id.cmp(&other.operation_id)) + } +} + +/// Subscriber that can be used to monitor when AwaitedActions change. +pub trait AwaitedActionSubscriber: Send + Sync + Sized + 'static { + /// Wait for AwaitedAction to change. + fn changed(&mut self) -> impl Future> + Send; + + /// Get the current awaited action. + fn borrow(&self) -> AwaitedAction; +} + +/// A trait that defines the interface for an AwaitedActionDb. +pub trait AwaitedActionDb: Send + Sync + MetricsComponent + 'static { + type Subscriber: AwaitedActionSubscriber; + + /// Get the AwaitedAction by the client operation id. + fn get_awaited_action_by_id( + &self, + client_operation_id: &ClientOperationId, + ) -> impl Future, Error>> + Send + Sync; + + /// Get all AwaitedActions. This call should be avoided as much as possible. + fn get_all_awaited_actions( + &self, + ) -> impl Future> + Send + Sync> + + Send + + Sync; + + /// Get the AwaitedAction by the operation id. + fn get_by_operation_id( + &self, + operation_id: &OperationId, + ) -> impl Future, Error>> + Send + Sync; + + /// Get a range of AwaitedActions of a specific state in sorted order. + fn get_range_of_actions( + &self, + state: SortedAwaitedActionState, + start: Bound, + end: Bound, + desc: bool, + ) -> impl Future> + Send + Sync> + + Send + + Sync; + + /// Process a change changed AwaitedAction and notify any listeners. + fn update_awaited_action( + &self, + new_awaited_action: AwaitedAction, + ) -> impl Future> + Send + Sync; + + /// Add (or join) an action to the AwaitedActionDb and subscribe + /// to changes. + fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> impl Future> + Send + Sync; +} diff --git a/nativelink-scheduler/src/cache_lookup_scheduler.rs b/nativelink-scheduler/src/cache_lookup_scheduler.rs index 4897f5e65..672118a09 100644 --- a/nativelink-scheduler/src/cache_lookup_scheduler.rs +++ b/nativelink-scheduler/src/cache_lookup_scheduler.rs @@ -13,38 +13,44 @@ // limitations under the License. use std::collections::HashMap; +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; -use futures::stream::StreamExt; -use nativelink_error::Error; +use futures::Future; +use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::{ ActionResult as ProtoActionResult, GetActionResultRequest, }; use nativelink_store::ac_utils::get_and_decode_digest; use nativelink_store::grpc_store::GrpcStore; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, OperationId, + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, OperationId, }; use nativelink_util::background_spawn; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; -use nativelink_util::store_trait::{Store, StoreLike}; +use nativelink_util::store_trait::Store; use parking_lot::{Mutex, MutexGuard}; use scopeguard::guard; -use tokio::select; -use tokio::sync::watch; -use tokio_stream::wrappers::WatchStream; +use tokio::sync::oneshot; use tonic::Request; use tracing::{event, Level}; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; use crate::platform_property_manager::PlatformPropertyManager; /// Actions that are having their cache checked or failed cache lookup and are /// being forwarded upstream. Missing the skip_cache_check actions which are /// forwarded directly. -type CheckActions = HashMap>>>; +type CheckActions = HashMap< + ActionUniqueKey, + Vec<( + ClientOperationId, + oneshot::Sender>, Error>>, + )>, +>; pub struct CacheLookupScheduler { /// A reference to the AC to find existing actions in. @@ -54,7 +60,7 @@ pub struct CacheLookupScheduler { /// in the action cache. action_scheduler: Arc, /// Actions that are currently performing a CacheCheck. - cache_check_actions: Arc>, + inflight_cache_checks: Arc>, } async fn get_action_from_store( @@ -62,7 +68,7 @@ async fn get_action_from_store( action_digest: DigestInfo, instance_name: String, digest_function: DigestHasherFunc, -) -> Option { +) -> Result { // If we are a GrpcStore we shortcut here, as this is a special store. if let Some(grpc_store) = ac_store.downcast_ref::(Some(action_digest.into())) { let action_result_request = GetActionResultRequest { @@ -77,27 +83,42 @@ async fn get_action_from_store( .get_action_result(Request::new(action_result_request)) .await .map(|response| response.into_inner()) - .ok() } else { - get_and_decode_digest::(ac_store, action_digest.into()) - .await - .ok() + get_and_decode_digest::(ac_store, action_digest.into()).await } } +/// Future for when ActionListeners are known. +type ActionListenerOneshot = oneshot::Receiver>, Error>>; + fn subscribe_to_existing_action( - cache_check_actions: &MutexGuard, - unique_qualifier: &ActionInfoHashKey, -) -> Option>> { - cache_check_actions.get(unique_qualifier).map(|tx| { - let current_value = tx.borrow(); - // Subscribe marks the current value as seen, so we have to - // re-send it to all receivers. - // TODO: Fix this when fixed upstream tokio-rs/tokio#5871 - let rx = tx.subscribe(); - let _ = tx.send(current_value.clone()); - rx - }) + inflight_cache_checks: &mut MutexGuard, + unique_qualifier: &ActionUniqueKey, + client_operation_id: &ClientOperationId, +) -> Option { + inflight_cache_checks + .get_mut(unique_qualifier) + .map(|oneshots| { + let (tx, rx) = oneshot::channel(); + oneshots.push((client_operation_id.clone(), tx)); + rx + }) +} +struct CachedActionListener { + client_operation_id: ClientOperationId, + action_state: Arc, +} + +impl ActionListener for CachedActionListener { + fn client_operation_id(&self) -> &ClientOperationId { + &self.client_operation_id + } + + fn changed( + &mut self, + ) -> Pin, Error>> + Send + '_>> { + Box::pin(async { Ok(self.action_state.clone()) }) + } } impl CacheLookupScheduler { @@ -105,7 +126,7 @@ impl CacheLookupScheduler { Ok(Self { ac_store, action_scheduler, - cache_check_actions: Default::default(), + inflight_cache_checks: Default::default(), }) } } @@ -123,117 +144,167 @@ impl ActionScheduler for CacheLookupScheduler { async fn add_action( &self, + client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error> { - let id = OperationId::new(action_info.unique_qualifier.clone()); - if action_info.skip_cache_lookup { - // Cache lookup skipped, forward to the upstream. - return self.action_scheduler.add_action(action_info).await; - } - let mut current_state = Arc::new(ActionState { - id, - stage: ActionStage::CacheCheck, - }); - let (tx, rx) = watch::channel(current_state.clone()); - let tx = Arc::new(tx); - let scope_guard = { - let mut cache_check_actions = self.cache_check_actions.lock(); - // Check this isn't a duplicate request first. - if let Some(rx) = - subscribe_to_existing_action(&cache_check_actions, &action_info.unique_qualifier) - { - return Ok(rx); + ) -> Result>, Error> { + let unique_key = match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => unique_key.clone(), + ActionUniqueQualifier::Uncachable(_) => { + // Cache lookup skipped, forward to the upstream. + return self + .action_scheduler + .add_action(client_operation_id, action_info) + .await; } - cache_check_actions.insert(action_info.unique_qualifier.clone(), tx.clone()); - // In the event we loose the reference to our `scope_guard`, it will remove - // the action from the cache_check_actions map. - let cache_check_actions = self.cache_check_actions.clone(); - let unique_qualifier = action_info.unique_qualifier.clone(); - guard((), move |_| { - cache_check_actions.lock().remove(&unique_qualifier); + }; + + let cache_check_result = { + // Check this isn't a duplicate request first. + let mut inflight_cache_checks = self.inflight_cache_checks.lock(); + subscribe_to_existing_action( + &mut inflight_cache_checks, + &unique_key, + &client_operation_id, + ) + .ok_or_else(move || { + let (action_listener_tx, action_listener_rx) = oneshot::channel(); + inflight_cache_checks.insert( + unique_key.clone(), + vec![(client_operation_id, action_listener_tx)], + ); + // In the event we loose the reference to our `scope_guard`, it will remove + // the action from the inflight_cache_checks map. + let inflight_cache_checks = self.inflight_cache_checks.clone(); + ( + action_listener_rx, + guard((), move |_| { + inflight_cache_checks.lock().remove(&unique_key); + }), + ) }) }; + let (action_listener_rx, scope_guard) = match cache_check_result { + Ok(action_listener_fut) => { + let action_listener = action_listener_fut.await.map_err(|_| { + make_err!( + Code::Internal, + "ActionListener tx hung up in CacheLookupScheduler::add_action" + ) + })?; + return action_listener; + } + Err(client_tx_and_scope_guard) => client_tx_and_scope_guard, + }; let ac_store = self.ac_store.clone(); let action_scheduler = self.action_scheduler.clone(); + let inflight_cache_checks = self.inflight_cache_checks.clone(); // We need this spawn because we are returning a stream and this spawn will populate the stream's data. background_spawn!("cache_lookup_scheduler_add_action", async move { - // If our spawn ever dies, we will remove the action from the cache_check_actions map. + // If our spawn ever dies, we will remove the action from the inflight_cache_checks map. let _scope_guard = scope_guard; + let unique_key = match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => unique_key, + ActionUniqueQualifier::Uncachable(unique_key) => { + event!( + Level::ERROR, + ?action_info, + "ActionInfo::unique_qualifier should be ActionUniqueQualifier::Cachable()" + ); + unique_key + } + }; + // Perform cache check. - let action_digest = current_state.action_digest(); - let instance_name = action_info.instance_name().clone(); - if let Some(action_result) = get_action_from_store( + let instance_name = action_info.unique_qualifier.instance_name().clone(); + let maybe_action_result = get_action_from_store( &ac_store, - *action_digest, + action_info.unique_qualifier.digest(), instance_name, - current_state.id.unique_qualifier.digest_function, + action_info.unique_qualifier.digest_function(), ) - .await - { - match ac_store.has(*action_digest).await { - Ok(Some(_)) => { - Arc::make_mut(&mut current_state).stage = - ActionStage::CompletedFromCache(action_result); - let _ = tx.send(current_state); - return; - } - Err(err) => { - event!( - Level::WARN, - ?err, - "Error while calling `has` on `ac_store` in `CacheLookupScheduler`'s `add_action` function" - ); - } - _ => {} - } - } - // Not in cache, forward to upstream and proxy state. - match action_scheduler.add_action(action_info).await { - Ok(rx) => { - let mut watch_stream = WatchStream::new(rx); - loop { - select!( - Some(action_state) = watch_stream.next() => { - if tx.send(action_state).is_err() { - break; - } - } - _ = tx.closed() => { - break; - } - ) + .await; + match maybe_action_result { + Ok(action_result) => { + let maybe_pending_txs = { + let mut inflight_cache_checks = inflight_cache_checks.lock(); + // We are ready to resolve the in-flight actions. We remove the + // in-flight actions from the map. + inflight_cache_checks.remove(unique_key) + }; + let Some(pending_txs) = maybe_pending_txs else { + return; // Nobody is waiting for this action anymore. + }; + let action_state = Arc::new(ActionState { + id: OperationId::new(action_info.unique_qualifier.clone()), + stage: ActionStage::CompletedFromCache(action_result), + }); + for (client_operation_id, pending_tx) in pending_txs { + // Ignore errors here, as the other end may have hung up. + let _ = pending_tx.send(Ok(Box::pin(CachedActionListener { + client_operation_id, + action_state: action_state.clone(), + }))); } + return; } Err(err) => { - Arc::make_mut(&mut current_state).stage = - ActionStage::Completed(ActionResult { - error: Some(err), - ..Default::default() - }); - let _ = tx.send(current_state); + // NotFound errors just mean we need to execute our action. + if err.code != Code::NotFound { + let err = err.append("In CacheLookupScheduler::add_action"); + let maybe_pending_txs = { + let mut inflight_cache_checks = inflight_cache_checks.lock(); + // We are ready to resolve the in-flight actions. We remove the + // in-flight actions from the map. + inflight_cache_checks.remove(unique_key) + }; + let Some(pending_txs) = maybe_pending_txs else { + return; // Nobody is waiting for this action anymore. + }; + for (_client_operation_id, pending_tx) in pending_txs { + // Ignore errors here, as the other end may have hung up. + let _ = pending_tx.send(Err(err.clone())); + } + return; + } } } + + let maybe_pending_txs = { + let mut inflight_cache_checks = inflight_cache_checks.lock(); + inflight_cache_checks.remove(unique_key) + }; + let Some(pending_txs) = maybe_pending_txs else { + return; // Noone is waiting for this action anymore. + }; + + for (client_operation_id, pending_tx) in pending_txs { + // Ignore errors here, as the other end may have hung up. + let _ = pending_tx.send( + action_scheduler + .add_action(client_operation_id, action_info.clone()) + .await, + ); + } }); - Ok(rx) + action_listener_rx + .await + .map_err(|_| { + make_err!( + Code::Internal, + "ActionListener tx hung up in CacheLookupScheduler::add_action" + ) + })? + .err_tip(|| "In CacheLookupScheduler::add_action") } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - { - let cache_check_actions = self.cache_check_actions.lock(); - if let Some(rx) = subscribe_to_existing_action(&cache_check_actions, unique_qualifier) { - return Some(rx); - } - } - // Cache skipped may be in the upstream scheduler. + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { self.action_scheduler - .find_existing_action(unique_qualifier) + .find_by_client_operation_id(client_operation_id) .await } - - async fn clean_recently_completed_actions(&self) {} } diff --git a/nativelink-scheduler/src/default_action_listener.rs b/nativelink-scheduler/src/default_action_listener.rs new file mode 100644 index 000000000..ec399790a --- /dev/null +++ b/nativelink-scheduler/src/default_action_listener.rs @@ -0,0 +1,77 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::pin::Pin; +use std::sync::Arc; + +use futures::Future; +use nativelink_error::{make_err, Code, Error}; +use nativelink_util::action_messages::{ActionState, ClientOperationId}; +use tokio::sync::watch; + +use crate::action_scheduler::ActionListener; + +/// Simple implementation of ActionListener using tokio's watch. +pub struct DefaultActionListener { + client_operation_id: ClientOperationId, + action_state: watch::Receiver>, +} + +impl DefaultActionListener { + pub fn new( + client_operation_id: ClientOperationId, + mut action_state: watch::Receiver>, + ) -> Self { + action_state.mark_changed(); + Self { + client_operation_id, + action_state, + } + } + + pub async fn changed(&mut self) -> Result, Error> { + self.action_state.changed().await.map_or_else( + |e| { + Err(make_err!( + Code::Internal, + "Sender of ActionState went away unexpectedly - {e:?}" + )) + }, + |()| Ok(self.action_state.borrow_and_update().clone()), + ) + } +} + +impl ActionListener for DefaultActionListener { + fn client_operation_id(&self) -> &ClientOperationId { + &self.client_operation_id + } + + fn changed( + &mut self, + ) -> Pin, Error>> + Send + '_>> { + Box::pin(self.changed()) + } +} + +impl Clone for DefaultActionListener { + fn clone(&self) -> Self { + let mut action_state = self.action_state.clone(); + action_state.mark_changed(); + Self { + client_operation_id: self.client_operation_id.clone(), + action_state, + } + } +} diff --git a/nativelink-scheduler/src/default_scheduler_factory.rs b/nativelink-scheduler/src/default_scheduler_factory.rs index d6f06e3ba..304f8534f 100644 --- a/nativelink-scheduler/src/default_scheduler_factory.rs +++ b/nativelink-scheduler/src/default_scheduler_factory.rs @@ -14,14 +14,11 @@ use std::collections::HashSet; use std::sync::Arc; -use std::time::Duration; use nativelink_config::schedulers::SchedulerConfig; use nativelink_error::{Error, ResultExt}; use nativelink_store::store_manager::StoreManager; -use nativelink_util::background_spawn; use nativelink_util::metrics_utils::Registry; -use tokio::time::interval; use crate::action_scheduler::ActionScheduler; use crate::cache_lookup_scheduler::CacheLookupScheduler; @@ -57,8 +54,8 @@ fn inner_scheduler_factory( ) -> Result { let scheduler: SchedulerFactoryResults = match scheduler_type_cfg { SchedulerConfig::simple(config) => { - let scheduler = Arc::new(SimpleScheduler::new(config)); - (Some(scheduler.clone()), Some(scheduler)) + let (action_scheduler, worker_scheduler) = SimpleScheduler::new(config); + (Some(action_scheduler), Some(worker_scheduler)) } SchedulerConfig::grpc(config) => (Some(Arc::new(GrpcScheduler::new(config)?)), None), SchedulerConfig::cache_lookup(config) => { @@ -88,7 +85,6 @@ fn inner_scheduler_factory( if let Some(scheduler_metrics) = maybe_scheduler_metrics { if let Some(action_scheduler) = &scheduler.0 { - start_cleanup_timer(action_scheduler); // We need a way to prevent our scheduler form having `register_metrics()` called multiple times. // This is the equivalent of grabbing a uintptr_t in C++, storing it in a set, and checking if it's // already been visited. We can't use the Arc's pointer directly because it has two interfaces @@ -109,24 +105,8 @@ fn inner_scheduler_factory( visited_schedulers.insert(worker_scheduler_uintptr); worker_scheduler.clone().register_metrics(scheduler_metrics); } - worker_scheduler.clone().register_metrics(scheduler_metrics); } } Ok(scheduler) } - -fn start_cleanup_timer(action_scheduler: &Arc) { - let weak_scheduler = Arc::downgrade(action_scheduler); - background_spawn!("default_scheduler_factory_cleanup_timer", async move { - let mut ticker = interval(Duration::from_secs(1)); - loop { - ticker.tick().await; - match weak_scheduler.upgrade() { - Some(scheduler) => scheduler.clean_recently_completed_actions().await, - // If we fail to upgrade, our service is probably destroyed, so return. - None => return, - } - } - }); -} diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index db8129aa8..45956e139 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -14,6 +14,7 @@ use std::collections::HashMap; use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -28,9 +29,12 @@ use nativelink_proto::build::bazel::remote::execution::v2::{ }; use nativelink_proto::google::longrunning::Operation; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionState, DEFAULT_EXECUTION_PRIORITY, + ActionInfo, ActionState, ActionUniqueKey, ActionUniqueQualifier, ClientOperationId, + OperationId, DEFAULT_EXECUTION_PRIORITY, }; +use nativelink_util::common::DigestInfo; use nativelink_util::connection_manager::ConnectionManager; +use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::{background_spawn, tls_utils}; use parking_lot::Mutex; @@ -42,7 +46,8 @@ use tokio::time::sleep; use tonic::{Request, Streaming}; use tracing::{event, Level}; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; +use crate::default_action_listener::DefaultActionListener; use crate::platform_property_manager::PlatformPropertyManager; pub struct GrpcScheduler { @@ -112,13 +117,26 @@ impl GrpcScheduler { async fn stream_state( mut result_stream: Streaming, - ) -> Result>, Error> { + ) -> Result>, Error> { if let Some(initial_response) = result_stream .message() .await .err_tip(|| "Recieving response from upstream scheduler")? { - let (tx, rx) = watch::channel(Arc::new(initial_response.try_into()?)); + let client_operation_id = + ClientOperationId::from_raw_string(initial_response.name.clone()); + // Our operation_id is not needed here is just a place holder to recycle existing object. + // The only thing that actually matters is the operation_id. + let operation_id = + OperationId::new(ActionUniqueQualifier::Uncachable(ActionUniqueKey { + instance_name: "dummy_instance_name".to_string(), + digest_function: DigestHasherFunc::Sha256, + digest: DigestInfo::zero_digest(), + })); + let action_state = + ActionState::try_from_operation(initial_response, operation_id.clone()) + .err_tip(|| "In GrpcScheduler::stream_state")?; + let (tx, rx) = watch::channel(Arc::new(action_state)); background_spawn!("grpc_scheduler_stream_state", async move { loop { select!( @@ -135,7 +153,8 @@ impl GrpcScheduler { let Ok(Some(response)) = response else { return; }; - match response.try_into() { + let maybe_action_state = ActionState::try_from_operation(response, operation_id.clone()); + match maybe_action_state { Ok(response) => { if let Err(err) = tx.send(Arc::new(response)) { event!( @@ -158,7 +177,10 @@ impl GrpcScheduler { ) } }); - return Ok(rx); + return Ok(Box::pin(DefaultActionListener::new( + client_operation_id, + rx, + ))); } Err(make_err!( Code::Internal, @@ -218,8 +240,9 @@ impl ActionScheduler for GrpcScheduler { async fn add_action( &self, + _client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error> { + ) -> Result>, Error> { let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY { None } else { @@ -227,16 +250,20 @@ impl ActionScheduler for GrpcScheduler { priority: action_info.priority, }) }; + let skip_cache_lookup = match action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(_) => false, + ActionUniqueQualifier::Uncachable(_) => true, + }; let request = ExecuteRequest { instance_name: action_info.instance_name().clone(), - skip_cache_lookup: action_info.skip_cache_lookup, + skip_cache_lookup, action_digest: Some(action_info.digest().into()), execution_policy, // TODO: Get me from the original request, not very important as we ignore it. results_cache_policy: None, digest_function: action_info .unique_qualifier - .digest_function + .digest_function() .proto_digest_func() .into(), }; @@ -257,12 +284,12 @@ impl ActionScheduler for GrpcScheduler { Self::stream_state(result_stream).await } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { let request = WaitExecutionRequest { - name: unique_qualifier.action_name(), + name: client_operation_id.to_string(), }; let result_stream = self .perform_request(request, |request| async move { @@ -270,7 +297,7 @@ impl ActionScheduler for GrpcScheduler { .connection_manager .connection() .await - .err_tip(|| "in find_existing_action()")?; + .err_tip(|| "in find_by_client_operation_id()")?; ExecutionClient::new(channel) .wait_execution(Request::new(request)) .await @@ -279,17 +306,15 @@ impl ActionScheduler for GrpcScheduler { .and_then(|result_stream| Self::stream_state(result_stream.into_inner())) .await; match result_stream { - Ok(result_stream) => Some(result_stream), + Ok(result_stream) => Ok(Some(result_stream)), Err(err) => { event!( Level::WARN, ?err, "Error looking up action with upstream scheduler" ); - None + Ok(None) } } } - - async fn clean_recently_completed_actions(&self) {} } diff --git a/nativelink-scheduler/src/lib.rs b/nativelink-scheduler/src/lib.rs index f8af0bf22..ab9c96c7e 100644 --- a/nativelink-scheduler/src/lib.rs +++ b/nativelink-scheduler/src/lib.rs @@ -13,15 +13,16 @@ // limitations under the License. pub mod action_scheduler; +pub mod api_worker_scheduler; +mod awaited_action_db; pub mod cache_lookup_scheduler; +pub mod default_action_listener; pub mod default_scheduler_factory; pub mod grpc_scheduler; -pub mod operation_state_manager; +mod memory_awaited_action_db; pub mod platform_property_manager; pub mod property_modifier_scheduler; -pub mod redis_action_stage; -pub mod redis_operation_state; -pub mod scheduler_state; pub mod simple_scheduler; +mod simple_scheduler_state_manager; pub mod worker; pub mod worker_scheduler; diff --git a/nativelink-scheduler/src/memory_awaited_action_db.rs b/nativelink-scheduler/src/memory_awaited_action_db.rs new file mode 100644 index 000000000..10992108c --- /dev/null +++ b/nativelink-scheduler/src/memory_awaited_action_db.rs @@ -0,0 +1,987 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::ops::{Bound, RangeBounds}; +use std::sync::Arc; +use std::time::{Duration, Instant, SystemTime}; + +use async_lock::Mutex; +use async_trait::async_trait; +use futures::{FutureExt, Stream}; +use nativelink_config::stores::EvictionPolicy; +use nativelink_error::{error_if, make_err, Code, Error, ResultExt}; +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, OperationId, +}; +use nativelink_util::chunked_stream::ChunkedStream; +use nativelink_util::evicting_map::{EvictingMap, LenEntry}; +use nativelink_util::metrics_utils::{CollectorState, MetricsComponent}; +use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::spawn; +use nativelink_util::task::JoinHandleDropGuard; +use tokio::sync::{mpsc, watch}; +use tracing::{event, Level}; + +use crate::awaited_action_db::{ + AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedAction, + SortedAwaitedActionState, +}; + +/// Number of events to process per cycle. +const MAX_ACTION_EVENTS_RX_PER_CYCLE: usize = 1024; + +/// Duration to wait before sending client keep alive messages. +const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10); + +/// Represents a client that is currently listening to an action. +/// When the client is dropped, it will send the [`AwaitedAction`] to the +/// `drop_tx` if there are other cleanups needed. +#[derive(Debug)] +struct ClientAwaitedAction { + /// The OperationId that the client is listening to. + operation_id: OperationId, + + /// The sender to notify of this struct being dropped. + drop_tx: mpsc::UnboundedSender, +} + +impl ClientAwaitedAction { + pub fn new(operation_id: OperationId, drop_tx: mpsc::UnboundedSender) -> Self { + Self { + operation_id, + drop_tx, + } + } + + pub fn operation_id(&self) -> &OperationId { + &self.operation_id + } +} + +impl Drop for ClientAwaitedAction { + fn drop(&mut self) { + // If we failed to send it means noone is listening. + let _ = self.drop_tx.send(ActionEvent::ClientDroppedOperation( + self.operation_id.clone(), + )); + } +} + +/// Trait to be able to use the EvictingMap with [`ClientAwaitedAction`]. +/// Note: We only use EvictingMap for a time based eviction, which is +/// why the implementation has fixed default values in it. +impl LenEntry for ClientAwaitedAction { + #[inline] + fn len(&self) -> usize { + 0 + } + + #[inline] + fn is_empty(&self) -> bool { + true + } +} + +/// Actions the AwaitedActionsDb needs to process. +pub(crate) enum ActionEvent { + /// A client has sent a keep alive message. + ClientKeepAlive(ClientOperationId), + /// A client has dropped and pointed to OperationId. + ClientDroppedOperation(OperationId), +} + +/// Information required to track an individual client +/// keep alive config and state. +struct ClientKeepAlive { + /// The client operation id. + client_operation_id: ClientOperationId, + /// The last time a keep alive was sent. + last_keep_alive: Instant, + /// The sender to notify of this struct being dropped. + drop_tx: mpsc::UnboundedSender, +} + +/// Subscriber that can be used to monitor when AwaitedActions change. +pub struct MemoryAwaitedActionSubscriber { + /// The receiver to listen for changes. + awaited_action_rx: watch::Receiver, + /// The client operation id and keep alive information. + client_operation_info: Option, +} + +impl MemoryAwaitedActionSubscriber { + pub fn new(mut awaited_action_rx: watch::Receiver) -> Self { + awaited_action_rx.mark_changed(); + Self { + awaited_action_rx, + client_operation_info: None, + } + } + + pub fn new_with_client( + mut awaited_action_rx: watch::Receiver, + client_operation_id: ClientOperationId, + drop_tx: mpsc::UnboundedSender, + ) -> Self { + awaited_action_rx.mark_changed(); + Self { + awaited_action_rx, + client_operation_info: Some(ClientKeepAlive { + client_operation_id, + last_keep_alive: Instant::now(), + drop_tx, + }), + } + } +} + +impl AwaitedActionSubscriber for MemoryAwaitedActionSubscriber { + async fn changed(&mut self) -> Result { + { + let changed_fut = self.awaited_action_rx.changed().map(|r| { + r.map_err(|e| { + make_err!( + Code::Internal, + "Failed to wait for awaited action to change {e:?}" + ) + }) + }); + let Some(client_keep_alive) = self.client_operation_info.as_mut() else { + changed_fut.await?; + return Ok(self.awaited_action_rx.borrow().clone()); + }; + tokio::pin!(changed_fut); + loop { + if client_keep_alive.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION { + client_keep_alive.last_keep_alive = Instant::now(); + // Failing to send just means our receiver dropped. + let _ = client_keep_alive.drop_tx.send(ActionEvent::ClientKeepAlive( + client_keep_alive.client_operation_id.clone(), + )); + } + tokio::select! { + result = &mut changed_fut => { + result?; + break; + } + _ = tokio::time::sleep(CLIENT_KEEPALIVE_DURATION) => { + // If we haven't received any updates for a while, we should + // let the database know that we are still listening to prevent + // the action from being dropped. + } + + } + } + } + Ok(self.awaited_action_rx.borrow().clone()) + } + + fn borrow(&self) -> AwaitedAction { + self.awaited_action_rx.borrow().clone() + } +} + +pub struct MatchingEngineActionStateResult { + awaited_action_sub: T, +} +impl MatchingEngineActionStateResult { + pub fn new(awaited_action_sub: T) -> Self { + Self { awaited_action_sub } + } +} + +#[async_trait] +impl ActionStateResult for MatchingEngineActionStateResult { + async fn as_state(&self) -> Result, Error> { + Ok(self.awaited_action_sub.borrow().state().clone()) + } + + async fn changed(&mut self) -> Result, Error> { + let awaited_action = self.awaited_action_sub.changed().await.map_err(|e| { + make_err!( + Code::Internal, + "Failed to wait for awaited action to change {e:?}" + ) + })?; + Ok(awaited_action.state().clone()) + } + + async fn as_action_info(&self) -> Result, Error> { + Ok(self.awaited_action_sub.borrow().action_info().clone()) + } +} + +pub(crate) struct ClientActionStateResult { + inner: MatchingEngineActionStateResult, +} + +impl ClientActionStateResult { + pub fn new(sub: T) -> Self { + Self { + inner: MatchingEngineActionStateResult::new(sub), + } + } +} + +#[async_trait] +impl ActionStateResult for ClientActionStateResult { + async fn as_state(&self) -> Result, Error> { + self.inner.as_state().await + } + + async fn changed(&mut self) -> Result, Error> { + self.inner.changed().await + } + + async fn as_action_info(&self) -> Result, Error> { + self.inner.as_action_info().await + } +} + +/// A struct that is used to keep the devloper from trying to +/// return early from a function. +struct NoEarlyReturn; + +#[derive(Default)] +struct SortedAwaitedActions { + unknown: BTreeSet, + cache_check: BTreeSet, + queued: BTreeSet, + executing: BTreeSet, + completed: BTreeSet, +} + +impl SortedAwaitedActions { + fn btree_for_state(&mut self, state: &ActionStage) -> &mut BTreeSet { + match state { + ActionStage::Unknown => &mut self.unknown, + ActionStage::CacheCheck => &mut self.cache_check, + ActionStage::Queued => &mut self.queued, + ActionStage::Executing => &mut self.executing, + ActionStage::Completed(_) => &mut self.completed, + ActionStage::CompletedFromCache(_) => &mut self.completed, + } + } + + fn insert_sort_map_for_stage( + &mut self, + stage: &ActionStage, + sorted_awaited_action: SortedAwaitedAction, + ) -> Result<(), Error> { + let newly_inserted = match stage { + ActionStage::Unknown => self.unknown.insert(sorted_awaited_action.clone()), + ActionStage::CacheCheck => self.cache_check.insert(sorted_awaited_action.clone()), + ActionStage::Queued => self.queued.insert(sorted_awaited_action.clone()), + ActionStage::Executing => self.executing.insert(sorted_awaited_action.clone()), + ActionStage::Completed(_) => self.completed.insert(sorted_awaited_action.clone()), + ActionStage::CompletedFromCache(_) => { + self.completed.insert(sorted_awaited_action.clone()) + } + }; + if !newly_inserted { + return Err(make_err!( + Code::Internal, + "Tried to insert an action that was already in the sorted map. This should never happen. {:?} - {:?}", + stage, + sorted_awaited_action + )); + } + Ok(()) + } + + fn process_state_changes( + &mut self, + old_awaited_action: &AwaitedAction, + new_awaited_action: &AwaitedAction, + ) -> Result<(), Error> { + let btree = self.btree_for_state(&old_awaited_action.state().stage); + let maybe_sorted_awaited_action = btree.take(&SortedAwaitedAction { + sort_key: old_awaited_action.sort_key(), + operation_id: new_awaited_action.operation_id().clone(), + }); + + let Some(sorted_awaited_action) = maybe_sorted_awaited_action else { + return Err(make_err!( + Code::Internal, + "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync - {} - {:?}", + new_awaited_action.operation_id(), + new_awaited_action, + )); + }; + + self.insert_sort_map_for_stage(&new_awaited_action.state().stage, sorted_awaited_action) + .err_tip(|| "In AwaitedActionDb::update_awaited_action")?; + Ok(()) + } +} + +/// The database for storing the state of all actions. +pub struct AwaitedActionDbImpl { + /// A lookup table to lookup the state of an action by its client operation id. + client_operation_to_awaited_action: + EvictingMap, SystemTime>, + + /// A lookup table to lookup the state of an action by its worker operation id. + operation_id_to_awaited_action: BTreeMap>, + + /// A lookup table to lookup the state of an action by its unique qualifier. + action_info_hash_key_to_awaited_action: HashMap, + + /// A sorted set of [`AwaitedAction`]s. A wrapper is used to perform sorting + /// based on the [`AwaitedActionSortKey`] of the [`AwaitedAction`]. + /// + /// See [`AwaitedActionSortKey`] for more information on the ordering. + sorted_action_info_hash_keys: SortedAwaitedActions, + + /// The number of connected clients for each operation id. + connected_clients_for_operation_id: HashMap, + + /// Where to send notifications about important events related to actions. + action_event_tx: mpsc::UnboundedSender, +} + +impl AwaitedActionDbImpl { + async fn get_awaited_action_by_id( + &self, + client_operation_id: &ClientOperationId, + ) -> Result, Error> { + let maybe_client_awaited_action = self + .client_operation_to_awaited_action + .get(client_operation_id) + .await; + let client_awaited_action = match maybe_client_awaited_action { + Some(client_awaited_action) => client_awaited_action, + None => return Ok(None), + }; + + self.operation_id_to_awaited_action + .get(client_awaited_action.operation_id()) + .map(|tx| Some(MemoryAwaitedActionSubscriber::new(tx.subscribe()))) + .ok_or_else(|| { + make_err!( + Code::Internal, + "Failed to get client operation id {client_operation_id:?}" + ) + }) + } + + /// Processes action events that need to be handled by the database. + async fn handle_action_events( + &mut self, + action_events: impl IntoIterator, + ) -> NoEarlyReturn { + for drop_action in action_events.into_iter() { + match drop_action { + ActionEvent::ClientDroppedOperation(operation_id) => { + // Cleanup operation_id_to_awaited_action. + let Some(tx) = self.operation_id_to_awaited_action.remove(&operation_id) else { + event!( + Level::ERROR, + ?operation_id, + "operation_id_to_awaited_action does not have operation_id" + ); + continue; + }; + let connected_clients = match self + .connected_clients_for_operation_id + .remove(&operation_id) + { + Some(connected_clients) => connected_clients - 1, + None => { + event!( + Level::ERROR, + ?operation_id, + "connected_clients_for_operation_id does not have operation_id" + ); + 0 + } + }; + // Note: It is rare to have more than one client listening + // to the same action, so we assume that we are the last + // client and insert it back into the map if we detect that + // there are still clients listening (ie: the happy path + // is operation.connected_clients == 0). + if connected_clients != 0 { + self.operation_id_to_awaited_action + .insert(operation_id.clone(), tx); + self.connected_clients_for_operation_id + .insert(operation_id, connected_clients); + continue; + } + let awaited_action = tx.borrow().clone(); + // Cleanup action_info_hash_key_to_awaited_action if it was marked cached. + match &awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(action_key) => { + let maybe_awaited_action = self + .action_info_hash_key_to_awaited_action + .remove(action_key); + if !awaited_action.state().stage.is_finished() + && maybe_awaited_action.is_none() + { + event!( + Level::ERROR, + ?operation_id, + ?awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", + ); + } + } + ActionUniqueQualifier::Uncachable(_action_key) => { + // This Operation should not be in the hash_key map. + } + } + + // Cleanup sorted_awaited_action. + let sort_key = awaited_action.sort_key(); + let sort_btree_for_state = self + .sorted_action_info_hash_keys + .btree_for_state(&awaited_action.state().stage); + + let maybe_sorted_awaited_action = + sort_btree_for_state.take(&SortedAwaitedAction { + sort_key, + operation_id: operation_id.clone(), + }); + if maybe_sorted_awaited_action.is_none() { + event!( + Level::ERROR, + ?operation_id, + ?sort_key, + "Expected maybe_sorted_awaited_action to have {sort_key:?}", + ); + } + } + ActionEvent::ClientKeepAlive(client_id) => { + let maybe_size = self + .client_operation_to_awaited_action + .size_for_key(&client_id) + .await; + if maybe_size.is_none() { + event!( + Level::ERROR, + ?client_id, + "client_operation_to_awaited_action does not have client_id", + ); + } + } + } + } + NoEarlyReturn + } + + fn get_awaited_actions_range( + &self, + start: Bound<&OperationId>, + end: Bound<&OperationId>, + ) -> impl Iterator { + self.operation_id_to_awaited_action + .range((start, end)) + .map(|(operation_id, tx)| { + ( + operation_id, + MemoryAwaitedActionSubscriber::new(tx.subscribe()), + ) + }) + } + + fn get_by_operation_id( + &self, + operation_id: &OperationId, + ) -> Option { + self.operation_id_to_awaited_action + .get(operation_id) + .map(|tx| MemoryAwaitedActionSubscriber::new(tx.subscribe())) + } + + fn get_range_of_actions<'a, 'b>( + &'a self, + state: SortedAwaitedActionState, + range: impl RangeBounds + 'b, + ) -> impl DoubleEndedIterator< + Item = Result<(&'a SortedAwaitedAction, MemoryAwaitedActionSubscriber), Error>, + > + 'a { + let btree = match state { + SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check, + SortedAwaitedActionState::Queued => &self.sorted_action_info_hash_keys.queued, + SortedAwaitedActionState::Executing => &self.sorted_action_info_hash_keys.executing, + SortedAwaitedActionState::Completed => &self.sorted_action_info_hash_keys.completed, + }; + btree.range(range).map(|sorted_awaited_action| { + let operation_id = &sorted_awaited_action.operation_id; + self.get_by_operation_id(operation_id) + .ok_or_else(|| { + make_err!( + Code::Internal, + "Failed to get operation id {}", + operation_id + ) + }) + .map(|subscriber| (sorted_awaited_action, subscriber)) + }) + } + + fn process_state_changes_for_hash_key_map( + action_info_hash_key_to_awaited_action: &mut HashMap, + new_awaited_action: &AwaitedAction, + ) -> Result<(), Error> { + // Do not allow future subscribes if the action is already completed, + // this is the responsibility of the CacheLookupScheduler. + // TODO(allad) Once we land the new scheduler onto main, we can remove this check. + // It makes sense to allow users to subscribe to already completed items. + // This can be changed to `.is_error()` later. + if !new_awaited_action.state().stage.is_finished() { + return Ok(()); + } + match &new_awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(action_key) => { + let maybe_awaited_action = + action_info_hash_key_to_awaited_action.remove(action_key); + match maybe_awaited_action { + Some(removed_operation_id) => { + if &removed_operation_id != new_awaited_action.operation_id() { + event!( + Level::ERROR, + ?removed_operation_id, + ?new_awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", + ); + } + } + None => { + event!( + Level::ERROR, + ?new_awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action out of sync, it should have had the unique_key", + ); + } + } + Ok(()) + } + ActionUniqueQualifier::Uncachable(_action_key) => { + // If we are not cachable, the action should not be in the + // hash_key map, so we don't need to process anything in + // action_info_hash_key_to_awaited_action. + Ok(()) + } + } + } + + fn update_awaited_action(&mut self, new_awaited_action: AwaitedAction) -> Result<(), Error> { + let tx = self + .operation_id_to_awaited_action + .get(new_awaited_action.operation_id()) + .ok_or_else(|| { + make_err!( + Code::Internal, + "OperationId does not exist in map in AwaitedActionDb::update_awaited_action" + ) + })?; + { + // Note: It's important to drop old_awaited_action before we call + // send_replace or we will have a deadlock. + let old_awaited_action = tx.borrow(); + + // Do not process changes if the action version is not in sync with + // what the sender based the update on. + if old_awaited_action.version() + 1 != new_awaited_action.version() { + return Err(make_err!( + // From: https://grpc.github.io/grpc/core/md_doc_statuscodes.html + // Use ABORTED if the client should retry at a higher level + // (e.g., when a client-specified test-and-set fails, + // indicating the client should restart a read-modify-write + // sequence) + Code::Aborted, + "{} Expected {:?} but got {:?} for operation_id {:?} - {:?}", + "Tried to update an awaited action with an incorrect version.", + old_awaited_action.version() + 1, + new_awaited_action.version(), + old_awaited_action, + new_awaited_action, + )); + } + + error_if!( + old_awaited_action.action_info().unique_qualifier + != new_awaited_action.action_info().unique_qualifier, + "Unique key changed for operation_id {:?} - {:?} - {:?}", + new_awaited_action.operation_id(), + old_awaited_action.action_info(), + new_awaited_action.action_info(), + ); + let is_same_stage = old_awaited_action + .state() + .stage + .is_same_stage(&new_awaited_action.state().stage); + + if !is_same_stage { + self.sorted_action_info_hash_keys + .process_state_changes(&old_awaited_action, &new_awaited_action)?; + Self::process_state_changes_for_hash_key_map( + &mut self.action_info_hash_key_to_awaited_action, + &new_awaited_action, + )?; + } + } + + // Notify all listeners of the new state and ignore if no one is listening. + // Note: Do not use `.send()` as it will not update the state if all listeners + // are dropped. + let _ = tx.send_replace(new_awaited_action); + + Ok(()) + } + + /// Creates a new [`ClientAwaitedAction`] and a [`watch::Receiver`] to + /// listen for changes. We don't do this in-line because it is important + /// to ALWAYS construct a [`ClientAwaitedAction`] before inserting it into + /// the map. Failing to do so may result in memory leaks. This is because + /// [`ClientAwaitedAction`] implements a drop function that will trigger + /// cleanup of the other maps on drop. + fn make_client_awaited_action( + &mut self, + operation_id: OperationId, + awaited_action: AwaitedAction, + ) -> (Arc, watch::Receiver) { + let (tx, rx) = watch::channel(awaited_action); + let client_awaited_action = Arc::new(ClientAwaitedAction::new( + operation_id.clone(), + self.action_event_tx.clone(), + )); + self.operation_id_to_awaited_action + .insert(operation_id.clone(), tx); + self.connected_clients_for_operation_id + .insert(operation_id.clone(), 1); + (client_awaited_action, rx) + } + + async fn add_action( + &mut self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result { + // Check to see if the action is already known and subscribe if it is. + let subscription_result = self + .try_subscribe( + &client_operation_id, + &action_info.unique_qualifier, + action_info.priority, + ) + .await + .err_tip(|| "In AwaitedActionDb::subscribe_or_add_action"); + match subscription_result { + Err(err) => return Err(err), + Ok(Some(subscription)) => return Ok(subscription), + Ok(None) => { /* Add item to queue. */ } + } + + let maybe_unique_key = match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => Some(unique_key.clone()), + ActionUniqueQualifier::Uncachable(_unique_key) => None, + }; + let operation_id = OperationId::new(action_info.unique_qualifier.clone()); + let awaited_action = AwaitedAction::new(operation_id.clone(), action_info); + debug_assert!( + ActionStage::Queued == awaited_action.state().stage, + "Expected action to be queued" + ); + let sort_key = awaited_action.sort_key(); + + let (client_awaited_action, rx) = + self.make_client_awaited_action(operation_id.clone(), awaited_action); + + self.client_operation_to_awaited_action + .insert(client_operation_id.clone(), client_awaited_action) + .await; + + // Note: We only put items in the map that are cachable. + if let Some(unique_key) = maybe_unique_key { + let old_value = self + .action_info_hash_key_to_awaited_action + .insert(unique_key, operation_id.clone()); + if let Some(old_value) = old_value { + event!( + Level::ERROR, + ?operation_id, + ?old_value, + "action_info_hash_key_to_awaited_action already has unique_key" + ); + } + } + + self.sorted_action_info_hash_keys + .insert_sort_map_for_stage( + &ActionStage::Queued, + SortedAwaitedAction { + sort_key, + operation_id, + }, + ) + .err_tip(|| "In AwaitedActionDb::subscribe_or_add_action")?; + + Ok(MemoryAwaitedActionSubscriber::new_with_client( + rx, + client_operation_id, + self.action_event_tx.clone(), + )) + } + + async fn try_subscribe( + &mut self, + client_operation_id: &ClientOperationId, + unique_qualifier: &ActionUniqueQualifier, + // TODO(allada) To simplify the scheduler 2024 refactor, we + // removed the ability to upgrade priorities of actions. + // we should add priority upgrades back in. + _priority: i32, + ) -> Result, Error> { + let unique_key = match unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => unique_key, + ActionUniqueQualifier::Uncachable(_unique_key) => return Ok(None), + }; + + let Some(operation_id) = self.action_info_hash_key_to_awaited_action.get(unique_key) else { + return Ok(None); // Not currently running. + }; + + let Some(tx) = self.operation_id_to_awaited_action.get(operation_id) else { + return Err(make_err!( + Code::Internal, + "operation_id_to_awaited_action and action_info_hash_key_to_awaited_action are out of sync for {unique_key:?} - {operation_id}" + )); + }; + + error_if!( + tx.borrow().state().stage.is_finished(), + "Tried to subscribe to a completed action but it already finished. This should never happen. {:?}", + tx.borrow() + ); + + let maybe_connected_clients = self + .connected_clients_for_operation_id + .get_mut(operation_id); + let Some(connected_clients) = maybe_connected_clients else { + return Err(make_err!( + Code::Internal, + "connected_clients_for_operation_id and operation_id_to_awaited_action are out of sync for {unique_key:?} - {operation_id}" + )); + }; + *connected_clients += 1; + + let subscription = tx.subscribe(); + + self.client_operation_to_awaited_action + .insert( + client_operation_id.clone(), + Arc::new(ClientAwaitedAction::new( + operation_id.clone(), + self.action_event_tx.clone(), + )), + ) + .await; + + Ok(Some(MemoryAwaitedActionSubscriber::new(subscription))) + } +} + +pub struct MemoryAwaitedActionDb { + inner: Arc>, + _handle_awaited_action_events: JoinHandleDropGuard<()>, +} + +impl MemoryAwaitedActionDb { + pub fn new(eviction_config: &EvictionPolicy) -> Self { + let (action_event_tx, mut action_event_rx) = mpsc::unbounded_channel(); + let inner = Arc::new(Mutex::new(AwaitedActionDbImpl { + client_operation_to_awaited_action: EvictingMap::new( + eviction_config, + SystemTime::now(), + ), + operation_id_to_awaited_action: BTreeMap::new(), + action_info_hash_key_to_awaited_action: HashMap::new(), + sorted_action_info_hash_keys: SortedAwaitedActions::default(), + connected_clients_for_operation_id: HashMap::new(), + action_event_tx, + })); + let weak_inner = Arc::downgrade(&inner); + Self { + inner, + _handle_awaited_action_events: spawn!("handle_awaited_action_events", async move { + let mut dropped_operation_ids = Vec::with_capacity(MAX_ACTION_EVENTS_RX_PER_CYCLE); + loop { + dropped_operation_ids.clear(); + action_event_rx + .recv_many(&mut dropped_operation_ids, MAX_ACTION_EVENTS_RX_PER_CYCLE) + .await; + let Some(inner) = weak_inner.upgrade() else { + return; // Nothing to cleanup, our struct is dropped. + }; + let mut inner = inner.lock().await; + inner + .handle_action_events(dropped_operation_ids.drain(..)) + .await; + } + }), + } + } +} + +impl AwaitedActionDb for MemoryAwaitedActionDb { + type Subscriber = MemoryAwaitedActionSubscriber; + + async fn get_awaited_action_by_id( + &self, + client_operation_id: &ClientOperationId, + ) -> Result, Error> { + self.inner + .lock() + .await + .get_awaited_action_by_id(client_operation_id) + .await + } + + async fn get_all_awaited_actions(&self) -> impl Stream> { + ChunkedStream::new( + Bound::Unbounded, + Bound::Unbounded, + move |start, end, mut output| async move { + let inner = self.inner.lock().await; + let mut maybe_new_start = None; + + for (operation_id, item) in + inner.get_awaited_actions_range(start.as_ref(), end.as_ref()) + { + output.push_back(item); + maybe_new_start = Some(operation_id); + } + + Ok(maybe_new_start + .map(|new_start| ((Bound::Excluded(new_start.clone()), end), output))) + }, + ) + } + + async fn get_by_operation_id( + &self, + operation_id: &OperationId, + ) -> Result, Error> { + Ok(self.inner.lock().await.get_by_operation_id(operation_id)) + } + + async fn get_range_of_actions( + &self, + state: SortedAwaitedActionState, + start: Bound, + end: Bound, + desc: bool, + ) -> impl Stream> + Send + Sync { + ChunkedStream::new(start, end, move |start, end, mut output| async move { + let inner = self.inner.lock().await; + let mut done = true; + let mut new_start = start.as_ref(); + let mut new_end = end.as_ref(); + + let iterator = inner.get_range_of_actions(state, (start.as_ref(), end.as_ref())); + // TODO(allada) This should probably use the `.left()/right()` pattern, + // but that doesn't exist in the std or any libraries we use. + if desc { + for result in iterator.rev() { + let (sorted_awaited_action, item) = + result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; + output.push_back(item); + new_end = Bound::Excluded(sorted_awaited_action); + done = false; + } + } else { + for result in iterator { + let (sorted_awaited_action, item) = + result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; + output.push_back(item); + new_start = Bound::Excluded(sorted_awaited_action); + done = false; + } + } + if done { + return Ok(None); + } + Ok(Some(((new_start.cloned(), new_end.cloned()), output))) + }) + } + + async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> { + self.inner + .lock() + .await + .update_awaited_action(new_awaited_action) + } + + async fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result { + self.inner + .lock() + .await + .add_action(client_operation_id, action_info) + .await + } +} + +impl MetricsComponent for MemoryAwaitedActionDb { + fn gather_metrics(&self, c: &mut CollectorState) { + let inner = self.inner.lock_blocking(); + c.publish( + "action_state_unknown_total", + &inner.sorted_action_info_hash_keys.unknown.len(), + "Number of actions wih the current state of unknown.", + ); + c.publish( + "action_state_cache_check_total", + &inner.sorted_action_info_hash_keys.cache_check.len(), + "Number of actions wih the current state of cache_check.", + ); + c.publish( + "action_state_queued_total", + &inner.sorted_action_info_hash_keys.queued.len(), + "Number of actions wih the current state of queued.", + ); + c.publish( + "action_state_executing_total", + &inner.sorted_action_info_hash_keys.executing.len(), + "Number of actions wih the current state of executing.", + ); + c.publish( + "action_state_completed_total", + &inner.sorted_action_info_hash_keys.completed.len(), + "Number of actions wih the current state of completed.", + ); + // TODO(allada) This is legacy and should be removed in the future. + c.publish( + "active_actions_total", + &inner.sorted_action_info_hash_keys.executing.len(), + "(LEGACY) The number of running actions.", + ); + // TODO(allada) This is legacy and should be removed in the future. + c.publish( + "queued_actions_total", + &inner.sorted_action_info_hash_keys.queued.len(), + "(LEGACY) The number actions in the queue.", + ); + } +} diff --git a/nativelink-scheduler/src/property_modifier_scheduler.rs b/nativelink-scheduler/src/property_modifier_scheduler.rs index 8b289fff2..77f3b897d 100644 --- a/nativelink-scheduler/src/property_modifier_scheduler.rs +++ b/nativelink-scheduler/src/property_modifier_scheduler.rs @@ -14,17 +14,17 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; use nativelink_config::schedulers::{PropertyModification, PropertyType}; use nativelink_error::{Error, ResultExt}; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionState}; +use nativelink_util::action_messages::{ActionInfo, ClientOperationId}; use nativelink_util::metrics_utils::Registry; use parking_lot::Mutex; -use tokio::sync::watch; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; use crate::platform_property_manager::PlatformPropertyManager; pub struct PropertyModifierScheduler { @@ -90,10 +90,11 @@ impl ActionScheduler for PropertyModifierScheduler { async fn add_action( &self, + client_operation_id: ClientOperationId, mut action_info: ActionInfo, - ) -> Result>, Error> { + ) -> Result>, Error> { let platform_property_manager = self - .get_platform_property_manager(&action_info.unique_qualifier.instance_name) + .get_platform_property_manager(action_info.unique_qualifier.instance_name()) .await .err_tip(|| "In PropertyModifierScheduler::add_action")?; for modification in &self.modifications { @@ -111,18 +112,18 @@ impl ActionScheduler for PropertyModifierScheduler { } }; } - self.scheduler.add_action(action_info).await + self.scheduler + .add_action(client_operation_id, action_info) + .await } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - self.scheduler.find_existing_action(unique_qualifier).await - } - - async fn clean_recently_completed_actions(&self) { - self.scheduler.clean_recently_completed_actions().await + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { + self.scheduler + .find_by_client_operation_id(client_operation_id) + .await } // Register metrics for the underlying ActionScheduler. diff --git a/nativelink-scheduler/src/redis_action_stage.rs b/nativelink-scheduler/src/redis_action_stage.rs deleted file mode 100644 index 3176c7324..000000000 --- a/nativelink-scheduler/src/redis_action_stage.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use nativelink_error::{make_input_err, Error, ResultExt}; -use nativelink_util::action_messages::{ActionResult, ActionStage}; -use serde::{Deserialize, Serialize}; - -use crate::operation_state_manager::OperationStageFlags; - -#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] -pub enum RedisOperationStage { - CacheCheck, - Queued, - Executing, - Completed(ActionResult), - CompletedFromCache(ActionResult), -} - -impl RedisOperationStage { - pub fn as_state_flag(&self) -> OperationStageFlags { - match self { - Self::CacheCheck => OperationStageFlags::CacheCheck, - Self::Executing => OperationStageFlags::Executing, - Self::Queued => OperationStageFlags::Queued, - Self::Completed(_) => OperationStageFlags::Completed, - Self::CompletedFromCache(_) => OperationStageFlags::Completed, - } - } -} - -impl TryFrom for RedisOperationStage { - type Error = Error; - fn try_from(stage: ActionStage) -> Result { - match stage { - ActionStage::CacheCheck => Ok(RedisOperationStage::CacheCheck), - ActionStage::Queued => Ok(RedisOperationStage::Queued), - ActionStage::Executing => Ok(RedisOperationStage::Executing), - ActionStage::Completed(result) => Ok(RedisOperationStage::Completed(result)), - ActionStage::CompletedFromCache(proto_result) => { - let decoded = ActionResult::try_from(proto_result) - .err_tip(|| "In RedisOperationStage::try_from::")?; - Ok(RedisOperationStage::Completed(decoded)) - } - ActionStage::Unknown => Err(make_input_err!("ActionStage conversion to RedisOperationStage failed with Error - Unknown is not a valid OperationStage")), - } - } -} - -impl From for ActionStage { - fn from(stage: RedisOperationStage) -> ActionStage { - match stage { - RedisOperationStage::CacheCheck => ActionStage::CacheCheck, - RedisOperationStage::Queued => ActionStage::Queued, - RedisOperationStage::Executing => ActionStage::Executing, - RedisOperationStage::Completed(result) => ActionStage::Completed(result), - RedisOperationStage::CompletedFromCache(result) => { - ActionStage::CompletedFromCache(result.into()) - } - } - } -} - -impl From<&RedisOperationStage> for ActionStage { - fn from(stage: &RedisOperationStage) -> Self { - stage.clone().into() - } -} diff --git a/nativelink-scheduler/src/redis_operation_state.rs b/nativelink-scheduler/src/redis_operation_state.rs deleted file mode 100644 index 5dd4c13d0..000000000 --- a/nativelink-scheduler/src/redis_operation_state.rs +++ /dev/null @@ -1,465 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::str::FromStr; -use std::sync::Arc; -use std::time::SystemTime; - -use futures::{join, StreamExt}; -use nativelink_error::{make_input_err, Error, ResultExt}; -use nativelink_store::redis_store::RedisStore; -use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionStage, ActionState, OperationId, WorkerId, -}; -use nativelink_util::buf_channel::make_buf_channel_pair; -use nativelink_util::spawn; -use nativelink_util::store_trait::{StoreDriver, StoreLike, StoreSubscription}; -use nativelink_util::task::JoinHandleDropGuard; -use redis::aio::{ConnectionLike, ConnectionManager}; -use redis::{AsyncCommands, Pipeline}; -use redis_macros::{FromRedisValue, ToRedisArgs}; -use serde::{Deserialize, Serialize}; -use tokio::sync::watch; -use tonic::async_trait; -use tracing::{event, Level}; - -use crate::operation_state_manager::{ - ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, - OperationFilter, WorkerStateManager, -}; -use crate::redis_action_stage::RedisOperationStage; - -#[inline] -fn build_action_key(unique_qualifier: &ActionInfoHashKey) -> String { - format!("actions:{}", unique_qualifier.action_name()) -} - -#[inline] -fn build_operations_key(operation_id: &OperationId) -> String { - format!("operations:{operation_id}") -} - -pub struct RedisOperationState { - rx: watch::Receiver>, - inner: Arc, - _join_handle: JoinHandleDropGuard<()>, -} - -impl RedisOperationState { - fn new( - inner: Arc, - mut operation_subscription: Box, - ) -> Self { - let (tx, rx) = watch::channel(inner.as_state()); - - let _join_handle = spawn!("redis_subscription_watcher", async move { - loop { - let Ok(item) = operation_subscription.changed().await else { - // This might occur if the store subscription is dropped - // or if there is an error fetching the data. - return; - }; - let (mut data_tx, mut data_rx) = make_buf_channel_pair(); - let (get_res, data_res) = join!( - // We use async move because we want to transfer ownership of data_tx into the closure. - // That way if join! selects data_rx.consume(None) because get fails, - // data_tx goes out of scope and will be dropped. - async move { item.get(&mut data_tx).await }, - data_rx.consume(None) - ); - - let res = get_res - .merge(data_res) - .and_then(|data| { - RedisOperation::from_slice(&data[..]) - .err_tip(|| "Error while Publishing RedisSubscription") - }) - .map(|redis_operation| { - tx.send_modify(move |cur_state| *cur_state = redis_operation.as_state()) - }); - if let Err(e) = res { - // TODO: Refactor API to allow error to be propogated to client. - event!( - Level::ERROR, - ?e, - "Error During Redis Operation Subscription", - ); - return; - } - } - }); - Self { - rx, - _join_handle, - inner, - } - } -} - -#[async_trait] -impl ActionStateResult for RedisOperationState { - async fn as_state(&self) -> Result, Error> { - Ok(Arc::new(ActionState::from(self.inner.as_ref()))) - } - - async fn as_receiver(&self) -> Result<&'_ watch::Receiver>, Error> { - Ok(&self.rx) - } - - async fn as_action_info(&self) -> Result, Error> { - Ok(Arc::new(self.inner.info.clone())) - } -} - -#[derive(Serialize, Deserialize, Clone, Debug, ToRedisArgs, FromRedisValue)] -pub struct RedisOperation { - operation_id: OperationId, - info: ActionInfo, - worker_id: Option, - stage: RedisOperationStage, - last_worker_update: Option, - last_client_update: Option, - last_error: Option, - completed_at: Option, -} - -impl RedisOperation { - pub fn as_json(&self) -> String { - serde_json::json!(&self).to_string() - } - - pub fn from_slice(s: &[u8]) -> Result { - serde_json::from_slice(s).map_err(|e| { - make_input_err!("Create RedisOperation from slice failed with Error - {e:?}") - }) - } - - pub fn new(info: ActionInfo, operation_id: OperationId) -> Self { - Self { - operation_id, - info, - worker_id: None, - stage: RedisOperationStage::CacheCheck, - last_worker_update: None, - last_client_update: None, - last_error: None, - completed_at: None, - } - } - - pub fn from_existing(existing: RedisOperation, operation_id: OperationId) -> Self { - Self { - operation_id, - info: existing.info, - worker_id: existing.worker_id, - stage: existing.stage, - last_worker_update: existing.last_worker_update, - last_client_update: existing.last_client_update, - last_error: existing.last_error, - completed_at: existing.completed_at, - } - } - - pub fn as_state(&self) -> Arc { - let action_state = ActionState { - stage: self.stage.clone().into(), - id: self.operation_id.clone(), - }; - Arc::new(action_state) - } - - pub fn unique_qualifier(&self) -> &ActionInfoHashKey { - &self.operation_id.unique_qualifier - } - - pub fn matches_filter(&self, filter: &OperationFilter) -> bool { - // If the filter value is None, we can match anything and return true. - // If the filter value is Some and the value is None, it can't be a match so we return false. - // If both values are Some, we compare to determine if there is a match. - let matches_stage_filter = filter.stages.contains(self.stage.as_state_flag()); - if !matches_stage_filter { - return false; - } - - let matches_operation_filter = filter - .operation_id - .as_ref() - .map_or(true, |id| &self.operation_id == id); - if !matches_operation_filter { - return false; - } - - let matches_worker_filter = self.worker_id == filter.worker_id; - if !matches_worker_filter { - return false; - }; - - let matches_digest_filter = filter - .action_digest - .map_or(true, |digest| self.unique_qualifier().digest == digest); - if !matches_digest_filter { - return false; - }; - - let matches_completed_before = filter.completed_before.map_or(true, |before| { - self.completed_at - .map_or(false, |completed_at| completed_at < before) - }); - if !matches_completed_before { - return false; - }; - - let matches_last_update = filter.last_client_update_before.map_or(true, |before| { - self.last_client_update - .map_or(false, |last_update| last_update < before) - }); - if !matches_last_update { - return false; - }; - - true - } -} - -impl FromStr for RedisOperation { - type Err = Error; - fn from_str(s: &str) -> Result { - serde_json::from_str(s).map_err(|e| { - make_input_err!( - "Decode string {s} to RedisOperation failed with error: {}", - e.to_string() - ) - }) - } -} - -impl From<&RedisOperation> for ActionState { - fn from(value: &RedisOperation) -> Self { - ActionState { - id: value.operation_id.clone(), - stage: value.stage.clone().into(), - } - } -} - -pub struct RedisStateManager< - T: ConnectionLike + Unpin + Clone + Send + Sync + 'static = ConnectionManager, -> { - store: Arc>, -} - -impl RedisStateManager { - pub fn new(store: Arc>) -> Self { - Self { store } - } - - pub async fn get_conn(&self) -> Result { - self.store.get_conn().await - } - - async fn list<'a, V>( - &self, - prefix: &str, - handler: impl Fn(String, String) -> Result, - ) -> Result, Error> - where - V: Send + Sync, - { - let mut con = self - .get_conn() - .await - .err_tip(|| "In RedisStateManager::list")?; - let ids_iter = con - .scan_match::<&str, String>(prefix) - .await - .err_tip(|| "In RedisStateManager::list")?; - let keys = ids_iter.collect::>().await; - let raw_values: Vec = con - .get(&keys) - .await - .err_tip(|| "In RedisStateManager::list")?; - keys.into_iter() - .zip(raw_values.into_iter()) - .map(|(k, v)| handler(k, v)) - .collect() - } - - async fn inner_add_action( - &self, - action_info: ActionInfo, - ) -> Result, Error> { - let operation_id = OperationId::new(action_info.unique_qualifier.clone()); - let mut con = self - .get_conn() - .await - .err_tip(|| "In RedisStateManager::inner_add_action")?; - let action_key = build_action_key(&operation_id.unique_qualifier); - // TODO: List API call to find existing actions. - let mut existing_operations: Vec = Vec::new(); - let operation = match existing_operations.pop() { - Some(existing_operation) => { - let operations_key = build_operations_key(&existing_operation); - let operation: RedisOperation = con - .get(operations_key) - .await - .err_tip(|| "In RedisStateManager::inner_add_action")?; - RedisOperation::from_existing(operation.clone(), operation_id.clone()) - } - None => RedisOperation::new(action_info, operation_id.clone()), - }; - - let operation_key = build_operations_key(&operation_id); - - // The values being stored in redis are pretty small so we can do our uploads as oneshots. - // We do not parallelize these uploads since we should always upload an operation followed by the action, - let store = self.store.as_store_driver_pin(); - store - .update_oneshot(operation_key.clone().into(), operation.as_json().into()) - .await - .err_tip(|| "In RedisStateManager::inner_add_action")?; - store - .update_oneshot(action_key.into(), operation_id.to_string().into()) - .await - .err_tip(|| "In RedisStateManager::inner_add_action")?; - - let store_subscription = self.store.clone().subscribe(operation_key.into()).await; - let state = RedisOperationState::new(Arc::new(operation), store_subscription); - Ok(Arc::new(state)) - } - - async fn inner_filter_operations( - &self, - filter: OperationFilter, - ) -> Result { - let handler = &|k: String, v: String| -> Result<(String, Arc), Error> { - let operation = Arc::new( - RedisOperation::from_str(&v) - .err_tip(|| "In RedisStateManager::inner_filter_operations")?, - ); - Ok((k, operation)) - }; - let existing_operations: Vec<(String, Arc)> = self - .list("operations:*", &handler) - .await - .err_tip(|| "In RedisStateManager::inner_filter_operations")?; - let mut v: Vec> = Vec::new(); - for (key, operation) in existing_operations.into_iter() { - if operation.matches_filter(&filter) { - let store_subscription = self.store.clone().subscribe(key.into()).await; - v.push(Arc::new(RedisOperationState::new( - operation, - store_subscription, - ))); - } - } - Ok(Box::pin(futures::stream::iter(v))) - } - - async fn inner_update_operation( - &self, - operation_id: OperationId, - worker_id: Option, - action_stage: Result, - ) -> Result<(), Error> { - let store = self.store.as_store_driver_pin(); - let key = format!("operations:{operation_id}"); - let operation_bytes_res = &store.get_part_unchunked(key.clone().into(), 0, None).await; - let Ok(operation_bytes) = operation_bytes_res else { - return Err(make_input_err!("Received request to update operation {operation_id}, but operation does not exist.")); - }; - - let mut operation = RedisOperation::from_slice(&operation_bytes[..]) - .err_tip(|| "In RedisStateManager::inner_update_operation")?; - match action_stage { - Ok(stage) => { - operation.stage = stage - .try_into() - .err_tip(|| "In RedisStateManager::inner_update_operation")?; - } - Err(e) => operation.last_error = Some(e), - } - - operation.worker_id = worker_id; - store - .update_oneshot(key.into(), operation.as_json().into()) - .await - } - - // TODO: This should be done through store but API endpoint does not exist yet. - async fn inner_remove_operation(&self, operation_id: OperationId) -> Result<(), Error> { - let mut con = self - .get_conn() - .await - .err_tip(|| "In RedisStateManager::inner_remove_operation")?; - let mut pipe = Pipeline::new(); - Ok(pipe - .del(format!("operations:{operation_id}")) - .query_async(&mut con) - .await?) - } -} - -#[async_trait] -impl ClientStateManager for RedisStateManager { - async fn add_action( - &mut self, - action_info: ActionInfo, - ) -> Result, Error> { - self.inner_add_action(action_info).await - } - - async fn filter_operations( - &self, - filter: OperationFilter, - ) -> Result { - self.inner_filter_operations(filter).await - } -} - -#[async_trait] -impl WorkerStateManager for RedisStateManager { - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: WorkerId, - action_stage: Result, - ) -> Result<(), Error> { - self.inner_update_operation(operation_id, Some(worker_id), action_stage) - .await - } -} - -#[async_trait] -impl MatchingEngineStateManager for RedisStateManager { - async fn filter_operations( - &self, - filter: OperationFilter, - ) -> Result { - self.inner_filter_operations(filter).await - } - - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: Option, - action_stage: Result, - ) -> Result<(), Error> { - self.inner_update_operation(operation_id, worker_id, action_stage) - .await - } - - async fn remove_operation(&self, operation_id: OperationId) -> Result<(), Error> { - self.inner_remove_operation(operation_id).await - } -} diff --git a/nativelink-scheduler/src/scheduler_state/awaited_action.rs b/nativelink-scheduler/src/scheduler_state/awaited_action.rs deleted file mode 100644 index bca2ef489..000000000 --- a/nativelink-scheduler/src/scheduler_state/awaited_action.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionState, WorkerId}; -use nativelink_util::metrics_utils::{CollectorState, MetricsComponent}; -use tokio::sync::watch; - -/// An action that is being awaited on and last known state. -pub struct AwaitedAction { - /// The action that is being awaited on. - pub(crate) action_info: Arc, - - /// The current state of the action. - pub(crate) current_state: Arc, - - /// The channel to notify subscribers of state changes when updated, completed or retrying. - pub(crate) notify_channel: watch::Sender>, - - /// Number of attempts the job has been tried. - pub(crate) attempts: usize, - - /// Possible last error set by the worker. If empty and attempts is set, it may be due to - /// something like a worker timeout. - pub(crate) last_error: Option, - - /// Worker that is currently running this action, None if unassigned. - pub(crate) worker_id: Option, -} - -impl MetricsComponent for AwaitedAction { - fn gather_metrics(&self, c: &mut CollectorState) { - c.publish( - "action_digest", - &self.action_info.unique_qualifier.action_name(), - "The digest of the action.", - ); - c.publish( - "current_state", - self.current_state.as_ref(), - "The current stage of the action.", - ); - c.publish( - "attempts", - &self.attempts, - "The number of attempts this action has tried.", - ); - c.publish( - "last_error", - &format!("{:?}", self.last_error), - "The last error this action caused from a retry (if any).", - ); - } -} diff --git a/nativelink-scheduler/src/scheduler_state/client_action_state_result.rs b/nativelink-scheduler/src/scheduler_state/client_action_state_result.rs deleted file mode 100644 index d1044b0a7..000000000 --- a/nativelink-scheduler/src/scheduler_state/client_action_state_result.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use async_trait::async_trait; -use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionState}; -use tokio::sync::watch::Receiver; - -use crate::operation_state_manager::ActionStateResult; - -pub(crate) struct ClientActionStateResult { - rx: Receiver>, -} - -impl ClientActionStateResult { - pub(crate) fn new(mut rx: Receiver>) -> Self { - // Marking the initial value as changed for new or existing actions regardless if - // underlying state has changed. This allows for triggering notification after subscription - // without having to use an explicit notification. - rx.mark_changed(); - Self { rx } - } -} - -#[async_trait] -impl ActionStateResult for ClientActionStateResult { - async fn as_state(&self) -> Result, Error> { - Ok(self.rx.borrow().clone()) - } - - async fn as_receiver(&self) -> Result<&'_ Receiver>, Error> { - Ok(&self.rx) - } - - async fn as_action_info(&self) -> Result, Error> { - unimplemented!() - } -} diff --git a/nativelink-scheduler/src/scheduler_state/completed_action.rs b/nativelink-scheduler/src/scheduler_state/completed_action.rs deleted file mode 100644 index f69f10d1a..000000000 --- a/nativelink-scheduler/src/scheduler_state/completed_action.rs +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::borrow::Borrow; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; -use std::time::SystemTime; - -use nativelink_util::action_messages::{ActionInfoHashKey, ActionState, OperationId}; -use nativelink_util::metrics_utils::{CollectorState, MetricsComponent}; - -/// A completed action that has no listeners. -pub struct CompletedAction { - /// The time the action was completed. - pub(crate) completed_time: SystemTime, - /// The current state of the action when it was completed. - pub(crate) state: Arc, -} - -impl Hash for CompletedAction { - fn hash(&self, state: &mut H) { - OperationId::hash(&self.state.id, state); - } -} - -impl PartialEq for CompletedAction { - fn eq(&self, other: &Self) -> bool { - OperationId::eq(&self.state.id, &other.state.id) - } -} - -impl Eq for CompletedAction {} - -impl Borrow for CompletedAction { - #[inline] - fn borrow(&self) -> &OperationId { - &self.state.id - } -} - -impl Borrow for CompletedAction { - #[inline] - fn borrow(&self) -> &ActionInfoHashKey { - &self.state.id.unique_qualifier - } -} - -impl MetricsComponent for CompletedAction { - fn gather_metrics(&self, c: &mut CollectorState) { - c.publish( - "completed_timestamp", - &self.completed_time, - "The timestamp this action was completed", - ); - c.publish( - "current_state", - self.state.as_ref(), - "The current stage of the action.", - ); - } -} diff --git a/nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs b/nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs deleted file mode 100644 index 0c6a4c74c..000000000 --- a/nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use async_trait::async_trait; -use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionState}; -use tokio::sync::watch; - -use crate::operation_state_manager::ActionStateResult; - -pub struct MatchingEngineActionStateResult { - action_info: Arc, - action_state: watch::Receiver>, -} -impl MatchingEngineActionStateResult { - pub(crate) fn new( - action_info: Arc, - action_state: watch::Receiver>, - ) -> Self { - Self { - action_info, - action_state, - } - } -} - -#[async_trait] -impl ActionStateResult for MatchingEngineActionStateResult { - async fn as_state(&self) -> Result, Error> { - Ok(self.action_state.borrow().clone()) - } - - async fn as_receiver(&self) -> Result<&'_ watch::Receiver>, Error> { - Ok(&self.action_state) - } - - async fn as_action_info(&self) -> Result, Error> { - Ok(self.action_info.clone()) - } -} diff --git a/nativelink-scheduler/src/scheduler_state/metrics.rs b/nativelink-scheduler/src/scheduler_state/metrics.rs deleted file mode 100644 index e9cfe60c5..000000000 --- a/nativelink-scheduler/src/scheduler_state/metrics.rs +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use nativelink_util::metrics_utils::{CollectorState, CounterWithTime}; - -#[derive(Default)] -pub(crate) struct Metrics { - pub(crate) add_action_joined_running_action: CounterWithTime, - pub(crate) add_action_joined_queued_action: CounterWithTime, - pub(crate) add_action_new_action_created: CounterWithTime, - pub(crate) update_action_missing_action_result: CounterWithTime, - pub(crate) update_action_from_wrong_worker: CounterWithTime, - pub(crate) update_action_no_more_listeners: CounterWithTime, - pub(crate) update_action_with_internal_error: CounterWithTime, - pub(crate) update_action_with_internal_error_no_action: CounterWithTime, - pub(crate) update_action_with_internal_error_backpressure: CounterWithTime, - pub(crate) update_action_with_internal_error_from_wrong_worker: CounterWithTime, - pub(crate) workers_evicted: CounterWithTime, - pub(crate) workers_evicted_with_running_action: CounterWithTime, - pub(crate) retry_action: CounterWithTime, - pub(crate) retry_action_max_attempts_reached: CounterWithTime, - pub(crate) retry_action_no_more_listeners: CounterWithTime, - pub(crate) retry_action_but_action_missing: CounterWithTime, -} - -impl Metrics { - pub fn gather_metrics(&self, c: &mut CollectorState) { - { - c.publish_with_labels( - "add_action", - &self.add_action_joined_running_action, - "Stats about add_action().", - vec![("result".into(), "joined_running_action".into())], - ); - c.publish_with_labels( - "add_action", - &self.add_action_joined_queued_action, - "Stats about add_action().", - vec![("result".into(), "joined_queued_action".into())], - ); - c.publish_with_labels( - "add_action", - &self.add_action_new_action_created, - "Stats about add_action().", - vec![("result".into(), "new_action_created".into())], - ); - } - { - c.publish_with_labels( - "update_action_errors", - &self.update_action_missing_action_result, - "Stats about errors when worker sends update_action() to scheduler. These errors are not complete, just the most common.", - vec![("result".into(), "missing_action_result".into())], - ); - c.publish_with_labels( - "update_action_errors", - &self.update_action_from_wrong_worker, - "Stats about errors when worker sends update_action() to scheduler. These errors are not complete, just the most common.", - vec![("result".into(), "from_wrong_worker".into())], - ); - c.publish_with_labels( - "update_action_errors", - &self.update_action_no_more_listeners, - "Stats about errors when worker sends update_action() to scheduler. These errors are not complete, just the most common.", - vec![("result".into(), "no_more_listeners".into())], - ); - } - { - c.publish( - "update_action_with_internal_error", - &self.update_action_with_internal_error, - "The number of times update_action_with_internal_error was triggered.", - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_no_action, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "no_action".into())], - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_backpressure, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "backpressure".into())], - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_from_wrong_worker, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "from_wrong_worker".into())], - ); - } - { - c.publish( - "workers_evicted_total", - &self.workers_evicted, - "The number of workers evicted from scheduler.", - ); - c.publish( - "workers_evicted_with_running_action", - &self.workers_evicted_with_running_action, - "The number of jobs cancelled because worker was evicted from scheduler.", - ); - } - { - c.publish_with_labels( - "retry_action", - &self.retry_action, - "Stats about retry_action().", - vec![("result".into(), "success".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_max_attempts_reached, - "Stats about retry_action().", - vec![("result".into(), "max_attempts_reached".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_no_more_listeners, - "Stats about retry_action().", - vec![("result".into(), "no_more_listeners".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_but_action_missing, - "Stats about retry_action().", - vec![("result".into(), "action_missing".into())], - ); - } - } -} diff --git a/nativelink-scheduler/src/scheduler_state/mod.rs b/nativelink-scheduler/src/scheduler_state/mod.rs deleted file mode 100644 index 359f4f063..000000000 --- a/nativelink-scheduler/src/scheduler_state/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -pub(crate) mod awaited_action; -pub(crate) mod client_action_state_result; -pub(crate) mod completed_action; -pub(crate) mod matching_engine_action_state_result; -pub(crate) mod metrics; -pub(crate) mod state_manager; -pub(crate) mod workers; diff --git a/nativelink-scheduler/src/scheduler_state/state_manager.rs b/nativelink-scheduler/src/scheduler_state/state_manager.rs deleted file mode 100644 index 8dd0def9c..000000000 --- a/nativelink-scheduler/src/scheduler_state/state_manager.rs +++ /dev/null @@ -1,742 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::cmp; -use std::collections::BTreeMap; -use std::sync::Arc; -use std::time::SystemTime; - -use async_trait::async_trait; -use futures::stream; -use hashbrown::{HashMap, HashSet}; -use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; -use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, ExecutionMetadata, - OperationId, WorkerId, -}; -use tokio::sync::watch::error::SendError; -use tokio::sync::{watch, Notify}; -use tracing::{event, Level}; - -use crate::operation_state_manager::{ - ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, - OperationFilter, WorkerStateManager, -}; -use crate::scheduler_state::awaited_action::AwaitedAction; -use crate::scheduler_state::client_action_state_result::ClientActionStateResult; -use crate::scheduler_state::completed_action::CompletedAction; -use crate::scheduler_state::matching_engine_action_state_result::MatchingEngineActionStateResult; -use crate::scheduler_state::metrics::Metrics; -use crate::scheduler_state::workers::Workers; -use crate::worker::WorkerUpdate; - -#[repr(transparent)] -pub(crate) struct StateManager { - pub inner: StateManagerImpl, -} - -impl StateManager { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - queued_actions_set: HashSet>, - queued_actions: BTreeMap, AwaitedAction>, - workers: Workers, - active_actions: HashMap, AwaitedAction>, - recently_completed_actions: HashSet, - metrics: Arc, - max_job_retries: usize, - tasks_or_workers_change_notify: Arc, - ) -> Self { - Self { - inner: StateManagerImpl { - queued_actions_set, - queued_actions, - workers, - active_actions, - recently_completed_actions, - metrics, - max_job_retries, - tasks_or_workers_change_notify, - }, - } - } - - fn immediate_evict_worker(&mut self, worker_id: &WorkerId, err: Error) { - if let Some(mut worker) = self.inner.workers.remove_worker(worker_id) { - self.inner.metrics.workers_evicted.inc(); - // We don't care if we fail to send message to worker, this is only a best attempt. - let _ = worker.notify_update(WorkerUpdate::Disconnect); - // We create a temporary Vec to avoid doubt about a possible code - // path touching the worker.running_action_infos elsewhere. - for action_info in worker.running_action_infos.drain() { - self.inner.metrics.workers_evicted_with_running_action.inc(); - self.retry_action(&action_info, worker_id, err.clone()); - } - // Note: Calling this multiple times is very cheap, it'll only trigger `do_try_match` once. - self.inner.tasks_or_workers_change_notify.notify_one(); - } - } - - fn retry_action(&mut self, action_info: &Arc, worker_id: &WorkerId, err: Error) { - match self.inner.active_actions.remove(action_info) { - Some(running_action) => { - let mut awaited_action = running_action; - let send_result = if awaited_action.attempts >= self.inner.max_job_retries { - self.inner.metrics.retry_action_max_attempts_reached.inc(); - Arc::make_mut(&mut awaited_action.current_state).stage = ActionStage::Completed(ActionResult { - execution_metadata: ExecutionMetadata { - worker: format!("{worker_id}"), - ..ExecutionMetadata::default() - }, - error: Some(err.merge(make_err!( - Code::Internal, - "Job cancelled because it attempted to execute too many times and failed" - ))), - ..ActionResult::default() - }); - awaited_action - .notify_channel - .send(awaited_action.current_state.clone()) - // Do not put the action back in the queue here, as this action attempted to run too many - // times. - } else { - self.inner.metrics.retry_action.inc(); - Arc::make_mut(&mut awaited_action.current_state).stage = ActionStage::Queued; - let send_result = awaited_action - .notify_channel - .send(awaited_action.current_state.clone()); - self.inner.queued_actions_set.insert(action_info.clone()); - self.inner - .queued_actions - .insert(action_info.clone(), awaited_action); - send_result - }; - - if send_result.is_err() { - self.inner.metrics.retry_action_no_more_listeners.inc(); - // Don't remove this task, instead we keep them around for a bit just in case - // the client disconnected and will reconnect and ask for same job to be executed - // again. - event!( - Level::WARN, - ?action_info, - ?worker_id, - "Action has no more listeners during evict_worker()" - ); - } - } - None => { - self.inner.metrics.retry_action_but_action_missing.inc(); - event!( - Level::ERROR, - ?action_info, - ?worker_id, - "Worker stated it was running an action, but it was not in the active_actions" - ); - } - } - } -} - -/// StateManager is responsible for maintaining the state of the scheduler. Scheduler state -/// includes the actions that are queued, active, and recently completed. It also includes the -/// workers that are available to execute actions based on allocation strategy. -pub(crate) struct StateManagerImpl { - // TODO(adams): Move `queued_actions_set` and `queued_actions` into a single struct that - // provides a unified interface for interacting with the two containers. - - // Important: `queued_actions_set` and `queued_actions` are two containers that provide - // different search and sort capabilities. We are using the two different containers to - // optimize different use cases. `HashSet` is used to look up actions in O(1) time. The - // `BTreeMap` is used to sort actions in O(log n) time based on priority and timestamp. - // These two fields must be kept in-sync, so if you modify one, you likely need to modify the - // other. - /// A `HashSet` of all actions that are queued. A hashset is used to find actions that are queued - /// in O(1) time. This set allows us to find and join on new actions onto already existing - /// (or queued) actions where insert timestamp of queued actions is not known. Using an - /// additional `HashSet` will prevent us from having to iterate the `BTreeMap` to find actions. - /// - /// Important: `queued_actions_set` and `queued_actions` must be kept in sync. - pub(crate) queued_actions_set: HashSet>, - - /// A BTreeMap of sorted actions that are primarily based on priority and insert timestamp. - /// `ActionInfo` implements `Ord` that defines the `cmp` function for order. Using a BTreeMap - /// gives us to sorted actions that are queued in O(log n) time. - /// - /// Important: `queued_actions_set` and `queued_actions` must be kept in sync. - pub(crate) queued_actions: BTreeMap, AwaitedAction>, - - /// A `Workers` pool that contains all workers that are available to execute actions in a priority - /// order based on the allocation strategy. - pub(crate) workers: Workers, - - /// A map of all actions that are active. A hashmap is used to find actions that are active in - /// O(1) time. The key is the `ActionInfo` struct. The value is the `AwaitedAction` struct. - pub(crate) active_actions: HashMap, AwaitedAction>, - - /// These actions completed recently but had no listener, they might have - /// completed while the caller was thinking about calling wait_execution, so - /// keep their completion state around for a while to send back. - /// TODO(#192) Revisit if this is the best way to handle recently completed actions. - pub(crate) recently_completed_actions: HashSet, - - pub(crate) metrics: Arc, - - /// Default times a job can retry before failing. - pub(crate) max_job_retries: usize, - - /// Notify task<->worker matching engine that work needs to be done. - pub(crate) tasks_or_workers_change_notify: Arc, -} - -impl StateManager { - /// Modifies the `stage` of `current_state` within `AwaitedAction`. Sends notification channel - /// the new state. - /// - /// - /// # Discussion - /// - /// The use of `Arc::make_mut` is potentially dangerous because it clones the data and - /// invalidates all weak references to it. However, in this context, it is considered - /// safe because the data is going to be re-sent back out. The primary reason for using - /// `Arc` is to reduce the number of copies, not to enforce read-only access. This approach - /// ensures that all downstream components receive the same pointer. If an update occurs - /// while another thread is operating on the data, it is acceptable, since the other thread - /// will receive another update with the new version. - /// - pub(crate) fn mutate_stage( - awaited_action: &mut AwaitedAction, - action_stage: ActionStage, - ) -> Result<(), SendError>> { - Arc::make_mut(&mut awaited_action.current_state).stage = action_stage; - awaited_action - .notify_channel - .send(awaited_action.current_state.clone()) - } - - /// Modifies the `priority` of `action_info` within `ActionInfo`. - /// - fn mutate_priority(action_info: &mut Arc, priority: i32) { - Arc::make_mut(action_info).priority = priority; - } - - /// Updates the `last_error` field of the provided `AwaitedAction` and sends the current state - /// to the notify channel. - /// - fn mutate_last_error( - awaited_action: &mut AwaitedAction, - last_error: Error, - ) -> Result<(), SendError>> { - awaited_action.last_error = Some(last_error); - awaited_action - .notify_channel - .send(awaited_action.current_state.clone()) - } - - /// Notifies the specified worker to run the given action and handles errors by evicting - /// the worker if the notification fails. - /// - /// # Note - /// - /// Intended utility function for matching engine. - /// - /// # Errors - /// - /// This function will return an error if the notification to the worker fails, and in that case, - /// the worker will be immediately evicted from the system. - /// - async fn worker_notify_run_action( - &mut self, - worker_id: WorkerId, - action_info: Arc, - ) -> Result<(), Error> { - if let Some(worker) = self.inner.workers.workers.get_mut(&worker_id) { - let notify_worker_result = - worker.notify_update(WorkerUpdate::RunAction(action_info.clone())); - - if notify_worker_result.is_err() { - event!( - Level::WARN, - ?worker_id, - ?action_info, - ?notify_worker_result, - "Worker command failed, removing worker", - ); - - let err = make_err!( - Code::Internal, - "Worker command failed, removing worker {worker_id} -- {notify_worker_result:?}", - ); - - self.immediate_evict_worker(&worker_id, err.clone()); - return Err(err); - } - } - Ok(()) - } - - /// Sets the action stage for the given `AwaitedAction` based on the result of the provided - /// `action_stage`. If the `action_stage` is an error, it updates the `last_error` field - /// and logs a warning. - /// - /// # Note - /// - /// Intended utility function for matching engine. - /// - /// # Errors - /// - /// This function will return an error if updating the state of the `awaited_action` fails. - /// - async fn worker_set_action_stage( - awaited_action: &mut AwaitedAction, - action_stage: Result, - worker_id: WorkerId, - ) -> Result<(), SendError>> { - match action_stage { - Ok(action_stage) => StateManager::mutate_stage(awaited_action, action_stage), - Err(e) => { - event!( - Level::WARN, - ?worker_id, - "Action stage setting error during do_try_match()" - ); - StateManager::mutate_last_error(awaited_action, e) - } - } - } - - /// Marks the specified action as active, assigns it to the given worker, and updates the - /// action stage. This function removes the action from the queue, updates the action's state - /// or error, and inserts it into the set of active actions. - /// - /// # Note - /// - /// Intended utility function for matching engine. - /// - /// # Errors - /// - /// This function will return an error if it fails to update the action's state or if any other - /// error occurs during the process. - /// - async fn worker_set_as_active( - &mut self, - action_info: Arc, - worker_id: WorkerId, - action_stage: Result, - ) -> Result<(), Error> { - if let Some((action_info, mut awaited_action)) = - self.inner.queued_actions.remove_entry(action_info.as_ref()) - { - assert!( - self.inner.queued_actions_set.remove(&action_info), - "queued_actions_set should always have same keys as queued_actions" - ); - - awaited_action.worker_id = Some(worker_id); - - let send_result = - StateManager::worker_set_action_stage(&mut awaited_action, action_stage, worker_id) - .await; - - if send_result.is_err() { - event!( - Level::WARN, - ?action_info, - ?worker_id, - "Action has no more listeners during do_try_match()" - ); - } - - awaited_action.attempts += 1; - self.inner - .active_actions - .insert(action_info, awaited_action); - Ok(()) - } else { - Err(make_err!( - Code::Internal, - "Action not found in queued_actions_set or queued_actions" - )) - } - } - - fn update_action_with_internal_error( - &mut self, - worker_id: &WorkerId, - action_info_hash_key: ActionInfoHashKey, - err: Error, - ) { - self.inner.metrics.update_action_with_internal_error.inc(); - let Some((action_info, mut running_action)) = self - .inner - .active_actions - .remove_entry(&action_info_hash_key) - else { - self.inner - .metrics - .update_action_with_internal_error_no_action - .inc(); - event!( - Level::ERROR, - ?action_info_hash_key, - ?worker_id, - "Could not find action info in active actions" - ); - return; - }; - - let due_to_backpressure = err.code == Code::ResourceExhausted; - // Don't count a backpressure failure as an attempt for an action. - if due_to_backpressure { - self.inner - .metrics - .update_action_with_internal_error_backpressure - .inc(); - running_action.attempts -= 1; - } - let Some(running_action_worker_id) = running_action.worker_id else { - event!( - Level::ERROR, - ?action_info_hash_key, - ?worker_id, - "Got a result from a worker that should not be running the action, Removing worker. Expected action to be unassigned got worker", - ); - return; - }; - if running_action_worker_id == *worker_id { - // Don't set the error on an action that's running somewhere else. - event!( - Level::WARN, - ?action_info_hash_key, - ?worker_id, - ?running_action_worker_id, - ?err, - "Internal worker error", - ); - running_action.last_error = Some(err.clone()); - } else { - self.inner - .metrics - .update_action_with_internal_error_from_wrong_worker - .inc(); - } - - // Now put it back. retry_action() needs it to be there to send errors properly. - self.inner - .active_actions - .insert(action_info.clone(), running_action); - - // Clear this action from the current worker. - if let Some(worker) = self.inner.workers.workers.get_mut(worker_id) { - let was_paused = !worker.can_accept_work(); - // This unpauses, but since we're completing with an error, don't - // unpause unless all actions have completed. - worker.complete_action(&action_info); - // Only pause if there's an action still waiting that will unpause. - if (was_paused || due_to_backpressure) && worker.has_actions() { - worker.is_paused = true; - } - } - - // Re-queue the action or fail on max attempts. - self.retry_action(&action_info, worker_id, err); - self.inner.tasks_or_workers_change_notify.notify_one(); - } -} - -#[async_trait] -impl ClientStateManager for StateManager { - async fn add_action( - &mut self, - action_info: ActionInfo, - ) -> Result, Error> { - // Check to see if the action is running, if it is and cacheable, merge the actions. - if let Some(running_action) = self.inner.active_actions.get_mut(&action_info) { - self.inner.metrics.add_action_joined_running_action.inc(); - self.inner.tasks_or_workers_change_notify.notify_one(); - return Ok(Arc::new(ClientActionStateResult::new( - running_action.notify_channel.subscribe(), - ))); - } - - // Check to see if the action is queued, if it is and cacheable, merge the actions. - if let Some(mut arc_action_info) = self.inner.queued_actions_set.take(&action_info) { - let (original_action_info, queued_action) = self - .inner - .queued_actions - .remove_entry(&arc_action_info) - .err_tip(|| "Internal error queued_actions and queued_actions_set should match")?; - self.inner.metrics.add_action_joined_queued_action.inc(); - - let new_priority = cmp::max(original_action_info.priority, action_info.priority); - drop(original_action_info); // This increases the chance Arc::make_mut won't copy. - - // In the event our task is higher priority than the one already scheduled, increase - // the priority of the scheduled one. - StateManager::mutate_priority(&mut arc_action_info, new_priority); - - let result = Arc::new(ClientActionStateResult::new( - queued_action.notify_channel.subscribe(), - )); - - // Even if we fail to send our action to the client, we need to add this action back to the - // queue because it was remove earlier. - self.inner - .queued_actions - .insert(arc_action_info.clone(), queued_action); - self.inner.queued_actions_set.insert(arc_action_info); - self.inner.tasks_or_workers_change_notify.notify_one(); - return Ok(result); - } - - self.inner.metrics.add_action_new_action_created.inc(); - // Action needs to be added to queue or is not cacheable. - let action_info = Arc::new(action_info); - - let operation_id = OperationId::new(action_info.unique_qualifier.clone()); - - let current_state = Arc::new(ActionState { - stage: ActionStage::Queued, - id: operation_id, - }); - - let (tx, rx) = watch::channel(current_state.clone()); - - self.inner.queued_actions_set.insert(action_info.clone()); - self.inner.queued_actions.insert( - action_info.clone(), - AwaitedAction { - action_info, - current_state, - notify_channel: tx, - attempts: 0, - last_error: None, - worker_id: None, - }, - ); - self.inner.tasks_or_workers_change_notify.notify_one(); - return Ok(Arc::new(ClientActionStateResult::new(rx))); - } - - async fn filter_operations( - &self, - filter: OperationFilter, - ) -> Result { - // TODO(adams): Build out a proper filter for other fields for state, at the moment - // this only supports the unique qualifier. - let unique_qualifier = &filter - .unique_qualifier - .err_tip(|| "No unique qualifier provided")?; - let maybe_awaited_action = self - .inner - .queued_actions_set - .get(unique_qualifier) - .and_then(|action_info| self.inner.queued_actions.get(action_info)) - .or_else(|| self.inner.active_actions.get(unique_qualifier)); - - let Some(awaited_action) = maybe_awaited_action else { - return Ok(Box::pin(stream::empty())); - }; - - let rx = awaited_action.notify_channel.subscribe(); - let action_result: [Arc; 1] = - [Arc::new(ClientActionStateResult::new(rx))]; - Ok(Box::pin(stream::iter(action_result))) - } -} - -#[async_trait] -impl WorkerStateManager for StateManager { - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: WorkerId, - action_stage: Result, - ) -> Result<(), Error> { - match action_stage { - Ok(action_stage) => { - let action_info_hash_key = operation_id.unique_qualifier; - if !action_stage.has_action_result() { - self.inner.metrics.update_action_missing_action_result.inc(); - event!( - Level::ERROR, - ?action_info_hash_key, - ?worker_id, - ?action_stage, - "Worker sent error while updating action. Removing worker" - ); - let err = make_err!( - Code::Internal, - "Worker '{worker_id}' set the action_stage of running action {action_info_hash_key:?} to {action_stage:?}. Removing worker.", - ); - self.immediate_evict_worker(&worker_id, err.clone()); - return Err(err); - } - - let (action_info, mut running_action) = self - .inner - .active_actions - .remove_entry(&action_info_hash_key) - .err_tip(|| { - format!("Could not find action info in active actions : {action_info_hash_key:?}") - })?; - - if running_action.worker_id != Some(worker_id) { - self.inner.metrics.update_action_from_wrong_worker.inc(); - let err = match running_action.worker_id { - - Some(running_action_worker_id) => make_err!( - Code::Internal, - "Got a result from a worker that should not be running the action, Removing worker. Expected worker {running_action_worker_id} got worker {worker_id}", - ), - None => make_err!( - Code::Internal, - "Got a result from a worker that should not be running the action, Removing worker. Expected action to be unassigned got worker {worker_id}", - ), - }; - event!( - Level::ERROR, - ?action_info, - ?worker_id, - ?running_action.worker_id, - ?err, - "Got a result from a worker that should not be running the action, Removing worker" - ); - // First put it back in our active_actions or we will drop the task. - self.inner - .active_actions - .insert(action_info, running_action); - self.immediate_evict_worker(&worker_id, err.clone()); - return Err(err); - } - - let send_result = StateManager::mutate_stage(&mut running_action, action_stage); - - if !running_action.current_state.stage.is_finished() { - if send_result.is_err() { - self.inner.metrics.update_action_no_more_listeners.inc(); - event!( - Level::WARN, - ?action_info, - ?worker_id, - "Action has no more listeners during update_action()" - ); - } - // If the operation is not finished it means the worker is still working on it, so put it - // back or else we will lose track of the task. - self.inner - .active_actions - .insert(action_info, running_action); - - self.inner.tasks_or_workers_change_notify.notify_one(); - return Ok(()); - } - - // Keep in case this is asked for soon. - self.inner - .recently_completed_actions - .insert(CompletedAction { - completed_time: SystemTime::now(), - state: running_action.current_state, - }); - - let worker = self - .inner - .workers - .workers - .get_mut(&worker_id) - .ok_or_else(|| { - make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) - })?; - worker.complete_action(&action_info); - self.inner.tasks_or_workers_change_notify.notify_one(); - Ok(()) - } - Err(e) => { - self.update_action_with_internal_error( - &worker_id, - operation_id.unique_qualifier, - e.clone(), - ); - return Err(e); - } - } - } -} - -#[async_trait] -impl MatchingEngineStateManager for StateManager { - async fn filter_operations( - &self, - _filter: OperationFilter, // TODO(adam): reference filter - ) -> Result { - // TODO(adams): use OperationFilter vs directly encoding it. - let action_infos = - self.inner - .queued_actions - .iter() - .rev() - .map(|(action_info, awaited_action)| { - let cloned_action_info = action_info.clone(); - Arc::new(MatchingEngineActionStateResult::new( - cloned_action_info, - awaited_action.notify_channel.subscribe(), - )) as Arc - }); - - let action_infos: Vec> = action_infos.collect(); - Ok(Box::pin(stream::iter(action_infos))) - } - - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: Option, - action_stage: Result, - ) -> Result<(), Error> { - if let Some(action_info) = self - .inner - .queued_actions_set - .get(&operation_id.unique_qualifier) - { - if let Some(worker_id) = worker_id { - let action_info = action_info.clone(); - self.worker_notify_run_action(worker_id, action_info.clone()) - .await?; - self.worker_set_as_active(action_info, worker_id, action_stage) - .await?; - } else { - event!( - Level::WARN, - ?operation_id, - ?worker_id, - "No worker found in do_try_match()" - ); - } - } else { - event!( - Level::WARN, - ?operation_id, - ?worker_id, - "No action info found in do_try_match()" - ); - } - - Ok(()) - } - - async fn remove_operation(&self, _operation_id: OperationId) -> Result<(), Error> { - todo!() - } -} diff --git a/nativelink-scheduler/src/scheduler_state/workers.rs b/nativelink-scheduler/src/scheduler_state/workers.rs deleted file mode 100644 index 25e78e2bb..000000000 --- a/nativelink-scheduler/src/scheduler_state/workers.rs +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use lru::LruCache; -use nativelink_config::schedulers::WorkerAllocationStrategy; -use nativelink_error::{error_if, make_input_err, Error, ResultExt}; -use nativelink_util::action_messages::WorkerId; -use nativelink_util::platform_properties::PlatformProperties; -use tracing::{event, Level}; - -use crate::worker::{Worker, WorkerTimestamp}; - -/// A collection of workers that are available to run tasks. -pub struct Workers { - /// A `LruCache` of workers availabled based on `allocation_strategy`. - pub(crate) workers: LruCache, - /// The allocation strategy for workers. - pub(crate) allocation_strategy: WorkerAllocationStrategy, -} - -impl Workers { - pub(crate) fn new(allocation_strategy: WorkerAllocationStrategy) -> Self { - Self { - workers: LruCache::unbounded(), - allocation_strategy, - } - } - - /// Refreshes the lifetime of the worker with the given timestamp. - pub(crate) fn refresh_lifetime( - &mut self, - worker_id: &WorkerId, - timestamp: WorkerTimestamp, - ) -> Result<(), Error> { - let worker = self.workers.get_mut(worker_id).ok_or_else(|| { - make_input_err!( - "Worker not found in worker map in refresh_lifetime() {}", - worker_id - ) - })?; - error_if!( - worker.last_update_timestamp > timestamp, - "Worker already had a timestamp of {}, but tried to update it with {}", - worker.last_update_timestamp, - timestamp - ); - worker.last_update_timestamp = timestamp; - Ok(()) - } - - /// Adds a worker to the pool. - /// Note: This function will not do any task matching. - pub(crate) fn add_worker(&mut self, worker: Worker) -> Result<(), Error> { - let worker_id = worker.id; - self.workers.put(worker_id, worker); - - // Worker is not cloneable, and we do not want to send the initial connection results until - // we have added it to the map, or we might get some strange race conditions due to the way - // the multi-threaded runtime works. - let worker = self.workers.peek_mut(&worker_id).unwrap(); - let res = worker - .send_initial_connection_result() - .err_tip(|| "Failed to send initial connection result to worker"); - if let Err(err) = &res { - event!( - Level::ERROR, - ?worker_id, - ?err, - "Worker connection appears to have been closed while adding to pool" - ); - } - res - } - - /// Removes worker from pool. - /// Note: The caller is responsible for any rescheduling of any tasks that might be - /// running. - pub(crate) fn remove_worker(&mut self, worker_id: &WorkerId) -> Option { - self.workers.pop(worker_id) - } - - // Attempts to find a worker that is capable of running this action. - // TODO(blaise.bruer) This algorithm is not very efficient. Simple testing using a tree-like - // structure showed worse performance on a 10_000 worker * 7 properties * 1000 queued tasks - // simulation of worst cases in a single threaded environment. - pub(crate) fn find_worker_for_action( - &self, - platform_properties: &PlatformProperties, - ) -> Option { - let mut workers_iter = self.workers.iter(); - let workers_iter = match self.allocation_strategy { - // Use rfind to get the least recently used that satisfies the properties. - WorkerAllocationStrategy::least_recently_used => workers_iter.rfind(|(_, w)| { - w.can_accept_work() && platform_properties.is_satisfied_by(&w.platform_properties) - }), - // Use find to get the most recently used that satisfies the properties. - WorkerAllocationStrategy::most_recently_used => workers_iter.find(|(_, w)| { - w.can_accept_work() && platform_properties.is_satisfied_by(&w.platform_properties) - }), - }; - workers_iter.map(|(_, w)| &w.id).copied() - } -} diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 20ced95cf..a0a27cf18 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -12,43 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; use std::pin::Pin; -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use std::time::{Instant, SystemTime}; -use async_lock::{Mutex, MutexGuard}; use async_trait::async_trait; -use futures::{Future, Stream}; -use hashbrown::{HashMap, HashSet}; -use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; +use futures::Future; +use nativelink_config::stores::EvictionPolicy; +use nativelink_error::{Error, ResultExt}; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, ExecutionMetadata, - OperationId, WorkerId, + ActionInfo, ActionStage, ActionState, ClientOperationId, OperationId, WorkerId, }; -use nativelink_util::metrics_utils::{ - AsyncCounterWrapper, Collector, CollectorState, CounterWithTime, FuncCounterWrapper, - MetricsComponent, Registry, +use nativelink_util::metrics_utils::Registry; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, + OperationFilter, OperationStageFlags, OrderDirection, }; -use nativelink_util::platform_properties::PlatformPropertyValue; use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; -use tokio::sync::{watch, Notify}; +use tokio::sync::Notify; use tokio::time::Duration; use tokio_stream::StreamExt; use tracing::{event, Level}; -use crate::action_scheduler::ActionScheduler; -use crate::operation_state_manager::{ - ActionStateResult, ClientStateManager, MatchingEngineStateManager, OperationFilter, - OperationStageFlags, WorkerStateManager, -}; +use crate::action_scheduler::{ActionListener, ActionScheduler}; +use crate::api_worker_scheduler::ApiWorkerScheduler; +use crate::memory_awaited_action_db::MemoryAwaitedActionDb; use crate::platform_property_manager::PlatformPropertyManager; -use crate::scheduler_state::metrics::Metrics as SchedulerMetrics; -use crate::scheduler_state::state_manager::StateManager; -use crate::scheduler_state::workers::Workers; -use crate::worker::{Worker, WorkerTimestamp, WorkerUpdate}; +use crate::simple_scheduler_state_manager::SimpleSchedulerStateManager; +use crate::worker::{Worker, WorkerTimestamp}; use crate::worker_scheduler::WorkerScheduler; /// Default timeout for workers in seconds. @@ -57,326 +48,212 @@ const DEFAULT_WORKER_TIMEOUT_S: u64 = 5; /// Default timeout for recently completed actions in seconds. /// If this changes, remember to change the documentation in the config. -const DEFAULT_RETAIN_COMPLETED_FOR_S: u64 = 60; +const DEFAULT_RETAIN_COMPLETED_FOR_S: u32 = 60; /// Default times a job can retry before failing. /// If this changes, remember to change the documentation in the config. const DEFAULT_MAX_JOB_RETRIES: usize = 3; -struct SimpleSchedulerImpl { - /// The manager responsible for holding the state of actions and workers. - state_manager: StateManager, - /// The duration that actions are kept in recently_completed_actions for. - retain_completed_for: Duration, - /// Timeout of how long to evict workers if no response in this given amount of time in seconds. - worker_timeout_s: u64, - /// Default times a job can retry before failing. - max_job_retries: usize, - metrics: Arc, +struct SimpleSchedulerActionListener { + client_operation_id: ClientOperationId, + action_state_result: Box, } -impl SimpleSchedulerImpl { - /// Attempts to find a worker to execute an action and begins executing it. - /// If an action is already running that is cacheable it may merge this action - /// with the results and state changes of the already running action. - /// If the task cannot be executed immediately it will be queued for execution - /// based on priority and other metrics. - /// All further updates to the action will be provided through `listener`. - async fn add_action( - &mut self, - action_info: ActionInfo, - ) -> Result>, Error> { - let add_action_result = self.state_manager.add_action(action_info).await?; - add_action_result.as_receiver().await.cloned() +impl SimpleSchedulerActionListener { + fn new( + client_operation_id: ClientOperationId, + action_state_result: Box, + ) -> Self { + Self { + client_operation_id, + action_state_result, + } } +} - fn clean_recently_completed_actions(&mut self) { - let expiry_time = SystemTime::now() - .checked_sub(self.retain_completed_for) - .unwrap(); - self.state_manager - .inner - .recently_completed_actions - .retain(|action| action.completed_time > expiry_time); +impl ActionListener for SimpleSchedulerActionListener { + fn client_operation_id(&self) -> &ClientOperationId { + &self.client_operation_id } - fn find_recently_completed_action( - &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - self.state_manager - .inner - .recently_completed_actions - .get(unique_qualifier) - .map(|action| watch::channel(action.state.clone()).1) + fn changed( + &mut self, + ) -> Pin, Error>> + Send + '_>> { + Box::pin(async move { + let action_state = self + .action_state_result + .changed() + .await + .err_tip(|| "In SimpleSchedulerActionListener::changed getting receiver")?; + Ok(action_state) + }) } +} - async fn find_existing_action( - &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - let filter_result = ::filter_operations( - &self.state_manager, - OperationFilter { - stages: OperationStageFlags::Any, - operation_id: None, - worker_id: None, - action_digest: None, - worker_update_before: None, - completed_before: None, - last_client_update_before: None, - unique_qualifier: Some(unique_qualifier.clone()), - order_by: None, - }, - ) - .await; - - let mut stream = filter_result.ok()?; - if let Some(result) = stream.next().await { - result.as_receiver().await.ok().cloned() - } else { - None - } - } +/// Engine used to manage the queued/running tasks and relationship with +/// the worker nodes. All state on how the workers and actions are interacting +/// should be held in this struct. +pub struct SimpleScheduler { + /// Manager for matching engine side of the state manager. + matching_engine_state_manager: Arc, - fn retry_action(&mut self, action_info: &Arc, worker_id: &WorkerId, err: Error) { - match self.state_manager.inner.active_actions.remove(action_info) { - Some(running_action) => { - let mut awaited_action = running_action; - let send_result = if awaited_action.attempts >= self.max_job_retries { - self.metrics.retry_action_max_attempts_reached.inc(); - - StateManager::mutate_stage(&mut awaited_action, ActionStage::Completed(ActionResult { - execution_metadata: ExecutionMetadata { - worker: format!("{worker_id}"), - ..ExecutionMetadata::default() - }, - error: Some(err.merge(make_err!( - Code::Internal, - "Job cancelled because it attempted to execute too many times and failed" - ))), - ..ActionResult::default() - })) - // Do not put the action back in the queue here, as this action attempted to run too many - // times. - } else { - self.metrics.retry_action.inc(); - let send_result = - StateManager::mutate_stage(&mut awaited_action, ActionStage::Queued); - self.state_manager - .inner - .queued_actions_set - .insert(action_info.clone()); - self.state_manager - .inner - .queued_actions - .insert(action_info.clone(), awaited_action); - send_result - }; - - if send_result.is_err() { - self.metrics.retry_action_no_more_listeners.inc(); - // Don't remove this task, instead we keep them around for a bit just in case - // the client disconnected and will reconnect and ask for same job to be executed - // again. - event!( - Level::WARN, - ?action_info, - ?worker_id, - "Action has no more listeners during evict_worker()" - ); - } - } - None => { - self.metrics.retry_action_but_action_missing.inc(); - event!( - Level::ERROR, - ?action_info, - ?worker_id, - "Worker stated it was running an action, but it was not in the active_actions" - ); - } - } - } + /// Manager for client state of this scheduler. + client_state_manager: Arc, - /// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it. - fn immediate_evict_worker(&mut self, worker_id: &WorkerId, err: Error) { - if let Some(mut worker) = self.state_manager.inner.workers.remove_worker(worker_id) { - self.metrics.workers_evicted.inc(); - // We don't care if we fail to send message to worker, this is only a best attempt. - let _ = worker.notify_update(WorkerUpdate::Disconnect); - // We create a temporary Vec to avoid doubt about a possible code - // path touching the worker.running_action_infos elsewhere. - for action_info in worker.running_action_infos.drain() { - self.metrics.workers_evicted_with_running_action.inc(); - self.retry_action(&action_info, worker_id, err.clone()); - } - } - // Note: Calling this many time is very cheap, it'll only trigger `do_try_match` once. - self.state_manager - .inner - .tasks_or_workers_change_notify - .notify_one(); - } + /// Manager for platform of this scheduler. + platform_property_manager: Arc, + + /// A `Workers` pool that contains all workers that are available to execute actions in a priority + /// order based on the allocation strategy. + worker_scheduler: Arc, + + /// Background task that tries to match actions to workers. If this struct + /// is dropped the spawn will be cancelled as well. + _task_worker_matching_spawn: JoinHandleDropGuard<()>, +} - /// Sets if the worker is draining or not. - fn set_drain_worker(&mut self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> { - let worker = self - .state_manager - .inner - .workers - .workers - .get_mut(&worker_id) - .err_tip(|| format!("Worker {worker_id} doesn't exist in the pool"))?; - self.metrics.workers_drained.inc(); - worker.is_draining = is_draining; - self.state_manager - .inner - .tasks_or_workers_change_notify - .notify_one(); - Ok(()) +impl SimpleScheduler { + /// Attempts to find a worker to execute an action and begins executing it. + /// If an action is already running that is cacheable it may merge this + /// action with the results and state changes of the already running + /// action. If the task cannot be executed immediately it will be queued + /// for execution based on priority and other metrics. + /// All further updates to the action will be provided through the returned + /// value. + async fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result>, Error> { + let add_action_result = self + .client_state_manager + .add_action(client_operation_id.clone(), action_info) + .await?; + + Ok(Box::pin(SimpleSchedulerActionListener::new( + client_operation_id, + add_action_result, + ))) } - async fn get_queued_operations( + async fn find_by_client_operation_id( &self, - ) -> Result> + Send>>, Error> - { - ::filter_operations( - &self.state_manager, - OperationFilter { - stages: OperationStageFlags::Queued, - operation_id: None, - worker_id: None, - action_digest: None, - worker_update_before: None, - completed_before: None, - last_client_update_before: None, - unique_qualifier: None, - order_by: None, - }, - ) - .await + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { + let filter = OperationFilter { + client_operation_id: Some(client_operation_id.clone()), + ..Default::default() + }; + let filter_result = self.client_state_manager.filter_operations(filter).await; + + let mut stream = filter_result + .err_tip(|| "In SimpleScheduler::find_by_client_operation_id getting filter result")?; + let Some(action_state_result) = stream.next().await else { + return Ok(None); + }; + Ok(Some(Box::pin(SimpleSchedulerActionListener::new( + client_operation_id.clone(), + action_state_result, + )))) } - // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we can create a map - // of capabilities of each worker and then try and match the actions to the worker using - // the map lookup (ie. map reduce). - async fn do_try_match(&mut self) { - // TODO(blaise.bruer) This is a bit difficult because of how rust's borrow checker gets in - // the way. We need to conditionally remove items from the `queued_action`. Rust is working - // to add `drain_filter`, which would in theory solve this problem, but because we need - // to iterate the items in reverse it becomes more difficult (and it is currently an - // unstable feature [see: https://github.com/rust-lang/rust/issues/70530]). - - let action_state_results = self.get_queued_operations().await; - - match action_state_results { - Ok(mut stream) => { - while let Some(action_state_result) = stream.next().await { - let as_state_result = action_state_result.as_state().await; - let Ok(state) = as_state_result else { - let _ = as_state_result.inspect_err(|err| { - event!( - Level::ERROR, - ?err, - "Failed to get action_info from as_state_result stream" - ); - }); - continue; - }; - let action_state_result = action_state_result.as_action_info().await; - let Ok(action_info) = action_state_result else { - let _ = action_state_result.inspect_err(|err| { - event!( - Level::ERROR, - ?err, - "Failed to get action_info from action_state_results stream" - ); - }); - continue; - }; - - let maybe_worker_id: Option = { - self.state_manager - .inner - .workers - .find_worker_for_action(&action_info.platform_properties) - }; - - let operation_id = state.id.clone(); - let ret = ::update_operation( - &mut self.state_manager, - operation_id.clone(), - maybe_worker_id, - Ok(ActionStage::Executing), - ) - .await; - - if let Err(e) = ret { - event!( - Level::ERROR, - ?e, - "update operation failed for {}", - operation_id - ); - } + async fn get_queued_operations(&self) -> Result { + let filter = OperationFilter { + stages: OperationStageFlags::Queued, + order_by_priority_direction: Some(OrderDirection::Desc), + ..Default::default() + }; + self.matching_engine_state_manager + .filter_operations(filter) + .await + .err_tip(|| "In SimpleScheduler::get_queued_operations getting filter result") + } + + // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we + // can create a map of capabilities of each worker and then try and match + // the actions to the worker using the map lookup (ie. map reduce). + async fn do_try_match(&self) -> Result<(), Error> { + async fn match_action_to_worker( + action_state_result: &dyn ActionStateResult, + workers: &ApiWorkerScheduler, + matching_engine_state_manager: &dyn MatchingEngineStateManager, + ) -> Result<(), Error> { + let action_info = action_state_result + .as_action_info() + .await + .err_tip(|| "Failed to get action_info from as_action_info_result stream")?; + + // Try to find a worker for the action. + let worker_id = { + let platform_properties = &action_info.platform_properties; + match workers.find_worker_for_action(platform_properties).await { + Some(worker_id) => worker_id, + // If we could not find a worker for the action, + // we have nothing to do. + None => return Ok(()), } - } - Err(e) => { - event!(Level::ERROR, ?e, "stream error in do_try_match"); + }; + + // Extract the operation_id from the action_state. + let operation_id = { + let action_state = action_state_result + .as_state() + .await + .err_tip(|| "Failed to get action_info from as_state_result stream")?; + action_state.id.clone() + }; + + // Tell the matching engine that the operation is being assigned to a worker. + matching_engine_state_manager + .assign_operation(&operation_id, Ok(&worker_id)) + .await + .err_tip(|| "Failed to assign operation in do_try_match")?; + + // Notify the worker to run the action. + { + workers + .worker_notify_run_action(worker_id, operation_id, action_info) + .await + .err_tip(|| { + "Failed to run worker_notify_run_action in SimpleScheduler::do_try_match" + }) } } - } - async fn update_action( - &mut self, - worker_id: &WorkerId, - action_info_hash_key: ActionInfoHashKey, - action_stage: Result, - ) -> Result<(), Error> { - let update_operation_result = ::update_operation( - &mut self.state_manager, - OperationId::new(action_info_hash_key.clone()), - *worker_id, - action_stage, - ) - .await; - if let Err(e) = &update_operation_result { - event!( - Level::ERROR, - ?action_info_hash_key, - ?worker_id, - ?e, - "Failed to update_operation on update_action" + let mut result = Ok(()); + + let mut stream = self + .get_queued_operations() + .await + .err_tip(|| "Failed to get queued operations in do_try_match")?; + + while let Some(action_state_result) = stream.next().await { + result = result.merge( + match_action_to_worker( + action_state_result.as_ref(), + self.worker_scheduler.as_ref(), + self.matching_engine_state_manager.as_ref(), + ) + .await, ); } - update_operation_result + result } } -/// Engine used to manage the queued/running tasks and relationship with -/// the worker nodes. All state on how the workers and actions are interacting -/// should be held in this struct. -pub struct SimpleScheduler { - inner: Arc>, - platform_property_manager: Arc, - metrics: Arc, - // Triggers `drop()`` call if scheduler is dropped. - _task_worker_matching_future: JoinHandleDropGuard<()>, -} - impl SimpleScheduler { - #[inline] - #[must_use] - pub fn new(scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler) -> Self { + pub fn new( + scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler, + ) -> (Arc, Arc) { Self::new_with_callback(scheduler_cfg, || { // The cost of running `do_try_match()` is very high, but constant - // in relation to the number of changes that have happened. This means - // that grabbing this lock to process `do_try_match()` should always - // yield to any other tasks that might want the lock. The easiest and - // most fair way to do this is to sleep for a small amount of time. - // Using something like tokio::task::yield_now() does not yield as - // aggresively as we'd like if new futures are scheduled within a future. + // in relation to the number of changes that have happened. This + // means that grabbing this lock to process `do_try_match()` should + // always yield to any other tasks that might want the lock. The + // easiest and most fair way to do this is to sleep for a small + // amount of time. Using something like tokio::task::yield_now() + // does not yield as aggresively as we'd like if new futures are + // scheduled within a future. tokio::time::sleep(Duration::from_millis(1)) }) } @@ -387,7 +264,7 @@ impl SimpleScheduler { >( scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler, on_matching_engine_run: F, - ) -> Self { + ) -> (Arc, Arc) { let platform_property_manager = Arc::new(PlatformPropertyManager::new( scheduler_cfg .supported_platform_properties @@ -410,102 +287,56 @@ impl SimpleScheduler { max_job_retries = DEFAULT_MAX_JOB_RETRIES; } - let tasks_or_workers_change_notify = Arc::new(Notify::new()); - let state_manager = StateManager::new( - HashSet::new(), - BTreeMap::new(), - Workers::new(scheduler_cfg.allocation_strategy), - HashMap::new(), - HashSet::new(), - Arc::new(SchedulerMetrics::default()), + let tasks_or_worker_change_notify = Arc::new(Notify::new()); + let state_manager = SimpleSchedulerStateManager::new( + tasks_or_worker_change_notify.clone(), max_job_retries, - tasks_or_workers_change_notify.clone(), + MemoryAwaitedActionDb::new(&EvictionPolicy { + max_seconds: retain_completed_for_s, + ..Default::default() + }), ); - let metrics = Arc::new(Metrics::default()); - let metrics_for_do_try_match = metrics.clone(); - let inner = Arc::new(Mutex::new(SimpleSchedulerImpl { - state_manager, - retain_completed_for: Duration::new(retain_completed_for_s, 0), + + let worker_scheduler = ApiWorkerScheduler::new( + state_manager.clone(), + platform_property_manager.clone(), + scheduler_cfg.allocation_strategy, + tasks_or_worker_change_notify.clone(), worker_timeout_s, - max_job_retries, - metrics: metrics.clone(), - })); - let weak_inner = Arc::downgrade(&inner); - Self { - inner, - platform_property_manager, - _task_worker_matching_future: spawn!( - "simple_scheduler_task_worker_matching", - async move { + ); + + let worker_scheduler_clone = worker_scheduler.clone(); + + let action_scheduler = Arc::new_cyclic(move |weak_self| -> Self { + let weak_inner = weak_self.clone(); + let task_worker_matching_spawn = + spawn!("simple_scheduler_task_worker_matching", async move { // Break out of the loop only when the inner is dropped. loop { - tasks_or_workers_change_notify.notified().await; - match weak_inner.upgrade() { - // Note: According to `parking_lot` documentation, the default - // `Mutex` implementation is eventual fairness, so we don't - // really need to worry about this thread taking the lock - // starving other threads too much. - Some(inner_mux) => { - let mut inner = inner_mux.lock().await; - let timer = metrics_for_do_try_match.do_try_match.begin_timer(); - inner.do_try_match().await; - timer.measure(); - } + tasks_or_worker_change_notify.notified().await; + let result = match weak_inner.upgrade() { + Some(scheduler) => scheduler.do_try_match().await, // If the inner went away it means the scheduler is shutting // down, so we need to resolve our future. None => return, }; + if let Err(err) = result { + event!(Level::ERROR, ?err, "Error while running do_try_match"); + } + on_matching_engine_run().await; } // Unreachable. - } - ), - metrics, - } - } - - /// Checks to see if the worker exists in the worker pool. Should only be used in unit tests. - #[must_use] - pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool { - let inner = self.get_inner_lock().await; - inner - .state_manager - .inner - .workers - .workers - .contains(worker_id) - } - - /// A unit test function used to send the keep alive message to the worker from the server. - pub async fn send_keep_alive_to_worker_for_test( - &self, - worker_id: &WorkerId, - ) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - let worker = inner - .state_manager - .inner - .workers - .workers - .get_mut(worker_id) - .ok_or_else(|| { - make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) - })?; - worker.keep_alive() - } - - async fn get_inner_lock(&self) -> MutexGuard<'_, SimpleSchedulerImpl> { - // We don't use one of the wrappers because we only want to capture the time spent, - // nothing else beacuse this is a hot path. - let start = Instant::now(); - let lock: MutexGuard = self.inner.lock().await; - self.metrics - .lock_stall_time - .fetch_add(start.elapsed().as_nanos() as u64, Ordering::Relaxed); - self.metrics - .lock_stall_time_counter - .fetch_add(1, Ordering::Relaxed); - lock + }); + SimpleScheduler { + matching_engine_state_manager: state_manager.clone(), + client_state_manager: state_manager.clone(), + worker_scheduler, + platform_property_manager, + _task_worker_matching_spawn: task_worker_matching_spawn, + } + }); + (action_scheduler, worker_scheduler_clone) } } @@ -520,82 +351,52 @@ impl ActionScheduler for SimpleScheduler { async fn add_action( &self, + client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error> { - let mut inner = self.get_inner_lock().await; - self.metrics - .add_action - .wrap(inner.add_action(action_info)) + ) -> Result>, Error> { + self.add_action(client_operation_id, Arc::new(action_info)) .await } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - let inner = self.get_inner_lock().await; - let result = inner - .find_existing_action(unique_qualifier) - .await - .or_else(|| inner.find_recently_completed_action(unique_qualifier)); - if result.is_some() { - self.metrics.existing_actions_found.inc(); - } else { - self.metrics.existing_actions_not_found.inc(); - } - result - } - - async fn clean_recently_completed_actions(&self) { - self.get_inner_lock() + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { + let maybe_receiver = self + .find_by_client_operation_id(client_operation_id) .await - .clean_recently_completed_actions(); - self.metrics.clean_recently_completed_actions.inc() + .err_tip(|| { + format!("Error while finding action with client id: {client_operation_id:?}") + })?; + Ok(maybe_receiver) } fn register_metrics(self: Arc, registry: &mut Registry) { - registry.register_collector(Box::new(Collector::new(&self))); + self.client_state_manager.clone().register_metrics(registry); + self.matching_engine_state_manager + .clone() + .register_metrics(registry); } } #[async_trait] impl WorkerScheduler for SimpleScheduler { fn get_platform_property_manager(&self) -> &PlatformPropertyManager { - self.platform_property_manager.as_ref() + self.worker_scheduler.get_platform_property_manager() } async fn add_worker(&self, worker: Worker) -> Result<(), Error> { - let worker_id = worker.id; - let mut inner = self.get_inner_lock().await; - self.metrics.add_worker.wrap(move || { - let res = inner - .state_manager - .inner - .workers - .add_worker(worker) - .err_tip(|| "Error while adding worker, removing from pool"); - if let Err(err) = &res { - inner.immediate_evict_worker(&worker_id, err.clone()); - } - inner - .state_manager - .inner - .tasks_or_workers_change_notify - .notify_one(); - res - }) + self.worker_scheduler.add_worker(worker).await } async fn update_action( &self, worker_id: &WorkerId, - action_info_hash_key: ActionInfoHashKey, + operation_id: &OperationId, action_stage: Result, ) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - self.metrics - .update_action - .wrap(inner.update_action(worker_id, action_info_hash_key, action_stage)) + self.worker_scheduler + .update_action(worker_id, operation_id, action_stage) .await } @@ -604,322 +405,24 @@ impl WorkerScheduler for SimpleScheduler { worker_id: &WorkerId, timestamp: WorkerTimestamp, ) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - inner - .state_manager - .inner - .workers - .refresh_lifetime(worker_id, timestamp) - .err_tip(|| "Error refreshing lifetime in worker_keep_alive_received()") + self.worker_scheduler + .worker_keep_alive_received(worker_id, timestamp) + .await } - async fn remove_worker(&self, worker_id: WorkerId) { - let mut inner = self.get_inner_lock().await; - inner.immediate_evict_worker( - &worker_id, - make_err!(Code::Internal, "Received request to remove worker"), - ); + async fn remove_worker(&self, worker_id: &WorkerId) -> Result<(), Error> { + self.worker_scheduler.remove_worker(worker_id).await } async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - self.metrics.remove_timedout_workers.wrap(move || { - // Items should be sorted based on last_update_timestamp, so we don't need to iterate the entire - // map most of the time. - let worker_ids_to_remove: Vec = inner - .state_manager - .inner - .workers - .workers - .iter() - .rev() - .map_while(|(worker_id, worker)| { - if worker.last_update_timestamp <= now_timestamp - inner.worker_timeout_s { - Some(*worker_id) - } else { - None - } - }) - .collect(); - for worker_id in &worker_ids_to_remove { - event!( - Level::WARN, - ?worker_id, - "Worker timed out, removing from pool" - ); - inner.immediate_evict_worker( - worker_id, - make_err!( - Code::Internal, - "Worker {worker_id} timed out, removing from pool" - ), - ); - } - - Ok(()) - }) - } - - async fn set_drain_worker(&self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - inner.set_drain_worker(worker_id, is_draining) - } - - fn register_metrics(self: Arc, _registry: &mut Registry) { - // We do not register anything here because we only want to register metrics - // once and we rely on the `ActionScheduler::register_metrics()` to do that. - } -} - -impl MetricsComponent for SimpleScheduler { - fn gather_metrics(&self, c: &mut CollectorState) { - self.metrics.gather_metrics(c); - { - // We use the raw lock because we dont gather stats about gathering stats. - let inner = self.inner.lock_blocking(); - inner.state_manager.inner.metrics.gather_metrics(c); - c.publish( - "queued_actions_total", - &inner.state_manager.inner.queued_actions.len(), - "The number actions in the queue.", - ); - c.publish( - "workers_total", - &inner.state_manager.inner.workers.workers.len(), - "The number workers active.", - ); - c.publish( - "active_actions_total", - &inner.state_manager.inner.active_actions.len(), - "The number of running actions.", - ); - c.publish( - "recently_completed_actions_total", - &inner.state_manager.inner.recently_completed_actions.len(), - "The number of recently completed actions in the buffer.", - ); - c.publish( - "retain_completed_for_seconds", - &inner.retain_completed_for, - "The duration completed actions are retained for.", - ); - c.publish( - "worker_timeout_seconds", - &inner.worker_timeout_s, - "The configured timeout if workers have not responded for a while.", - ); - c.publish( - "max_job_retries", - &inner.max_job_retries, - "The amount of times a job is allowed to retry from an internal error before it is dropped.", - ); - let mut props = HashMap::<&String, u64>::new(); - for (_worker_id, worker) in inner.state_manager.inner.workers.workers.iter() { - c.publish_with_labels( - "workers", - worker, - "", - vec![("worker_id".into(), worker.id.to_string().into())], - ); - for (property, prop_value) in &worker.platform_properties.properties { - let current_value = props.get(&property).unwrap_or(&0); - if let PlatformPropertyValue::Minimum(worker_value) = prop_value { - props.insert(property, *current_value + *worker_value); - } - } - } - for (property, prop_value) in props { - c.publish( - &format!("{property}_available_properties"), - &prop_value, - format!("Total sum of available properties for {property}"), - ); - } - for (_, active_action) in inner.state_manager.inner.active_actions.iter() { - let action_name = active_action - .action_info - .unique_qualifier - .action_name() - .into(); - let worker_id_str = match active_action.worker_id { - Some(id) => id.to_string(), - None => "Unassigned".to_string(), - }; - c.publish_with_labels( - "active_actions", - active_action, - "", - vec![ - ("worker_id".into(), worker_id_str.into()), - ("digest".into(), action_name), - ], - ); - } - // Note: We don't publish queued_actions because it can be very large. - // Note: We don't publish recently completed actions because it can be very large. - } + self.worker_scheduler + .remove_timedout_workers(now_timestamp) + .await } -} - -#[derive(Default)] -struct Metrics { - add_action: AsyncCounterWrapper, - existing_actions_found: CounterWithTime, - existing_actions_not_found: CounterWithTime, - clean_recently_completed_actions: CounterWithTime, - remove_timedout_workers: FuncCounterWrapper, - update_action: AsyncCounterWrapper, - update_action_with_internal_error: CounterWithTime, - update_action_with_internal_error_no_action: CounterWithTime, - update_action_with_internal_error_backpressure: CounterWithTime, - update_action_with_internal_error_from_wrong_worker: CounterWithTime, - workers_evicted: CounterWithTime, - workers_evicted_with_running_action: CounterWithTime, - workers_drained: CounterWithTime, - retry_action: CounterWithTime, - retry_action_max_attempts_reached: CounterWithTime, - retry_action_no_more_listeners: CounterWithTime, - retry_action_but_action_missing: CounterWithTime, - add_worker: FuncCounterWrapper, - timedout_workers: CounterWithTime, - lock_stall_time: AtomicU64, - lock_stall_time_counter: AtomicU64, - do_try_match: AsyncCounterWrapper, -} -impl Metrics { - fn gather_metrics(&self, c: &mut CollectorState) { - c.publish( - "add_action", - &self.add_action, - "The number of times add_action was called.", - ); - c.publish_with_labels( - "find_existing_action", - &self.existing_actions_found, - "The number of times existing_actions_found had an action found.", - vec![("result".into(), "found".into())], - ); - c.publish_with_labels( - "find_existing_action", - &self.existing_actions_not_found, - "The number of times existing_actions_found had an action not found.", - vec![("result".into(), "not_found".into())], - ); - c.publish( - "clean_recently_completed_actions", - &self.clean_recently_completed_actions, - "The number of times clean_recently_completed_actions was triggered.", - ); - c.publish( - "remove_timedout_workers", - &self.remove_timedout_workers, - "The number of times remove_timedout_workers was triggered.", - ); - { - c.publish_with_labels( - "update_action", - &self.update_action, - "Stats about errors when worker sends update_action() to scheduler.", - vec![("result".into(), "missing_action_result".into())], - ); - } - c.publish( - "update_action_with_internal_error", - &self.update_action_with_internal_error, - "The number of times update_action_with_internal_error was triggered.", - ); - { - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_no_action, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "no_action".into())], - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_backpressure, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "backpressure".into())], - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_from_wrong_worker, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "from_wrong_worker".into())], - ); - } - c.publish( - "workers_evicted_total", - &self.workers_evicted, - "The number of workers evicted from scheduler.", - ); - c.publish( - "workers_evicted_with_running_action", - &self.workers_evicted_with_running_action, - "The number of jobs cancelled because worker was evicted from scheduler.", - ); - c.publish( - "workers_drained_total", - &self.workers_drained, - "The number of workers drained from scheduler.", - ); - { - c.publish_with_labels( - "retry_action", - &self.retry_action, - "Stats about retry_action().", - vec![("result".into(), "success".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_max_attempts_reached, - "Stats about retry_action().", - vec![("result".into(), "max_attempts_reached".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_no_more_listeners, - "Stats about retry_action().", - vec![("result".into(), "no_more_listeners".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_but_action_missing, - "Stats about retry_action().", - vec![("result".into(), "action_missing".into())], - ); - } - c.publish( - "add_worker", - &self.add_worker, - "Stats about add_worker() being called on the scheduler.", - ); - c.publish( - "timedout_workers", - &self.timedout_workers, - "The number of workers that timed out.", - ); - c.publish( - "lock_stall_time_nanos_total", - &self.lock_stall_time, - "The total number of nanos spent waiting on the lock in the scheduler.", - ); - c.publish( - "lock_stall_time_total", - &self.lock_stall_time_counter, - "The number of times a lock request was made in the scheduler.", - ); - c.publish( - "lock_stall_time_avg_nanos", - &(self.lock_stall_time.load(Ordering::Relaxed) - / self.lock_stall_time_counter.load(Ordering::Relaxed)), - "The average time the scheduler stalled waiting on the lock to release in nanos.", - ); - c.publish( - "matching_engine", - &self.do_try_match, - "The job<->worker matching engine stats. This is a very expensive operation, so it is not run every time (often called do_try_match).", - ); + async fn set_drain_worker(&self, worker_id: &WorkerId, is_draining: bool) -> Result<(), Error> { + self.worker_scheduler + .set_drain_worker(worker_id, is_draining) + .await } } diff --git a/nativelink-scheduler/src/simple_scheduler_state_manager.rs b/nativelink-scheduler/src/simple_scheduler_state_manager.rs new file mode 100644 index 000000000..f8bd2fc4d --- /dev/null +++ b/nativelink-scheduler/src/simple_scheduler_state_manager.rs @@ -0,0 +1,480 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Bound; +use std::sync::Arc; + +use async_trait::async_trait; +use futures::{future, stream, StreamExt, TryStreamExt}; +use nativelink_error::{make_err, Code, Error, ResultExt}; +use nativelink_util::action_messages::{ + ActionInfo, ActionResult, ActionStage, ActionState, ActionUniqueQualifier, ClientOperationId, + ExecutionMetadata, OperationId, WorkerId, +}; +use nativelink_util::metrics_utils::{Collector, CollectorState, MetricsComponent, Registry}; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, + OperationFilter, OperationStageFlags, OrderDirection, WorkerStateManager, +}; +use tokio::sync::Notify; +use tracing::{event, Level}; + +use super::awaited_action_db::{ + AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedActionState, +}; +use crate::memory_awaited_action_db::{ClientActionStateResult, MatchingEngineActionStateResult}; + +/// Maximum number of times an update to the database +/// can fail before giving up. +const MAX_UPDATE_RETRIES: usize = 5; + +/// Simple struct that implements the ActionStateResult trait and always returns an error. +struct ErrorActionStateResult(Error); + +#[async_trait] +impl ActionStateResult for ErrorActionStateResult { + async fn as_state(&self) -> Result, Error> { + Err(self.0.clone()) + } + + async fn changed(&mut self) -> Result, Error> { + Err(self.0.clone()) + } + + async fn as_action_info(&self) -> Result, Error> { + Err(self.0.clone()) + } +} + +fn apply_filter_predicate(awaited_action: &AwaitedAction, filter: &OperationFilter) -> bool { + // Note: The caller must filter `client_operation_id`. + + if let Some(operation_id) = &filter.operation_id { + if operation_id != awaited_action.operation_id() { + return false; + } + } + + if filter.worker_id.is_some() && filter.worker_id != awaited_action.worker_id() { + return false; + } + + { + if let Some(filter_unique_key) = &filter.unique_key { + match &awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => { + if filter_unique_key != unique_key { + return false; + } + } + ActionUniqueQualifier::Uncachable(_) => { + return false; + } + } + } + if let Some(action_digest) = filter.action_digest { + if action_digest != awaited_action.action_info().digest() { + return false; + } + } + } + + { + let last_worker_update_timestamp = awaited_action.last_worker_updated_timestamp(); + if let Some(worker_update_before) = filter.worker_update_before { + if worker_update_before < last_worker_update_timestamp { + return false; + } + } + if let Some(completed_before) = filter.completed_before { + if awaited_action.state().stage.is_finished() + && completed_before < last_worker_update_timestamp + { + return false; + } + } + if filter.stages != OperationStageFlags::Any { + let stage_flag = match awaited_action.state().stage { + ActionStage::Unknown => OperationStageFlags::Any, + ActionStage::CacheCheck => OperationStageFlags::CacheCheck, + ActionStage::Queued => OperationStageFlags::Queued, + ActionStage::Executing => OperationStageFlags::Executing, + ActionStage::Completed(_) => OperationStageFlags::Completed, + ActionStage::CompletedFromCache(_) => OperationStageFlags::Completed, + }; + if !filter.stages.intersects(stage_flag) { + return false; + } + } + } + + true +} + +/// MemorySchedulerStateManager is responsible for maintaining the state of the scheduler. +/// Scheduler state includes the actions that are queued, active, and recently completed. +/// It also includes the workers that are available to execute actions based on allocation +/// strategy. +pub struct SimpleSchedulerStateManager { + /// Database for storing the state of all actions. + action_db: T, + + /// Notify matching engine that work needs to be done. + tasks_change_notify: Arc, + + /// Maximum number of times a job can be retried. + // TODO(allada) This should be a scheduler decorator instead + // of always having it on every SimpleScheduler. + max_job_retries: usize, +} + +impl SimpleSchedulerStateManager { + pub fn new( + tasks_change_notify: Arc, + max_job_retries: usize, + action_db: T, + ) -> Arc { + Arc::new(Self { + action_db, + tasks_change_notify, + max_job_retries, + }) + } + + async fn inner_update_operation( + &self, + operation_id: &OperationId, + maybe_worker_id: Option<&WorkerId>, + action_stage_result: Result, + ) -> Result<(), Error> { + let mut last_err = None; + for _ in 0..MAX_UPDATE_RETRIES { + let maybe_awaited_action_subscriber = self + .action_db + .get_by_operation_id(operation_id) + .await + .err_tip(|| "In MemorySchedulerStateManager::update_operation")?; + let awaited_action_subscriber = match maybe_awaited_action_subscriber { + Some(sub) => sub, + // No action found. It is ok if the action was not found. It probably + // means that the action was dropped, but worker was still processing + // it. + None => return Ok(()), + }; + + let mut awaited_action = awaited_action_subscriber.borrow(); + + // Make sure we don't update an action that is already completed. + if awaited_action.state().stage.is_finished() { + return Err(make_err!( + Code::Internal, + "Action {operation_id:?} is already completed with state {:?} - maybe_worker_id: {:?}", + awaited_action.state().stage, + maybe_worker_id, + )); + } + + // Make sure the worker id matches the awaited action worker id. + // This might happen if the worker sending the update is not the + // worker that was assigned. + if awaited_action.worker_id().is_some() + && maybe_worker_id.is_some() + && maybe_worker_id != awaited_action.worker_id().as_ref() + { + let err = make_err!( + Code::Internal, + "Worker ids do not match - {:?} != {:?} for {:?}", + maybe_worker_id, + awaited_action.worker_id(), + awaited_action, + ); + event!( + Level::ERROR, + ?operation_id, + ?maybe_worker_id, + ?awaited_action, + "{}", + err.to_string(), + ); + return Err(err); + } + + let stage = match &action_stage_result { + Ok(stage) => stage.clone(), + Err(err) => { + // Don't count a backpressure failure as an attempt for an action. + let due_to_backpressure = err.code == Code::ResourceExhausted; + if !due_to_backpressure { + awaited_action.attempts += 1; + } + + if awaited_action.attempts > self.max_job_retries { + ActionStage::Completed(ActionResult { + execution_metadata: ExecutionMetadata { + worker: maybe_worker_id.map_or_else(String::default, |v| v.to_string()), + ..ExecutionMetadata::default() + }, + error: Some(err.clone().merge(make_err!( + Code::Internal, + "Job cancelled because it attempted to execute too many times and failed {}", + format!("for operation_id: {operation_id}, maybe_worker_id: {maybe_worker_id:?}"), + ))), + ..ActionResult::default() + }) + } else { + ActionStage::Queued + } + } + }; + if matches!(stage, ActionStage::Queued) { + // If the action is queued, we need to unset the worker id regardless of + // which worker sent the update. + awaited_action.set_worker_id(None); + } else { + awaited_action.set_worker_id(maybe_worker_id.copied()); + } + awaited_action.set_state(Arc::new(ActionState { + stage, + id: operation_id.clone(), + })); + awaited_action.increment_version(); + + let update_action_result = self + .action_db + .update_awaited_action(awaited_action) + .await + .err_tip(|| "In MemorySchedulerStateManager::update_operation"); + if let Err(err) = update_action_result { + // We use Aborted to signal that the action was not + // updated due to the data being set was not the latest + // but can be retried. + if err.code == Code::Aborted { + last_err = Some(err); + continue; + } else { + return Err(err); + } + } + + self.tasks_change_notify.notify_one(); + return Ok(()); + } + match last_err { + Some(err) => Err(err), + None => Err(make_err!( + Code::Internal, + "Failed to update action after {} retries with no error set", + MAX_UPDATE_RETRIES, + )), + } + } + + async fn inner_add_operation( + &self, + new_client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result { + let rx = self + .action_db + .add_action(new_client_operation_id, action_info) + .await + .err_tip(|| "In MemorySchedulerStateManager::add_operation")?; + self.tasks_change_notify.notify_one(); + Ok(rx) + } + + async fn inner_filter_operations<'a, F>( + &'a self, + filter: OperationFilter, + to_action_state_result: F, + ) -> Result, Error> + where + F: Fn(T::Subscriber) -> Box + Send + Sync + 'a, + { + fn sorted_awaited_action_state_for_flags( + stage: OperationStageFlags, + ) -> Option { + match stage { + OperationStageFlags::CacheCheck => Some(SortedAwaitedActionState::CacheCheck), + OperationStageFlags::Queued => Some(SortedAwaitedActionState::Queued), + OperationStageFlags::Executing => Some(SortedAwaitedActionState::Executing), + OperationStageFlags::Completed => Some(SortedAwaitedActionState::Completed), + _ => None, + } + } + + if let Some(operation_id) = &filter.operation_id { + return Ok(self + .action_db + .get_by_operation_id(operation_id) + .await + .err_tip(|| "In MemorySchedulerStateManager::filter_operations")? + .filter(|awaited_action_rx| { + let awaited_action = awaited_action_rx.borrow(); + apply_filter_predicate(&awaited_action, &filter) + }) + .map(|awaited_action| -> ActionStateResultStream { + Box::pin(stream::once(async move { + to_action_state_result(awaited_action) + })) + }) + .unwrap_or_else(|| Box::pin(stream::empty()))); + } + if let Some(client_operation_id) = &filter.client_operation_id { + return Ok(self + .action_db + .get_awaited_action_by_id(client_operation_id) + .await + .err_tip(|| "In MemorySchedulerStateManager::filter_operations")? + .filter(|awaited_action_rx| { + let awaited_action = awaited_action_rx.borrow(); + apply_filter_predicate(&awaited_action, &filter) + }) + .map(|awaited_action| -> ActionStateResultStream { + Box::pin(stream::once(async move { + to_action_state_result(awaited_action) + })) + }) + .unwrap_or_else(|| Box::pin(stream::empty()))); + } + + let Some(sorted_awaited_action_state) = + sorted_awaited_action_state_for_flags(filter.stages) + else { + let mut all_items: Vec = self + .action_db + .get_all_awaited_actions() + .await + .try_filter(|awaited_action_subscriber| { + future::ready(apply_filter_predicate( + &awaited_action_subscriber.borrow(), + &filter, + )) + }) + .try_collect() + .await + .err_tip(|| "In MemorySchedulerStateManager::filter_operations")?; + match filter.order_by_priority_direction { + Some(OrderDirection::Asc) => { + all_items.sort_unstable_by_key(|a| a.borrow().sort_key()) + } + Some(OrderDirection::Desc) => { + all_items.sort_unstable_by_key(|a| std::cmp::Reverse(a.borrow().sort_key())) + } + None => {} + } + return Ok(Box::pin(stream::iter( + all_items.into_iter().map(to_action_state_result), + ))); + }; + + let desc = matches!( + filter.order_by_priority_direction, + Some(OrderDirection::Desc) + ); + let filter = filter.clone(); + let stream = self + .action_db + .get_range_of_actions( + sorted_awaited_action_state, + Bound::Unbounded, + Bound::Unbounded, + desc, + ) + .await + .try_filter(move |sub| future::ready(apply_filter_predicate(&sub.borrow(), &filter))) + .map(move |result| -> Box { + result.map_or_else( + |e| -> Box { Box::new(ErrorActionStateResult(e)) }, + |v| -> Box { to_action_state_result(v) }, + ) + }); + Ok(Box::pin(stream)) + } +} + +#[async_trait] +impl ClientStateManager for SimpleSchedulerStateManager { + async fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result, Error> { + let sub = self + .inner_add_operation(client_operation_id.clone(), action_info.clone()) + .await?; + + Ok(Box::new(ClientActionStateResult::new(sub))) + } + + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { + self.inner_filter_operations(filter, move |rx| Box::new(ClientActionStateResult::new(rx))) + .await + } +} + +#[async_trait] +impl WorkerStateManager for SimpleSchedulerStateManager { + async fn update_operation( + &self, + operation_id: &OperationId, + worker_id: &WorkerId, + action_stage_result: Result, + ) -> Result<(), Error> { + self.inner_update_operation(operation_id, Some(worker_id), action_stage_result) + .await + } +} + +#[async_trait] +impl MatchingEngineStateManager for SimpleSchedulerStateManager { + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { + self.inner_filter_operations(filter, |rx| { + Box::new(MatchingEngineActionStateResult::new(rx)) + }) + .await + } + + async fn assign_operation( + &self, + operation_id: &OperationId, + worker_id_or_reason_for_unsassign: Result<&WorkerId, Error>, + ) -> Result<(), Error> { + let (maybe_worker_id, stage_result) = match worker_id_or_reason_for_unsassign { + Ok(worker_id) => (Some(worker_id), Ok(ActionStage::Executing)), + Err(err) => (None, Err(err)), + }; + self.inner_update_operation(operation_id, maybe_worker_id, stage_result) + .await + } + + /// Register metrics with the registry. + fn register_metrics(self: Arc, registry: &mut Registry) { + // TODO(allada) We only register the metrics in one of the components instead of + // all three because it's a bit tricky to separate the metrics for each component. + registry.register_collector(Box::new(Collector::new(&self))); + } +} + +impl MetricsComponent for SimpleSchedulerStateManager { + fn gather_metrics(&self, c: &mut CollectorState) { + c.publish("", &self.action_db, ""); + } +} diff --git a/nativelink-scheduler/src/worker.rs b/nativelink-scheduler/src/worker.rs index 3dde40874..883c7c1a9 100644 --- a/nativelink-scheduler/src/worker.rs +++ b/nativelink-scheduler/src/worker.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashSet; +use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; @@ -21,7 +21,7 @@ use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; -use nativelink_util::action_messages::{ActionInfo, WorkerId}; +use nativelink_util::action_messages::{ActionInfo, OperationId, WorkerId}; use nativelink_util::metrics_utils::{ CollectorState, CounterWithTime, FuncCounterWrapper, MetricsComponent, }; @@ -33,7 +33,7 @@ pub type WorkerTimestamp = u64; /// Notifications to send worker about a requested state change. pub enum WorkerUpdate { /// Requests that the worker begin executing this action. - RunAction(Arc), + RunAction((OperationId, Arc)), /// Request that the worker is no longer in the pool and may discard any jobs. Disconnect, @@ -52,7 +52,7 @@ pub struct Worker { pub tx: UnboundedSender, /// The action info of the running actions on the worker - pub running_action_infos: HashSet>, + pub running_action_infos: HashMap>, /// Timestamp of last time this worker had been communicated with. // Warning: Do not update this timestamp without updating the placement of the worker in @@ -108,7 +108,7 @@ impl Worker { id, platform_properties, tx, - running_action_infos: HashSet::new(), + running_action_infos: HashMap::new(), last_update_timestamp: timestamp, is_paused: false, is_draining: false, @@ -140,7 +140,9 @@ impl Worker { /// Notifies the worker of a requested state change. pub fn notify_update(&mut self, worker_update: WorkerUpdate) -> Result<(), Error> { match worker_update { - WorkerUpdate::RunAction(action_info) => self.run_action(action_info), + WorkerUpdate::RunAction((operation_id, action_info)) => { + self.run_action(operation_id, action_info) + } WorkerUpdate::Disconnect => { self.metrics.notify_disconnect.inc(); send_msg_to_worker(&mut self.tx, update_for_worker::Update::Disconnect(())) @@ -157,13 +159,18 @@ impl Worker { }) } - fn run_action(&mut self, action_info: Arc) -> Result<(), Error> { + fn run_action( + &mut self, + operation_id: OperationId, + action_info: Arc, + ) -> Result<(), Error> { let tx = &mut self.tx; let worker_platform_properties = &mut self.platform_properties; let running_action_infos = &mut self.running_action_infos; self.metrics.run_action.wrap(move || { let action_info_clone = action_info.as_ref().clone(); - running_action_infos.insert(action_info.clone()); + let operation_id_string = operation_id.to_string(); + running_action_infos.insert(operation_id, action_info.clone()); reduce_platform_properties( worker_platform_properties, &action_info.platform_properties, @@ -172,18 +179,24 @@ impl Worker { tx, update_for_worker::Update::StartAction(StartExecute { execute_request: Some(action_info_clone.into()), - salt: *action_info.salt(), + operation_id: operation_id_string, queued_timestamp: Some(action_info.insert_timestamp.into()), }), ) }) } - pub fn complete_action(&mut self, action_info: &Arc) { - self.running_action_infos.remove(action_info); + pub(crate) fn complete_action(&mut self, operation_id: &OperationId) -> Result<(), Error> { + let action_info = self.running_action_infos.remove(operation_id).err_tip(|| { + format!( + "Worker {} tried to complete operation {} that was not running", + self.id, operation_id + ) + })?; self.restore_platform_properties(&action_info.platform_properties); self.is_paused = false; self.metrics.actions_completed.inc(); + Ok(()) } pub fn has_actions(&self) -> bool { @@ -277,8 +290,8 @@ impl MetricsComponent for Worker { "If this worker is draining.", vec![("worker_id".into(), format!("{}", self.id).into())], ); - for action_info in self.running_action_infos.iter() { - let action_name = action_info.unique_qualifier.action_name().to_string(); + for action_info in self.running_action_infos.values() { + let action_name = action_info.unique_qualifier.to_string(); c.publish_with_labels( "timeout", &action_info.timeout, @@ -303,12 +316,6 @@ impl MetricsComponent for Worker { "When this action was created.", vec![("digest".into(), action_name.clone().into())], ); - c.publish_with_labels( - "skip_cache_lookup", - &action_info.skip_cache_lookup, - "Weather this action should skip cache lookup.", - vec![("digest".into(), action_name.clone().into())], - ); } for (prop_name, prop_type_and_value) in &self.platform_properties.properties { match prop_type_and_value { diff --git a/nativelink-scheduler/src/worker_scheduler.rs b/nativelink-scheduler/src/worker_scheduler.rs index f17c79f6c..10298c69c 100644 --- a/nativelink-scheduler/src/worker_scheduler.rs +++ b/nativelink-scheduler/src/worker_scheduler.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use async_trait::async_trait; use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfoHashKey, ActionStage, WorkerId}; +use nativelink_util::action_messages::{ActionStage, OperationId, WorkerId}; use nativelink_util::metrics_utils::Registry; use crate::platform_property_manager::PlatformPropertyManager; @@ -36,7 +36,7 @@ pub trait WorkerScheduler: Sync + Send + Unpin { async fn update_action( &self, worker_id: &WorkerId, - action_info_hash_key: ActionInfoHashKey, + operation_id: &OperationId, action_stage: Result, ) -> Result<(), Error>; @@ -48,14 +48,14 @@ pub trait WorkerScheduler: Sync + Send + Unpin { ) -> Result<(), Error>; /// Removes worker from pool and reschedule any tasks that might be running on it. - async fn remove_worker(&self, worker_id: WorkerId); + async fn remove_worker(&self, worker_id: &WorkerId) -> Result<(), Error>; /// Removes timed out workers from the pool. This is called periodically by an /// external source. async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error>; /// Sets if the worker is draining or not. - async fn set_drain_worker(&self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error>; + async fn set_drain_worker(&self, worker_id: &WorkerId, is_draining: bool) -> Result<(), Error>; /// Register the metrics for the worker scheduler. fn register_metrics(self: Arc, _registry: &mut Registry) {} diff --git a/nativelink-scheduler/tests/action_messages_test.rs b/nativelink-scheduler/tests/action_messages_test.rs index a65af5a42..e1ae1e444 100644 --- a/nativelink-scheduler/tests/action_messages_test.rs +++ b/nativelink-scheduler/tests/action_messages_test.rs @@ -12,9 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; +use std::collections::HashMap; +use std::time::SystemTime; use nativelink_error::Error; use nativelink_macro::nativelink_test; @@ -22,37 +21,28 @@ use nativelink_proto::build::bazel::remote::execution::v2::ExecuteResponse; use nativelink_proto::google::longrunning::{operation, Operation}; use nativelink_proto::google::rpc::Status; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, ExecutionMetadata, - OperationId, + ActionResult, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, ExecutionMetadata, OperationId, }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; -use nativelink_util::platform_properties::PlatformProperties; use pretty_assertions::assert_eq; -const NOW_TIME: u64 = 10000; - -fn make_system_time(add_time: u64) -> SystemTime { - SystemTime::UNIX_EPOCH - .checked_add(Duration::from_secs(NOW_TIME + add_time)) - .unwrap() -} - #[nativelink_test] async fn action_state_any_url_test() -> Result<(), Error> { - let unique_qualifier = ActionInfoHashKey { + let unique_qualifier = ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: "foo_instance".to_string(), digest_function: DigestHasherFunc::Sha256, digest: DigestInfo::new([1u8; 32], 5), - salt: 0, - }; - let id = OperationId::new(unique_qualifier); + }); + let client_id = ClientOperationId::new(unique_qualifier.clone()); + let operation_id = OperationId::new(unique_qualifier); let action_state = ActionState { - id, + id: operation_id.clone(), // Result is only populated if has_action_result. stage: ActionStage::Completed(ActionResult::default()), }; - let operation: Operation = action_state.clone().into(); + let operation: Operation = action_state.as_operation(client_id); match &operation.result { Some(operation::Result::Response(any)) => assert_eq!( @@ -62,7 +52,7 @@ async fn action_state_any_url_test() -> Result<(), Error> { other => panic!("Expected Some(Result(Any)), got: {other:?}"), } - let action_state_round_trip: ActionState = operation.try_into()?; + let action_state_round_trip = ActionState::try_from_operation(operation, operation_id)?; assert_eq!(action_state, action_state_round_trip); Ok(()) @@ -101,115 +91,3 @@ async fn execute_response_status_message_is_some_on_success_test() -> Result<(), Ok(()) } - -#[nativelink_test] -async fn highest_priority_action_first() -> Result<(), Error> { - const INSTANCE_NAME: &str = "foobar_instance_name"; - - let high_priority_action = Arc::new(ActionInfo { - command_digest: DigestInfo::new([0u8; 32], 0), - input_root_digest: DigestInfo::new([0u8; 32], 0), - timeout: Duration::MAX, - platform_properties: PlatformProperties { - properties: HashMap::new(), - }, - priority: 1000, - load_timestamp: SystemTime::UNIX_EPOCH, - insert_timestamp: SystemTime::UNIX_EPOCH, - unique_qualifier: ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([0u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: true, - }); - let lowest_priority_action = Arc::new(ActionInfo { - command_digest: DigestInfo::new([0u8; 32], 0), - input_root_digest: DigestInfo::new([0u8; 32], 0), - timeout: Duration::MAX, - platform_properties: PlatformProperties { - properties: HashMap::new(), - }, - priority: 0, - load_timestamp: SystemTime::UNIX_EPOCH, - insert_timestamp: SystemTime::UNIX_EPOCH, - unique_qualifier: ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([1u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: true, - }); - let mut action_set = BTreeSet::>::new(); - action_set.insert(lowest_priority_action.clone()); - action_set.insert(high_priority_action.clone()); - - assert_eq!( - vec![high_priority_action, lowest_priority_action], - action_set - .iter() - .rev() - .cloned() - .collect::>>() - ); - - Ok(()) -} - -#[nativelink_test] -async fn equal_priority_earliest_first() -> Result<(), Error> { - const INSTANCE_NAME: &str = "foobar_instance_name"; - - let first_action = Arc::new(ActionInfo { - command_digest: DigestInfo::new([0u8; 32], 0), - input_root_digest: DigestInfo::new([0u8; 32], 0), - timeout: Duration::MAX, - platform_properties: PlatformProperties { - properties: HashMap::new(), - }, - priority: 0, - load_timestamp: SystemTime::UNIX_EPOCH, - insert_timestamp: SystemTime::UNIX_EPOCH, - unique_qualifier: ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([0u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: true, - }); - let current_action = Arc::new(ActionInfo { - command_digest: DigestInfo::new([0u8; 32], 0), - input_root_digest: DigestInfo::new([0u8; 32], 0), - timeout: Duration::MAX, - platform_properties: PlatformProperties { - properties: HashMap::new(), - }, - priority: 0, - load_timestamp: SystemTime::UNIX_EPOCH, - insert_timestamp: make_system_time(0), - unique_qualifier: ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([1u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: true, - }); - let mut action_set = BTreeSet::>::new(); - action_set.insert(current_action.clone()); - action_set.insert(first_action.clone()); - - assert_eq!( - vec![first_action, current_action], - action_set - .iter() - .rev() - .cloned() - .collect::>>() - ); - - Ok(()) -} diff --git a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs index 80d766d3e..ab3a79c5f 100644 --- a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs +++ b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs @@ -27,10 +27,12 @@ use nativelink_macro::nativelink_test; use nativelink_proto::build::bazel::remote::execution::v2::ActionResult as ProtoActionResult; use nativelink_scheduler::action_scheduler::ActionScheduler; use nativelink_scheduler::cache_lookup_scheduler::CacheLookupScheduler; +use nativelink_scheduler::default_action_listener::DefaultActionListener; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_store::memory_store::MemoryStore; use nativelink_util::action_messages::{ - ActionInfoHashKey, ActionResult, ActionStage, ActionState, OperationId, + ActionResult, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, OperationId, }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; @@ -85,42 +87,55 @@ async fn platform_property_manager_call_passed() -> Result<(), Error> { #[nativelink_test] async fn add_action_handles_skip_cache() -> Result<(), Error> { let context = make_cache_scheduler()?; - let action_info = make_base_action_info(UNIX_EPOCH); + let action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); let action_result = ProtoActionResult::from(ActionResult::default()); context .ac_store - .update_oneshot(*action_info.digest(), action_result.encode_to_vec().into()) + .update_oneshot(action_info.digest(), action_result.encode_to_vec().into()) .await?; let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { id: OperationId::new(action_info.unique_qualifier.clone()), stage: ActionStage::Queued, })); + let ActionUniqueQualifier::Cachable(action_key) = action_info.unique_qualifier.clone() else { + panic!("This test should be testing when item was cached first"); + }; let mut skip_cache_action = action_info.clone(); - skip_cache_action.skip_cache_lookup = true; + skip_cache_action.unique_qualifier = ActionUniqueQualifier::Uncachable(action_key); + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); let _ = join!( - context.cache_scheduler.add_action(skip_cache_action), + context + .cache_scheduler + .add_action(client_operation_id.clone(), skip_cache_action), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)) + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id, + forward_watch_channel_rx + )))) ); Ok(()) } #[nativelink_test] -async fn find_existing_action_call_passed() -> Result<(), Error> { +async fn find_by_client_operation_id_call_passed() -> Result<(), Error> { let context = make_cache_scheduler()?; - let action_name = ActionInfoHashKey { - instance_name: "instance".to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([8; 32], 1), - salt: 1000, - }; - let (actual_result, actual_action_name) = join!( - context.cache_scheduler.find_existing_action(&action_name), - context.mock_scheduler.expect_find_existing_action(None), + let client_operation_id = + ClientOperationId::new(ActionUniqueQualifier::Uncachable(ActionUniqueKey { + instance_name: "instance".to_string(), + digest_function: DigestHasherFunc::Sha256, + digest: DigestInfo::new([8; 32], 1), + })); + let (actual_result, actual_client_id) = join!( + context + .cache_scheduler + .find_by_client_operation_id(&client_operation_id), + context + .mock_scheduler + .expect_find_by_client_operation_id(Ok(None)), ); - assert_eq!(true, actual_result.is_none()); - assert_eq!(action_name, actual_action_name); + assert_eq!(true, actual_result.unwrap().is_none()); + assert_eq!(client_operation_id, actual_client_id); Ok(()) } diff --git a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs index 38b28ec67..8265e39ea 100644 --- a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs +++ b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs @@ -26,9 +26,13 @@ use nativelink_config::schedulers::{PlatformPropertyAddition, PropertyModificati use nativelink_error::Error; use nativelink_macro::nativelink_test; use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::default_action_listener::DefaultActionListener; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_scheduler::property_modifier_scheduler::PropertyModifierScheduler; -use nativelink_util::action_messages::{ActionInfoHashKey, ActionStage, ActionState, OperationId}; +use nativelink_util::action_messages::{ + ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, ClientOperationId, + OperationId, +}; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::platform_properties::PlatformPropertyValue; @@ -66,7 +70,7 @@ async fn add_action_adds_property() -> Result<(), Error> { name: name.clone(), value: value.clone(), })]); - let action_info = make_base_action_info(UNIX_EPOCH); + let action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { id: OperationId::new(action_info.unique_qualifier.clone()), @@ -76,15 +80,22 @@ async fn add_action_adds_property() -> Result<(), Error> { name.clone(), PropertyType::exact, )]))); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([(name, PlatformPropertyValue::Exact(value))]), action_info.platform_properties.properties @@ -102,7 +113,7 @@ async fn add_action_overwrites_property() -> Result<(), Error> { name: name.clone(), value: replaced_value.clone(), })]); - let mut action_info = make_base_action_info(UNIX_EPOCH); + let mut action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); action_info .platform_properties .properties @@ -116,15 +127,22 @@ async fn add_action_overwrites_property() -> Result<(), Error> { name.clone(), PropertyType::exact, )]))); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([(name, PlatformPropertyValue::Exact(replaced_value))]), action_info.platform_properties.properties @@ -143,7 +161,7 @@ async fn add_action_property_added_after_remove() -> Result<(), Error> { value: value.clone(), }), ]); - let action_info = make_base_action_info(UNIX_EPOCH); + let action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { id: OperationId::new(action_info.unique_qualifier.clone()), @@ -153,15 +171,22 @@ async fn add_action_property_added_after_remove() -> Result<(), Error> { name.clone(), PropertyType::exact, )]))); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([(name, PlatformPropertyValue::Exact(value))]), action_info.platform_properties.properties @@ -180,7 +205,7 @@ async fn add_action_property_remove_after_add() -> Result<(), Error> { }), PropertyModification::remove(name.clone()), ]); - let action_info = make_base_action_info(UNIX_EPOCH); + let action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { id: OperationId::new(action_info.unique_qualifier.clone()), @@ -190,15 +215,22 @@ async fn add_action_property_remove_after_add() -> Result<(), Error> { name, PropertyType::exact, )]))); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([]), action_info.platform_properties.properties @@ -211,7 +243,7 @@ async fn add_action_property_remove() -> Result<(), Error> { let name = "name".to_string(); let value = "value".to_string(); let context = make_modifier_scheduler(vec![PropertyModification::remove(name.clone())]); - let mut action_info = make_base_action_info(UNIX_EPOCH); + let mut action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); action_info .platform_properties .properties @@ -222,15 +254,22 @@ async fn add_action_property_remove() -> Result<(), Error> { stage: ActionStage::Queued, })); let platform_property_manager = Arc::new(PlatformPropertyManager::new(HashMap::new())); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([]), action_info.platform_properties.properties @@ -239,22 +278,23 @@ async fn add_action_property_remove() -> Result<(), Error> { } #[nativelink_test] -async fn find_existing_action_call_passed() -> Result<(), Error> { +async fn find_by_client_operation_id_call_passed() -> Result<(), Error> { let context = make_modifier_scheduler(vec![]); - let action_name = ActionInfoHashKey { + let operation_id = ClientOperationId::new(ActionUniqueQualifier::Uncachable(ActionUniqueKey { instance_name: "instance".to_string(), digest_function: DigestHasherFunc::Sha256, digest: DigestInfo::new([8; 32], 1), - salt: 1000, - }; - let (actual_result, actual_action_name) = join!( + })); + let (actual_result, actual_operation_id) = join!( context .modifier_scheduler - .find_existing_action(&action_name), - context.mock_scheduler.expect_find_existing_action(None), + .find_by_client_operation_id(&operation_id), + context + .mock_scheduler + .expect_find_by_client_operation_id(Ok(None)), ); - assert_eq!(true, actual_result.is_none()); - assert_eq!(action_name, actual_action_name); + assert_eq!(true, actual_result.unwrap().is_none()); + assert_eq!(operation_id, actual_operation_id); Ok(()) } diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index 3c722220a..4a4314140 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -13,29 +13,33 @@ // limitations under the License. use std::collections::HashMap; +use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use futures::poll; +use futures::task::Poll; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_macro::nativelink_test; use nativelink_proto::build::bazel::remote::execution::v2::{digest_function, ExecuteRequest}; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; -use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::action_scheduler::{ActionListener, ActionScheduler}; use nativelink_scheduler::simple_scheduler::SimpleScheduler; use nativelink_scheduler::worker::Worker; use nativelink_scheduler::worker_scheduler::WorkerScheduler; use nativelink_util::action_messages::{ - ActionInfoHashKey, ActionResult, ActionStage, ActionState, DirectoryInfo, ExecutionMetadata, - FileInfo, NameOrPath, OperationId, SymlinkInfo, WorkerId, INTERNAL_ERROR_EXIT_CODE, + ActionResult, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, DirectoryInfo, ExecutionMetadata, FileInfo, NameOrPath, OperationId, + SymlinkInfo, WorkerId, INTERNAL_ERROR_EXIT_CODE, }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::platform_properties::{PlatformProperties, PlatformPropertyValue}; use pretty_assertions::assert_eq; -use tokio::sync::{mpsc, watch}; +use tokio::sync::mpsc; use utils::scheduler_utils::{make_base_action_info, INSTANCE_NAME}; use uuid::Uuid; @@ -43,6 +47,45 @@ mod utils { pub(crate) mod scheduler_utils; } +fn update_eq(expected: UpdateForWorker, actual: UpdateForWorker, ignore_id: bool) -> bool { + let Some(expected_update) = expected.update else { + return actual.update.is_none(); + }; + let Some(actual_update) = actual.update else { + return false; + }; + match actual_update { + update_for_worker::Update::Disconnect(()) => { + matches!(expected_update, update_for_worker::Update::Disconnect(())) + } + update_for_worker::Update::KeepAlive(()) => { + matches!(expected_update, update_for_worker::Update::KeepAlive(())) + } + update_for_worker::Update::StartAction(actual_update) => match expected_update { + update_for_worker::Update::StartAction(mut expected_update) => { + if ignore_id { + expected_update + .operation_id + .clone_from(&actual_update.operation_id); + } + expected_update == actual_update + } + _ => false, + }, + update_for_worker::Update::KillOperationRequest(actual_update) => match expected_update { + update_for_worker::Update::KillOperationRequest(expected_update) => { + expected_update == actual_update + } + _ => false, + }, + update_for_worker::Update::ConnectionResult(actual_update) => match expected_update { + update_for_worker::Update::ConnectionResult(expected_update) => { + expected_update == actual_update + } + _ => false, + }, + } +} async fn verify_initial_connection_message( worker_id: WorkerId, rx: &mut mpsc::UnboundedReceiver, @@ -88,11 +131,11 @@ async fn setup_action( action_digest: DigestInfo, platform_properties: PlatformProperties, insert_timestamp: SystemTime, -) -> Result>, Error> { - let mut action_info = make_base_action_info(insert_timestamp); +) -> Result>, Error> { + let mut action_info = make_base_action_info(insert_timestamp, action_digest); action_info.platform_properties = platform_properties; - action_info.unique_qualifier.digest = action_digest; - let result = scheduler.add_action(action_info).await; + let client_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let result = scheduler.add_action(client_id, action_info).await; tokio::task::yield_now().await; // Allow task<->worker matcher to run. result } @@ -103,7 +146,7 @@ const WORKER_TIMEOUT_S: u64 = 100; async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -112,13 +155,14 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), insert_timestamp, ) - .await?; + .await + .unwrap(); { // Worker should have been sent an execute command. @@ -126,21 +170,21 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + // Operation ID is random so we ignore it. + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -156,7 +200,7 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { async fn find_executing_action() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -165,21 +209,23 @@ async fn find_executing_action() -> Result<(), Error> { let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let client_rx = setup_action( + let action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), insert_timestamp, ) - .await?; + .await + .unwrap(); + let client_operation_id = action_listener.client_operation_id().clone(); // Drop our receiver and look up a new one. - let unique_qualifier = client_rx.borrow().id.unique_qualifier.clone(); - drop(client_rx); - let mut client_rx = scheduler - .find_existing_action(&unique_qualifier) + drop(action_listener); + let mut action_listener = scheduler + .find_by_client_operation_id(&client_operation_id) .await - .err_tip(|| "Action not found")?; + .expect("Action not found") + .unwrap(); { // Worker should have been sent an execute command. @@ -187,21 +233,21 @@ async fn find_executing_action() -> Result<(), Error> { update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + // Operation ID is random so we ignore it. + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -217,7 +263,7 @@ async fn find_executing_action() -> Result<(), Error> { async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Error> { let worker_id1: WorkerId = WorkerId(Uuid::new_v4()); let worker_id2: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler { worker_timeout_s: WORKER_TIMEOUT_S, ..Default::default() @@ -230,7 +276,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, PlatformProperties::default()).await?; let insert_timestamp1 = make_system_time(1); - let mut client_rx1 = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest1, PlatformProperties::default(), @@ -238,7 +284,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err ) .await?; let insert_timestamp2 = make_system_time(2); - let mut client_rx2 = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest2, PlatformProperties::default(), @@ -246,83 +292,87 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err ) .await?; - let unique_qualifier = ActionInfoHashKey { - instance_name: "".to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::zero_digest(), - salt: 0, - }; - - let id = OperationId::new(unique_qualifier); - let mut expected_action_state1 = ActionState { - // Name is a random string, so we ignore it and just make it the same. - id: id.clone(), - stage: ActionStage::Executing, - }; - let mut expected_action_state2 = ActionState { - // Name is a random string, so we ignore it and just make it the same. - id, - stage: ActionStage::Executing, + let mut expected_start_execute_for_worker1 = StartExecute { + execute_request: Some(ExecuteRequest { + instance_name: INSTANCE_NAME.to_string(), + action_digest: Some(action_digest1.into()), + digest_function: digest_function::Value::Sha256.into(), + ..Default::default() + }), + operation_id: "WILL BE SET BELOW".to_string(), + queued_timestamp: Some(insert_timestamp1.into()), }; - let execution_request_for_worker1 = UpdateForWorker { - update: Some(update_for_worker::Update::StartAction(StartExecute { - execute_request: Some(ExecuteRequest { - instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, - action_digest: Some(action_digest1.into()), - digest_function: digest_function::Value::Sha256.into(), - ..Default::default() - }), - salt: 0, - queued_timestamp: Some(insert_timestamp1.into()), - })), + let mut expected_start_execute_for_worker2 = StartExecute { + execute_request: Some(ExecuteRequest { + instance_name: INSTANCE_NAME.to_string(), + action_digest: Some(action_digest2.into()), + digest_function: digest_function::Value::Sha256.into(), + ..Default::default() + }), + operation_id: "WILL BE SET BELOW".to_string(), + queued_timestamp: Some(insert_timestamp2.into()), }; - { - // Worker1 should now see execution request. - let msg_for_worker = rx_from_worker1.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker1); - } - let execution_request_for_worker2 = UpdateForWorker { - update: Some(update_for_worker::Update::StartAction(StartExecute { - execute_request: Some(ExecuteRequest { - instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, - action_digest: Some(action_digest2.into()), - digest_function: digest_function::Value::Sha256.into(), - ..Default::default() - }), - salt: 0, - queued_timestamp: Some(insert_timestamp2.into()), - })), + let operation_id1 = { + // Worker1 should now see first execution request. + let update_for_worker = rx_from_worker1 + .recv() + .await + .expect("Worker terminated stream") + .update + .expect("`update` should be set on UpdateForWorker"); + let (operation_id, rx_start_execute) = match update_for_worker { + update_for_worker::Update::StartAction(start_execute) => ( + OperationId::try_from(start_execute.operation_id.as_str()).unwrap(), + start_execute, + ), + v => panic!("Expected StartAction, got : {v:?}"), + }; + expected_start_execute_for_worker1.operation_id = operation_id.to_string(); + assert_eq!(expected_start_execute_for_worker1, rx_start_execute); + operation_id }; - { + let operation_id2 = { // Worker1 should now see second execution request. - let msg_for_worker = rx_from_worker1.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker2); - } + let update_for_worker = rx_from_worker1 + .recv() + .await + .expect("Worker terminated stream") + .update + .expect("`update` should be set on UpdateForWorker"); + let (operation_id, rx_start_execute) = match update_for_worker { + update_for_worker::Update::StartAction(start_execute) => ( + OperationId::try_from(start_execute.operation_id.as_str()).unwrap(), + start_execute, + ), + v => panic!("Expected StartAction, got : {v:?}"), + }; + expected_start_execute_for_worker2.operation_id = operation_id.to_string(); + assert_eq!(expected_start_execute_for_worker2, rx_start_execute); + operation_id + }; // Add a second worker that can take jobs if the first dies. let mut rx_from_worker2 = setup_new_worker(&scheduler, worker_id2, PlatformProperties::default()).await?; { + let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx1.borrow_and_update(); + let action_state = client1_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. - expected_action_state1.id = action_state.id.clone(); - assert_eq!(action_state.as_ref(), &expected_action_state1); + assert_eq!(&action_state.stage, &expected_action_stage); } { + let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx2.borrow_and_update(); + let action_state = client2_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. - expected_action_state2.id = action_state.id.clone(); - assert_eq!(action_state.as_ref(), &expected_action_state2); + assert_eq!(&action_state.stage, &expected_action_stage); } // Now remove worker. - scheduler.remove_worker(worker_id1).await; + let _ = scheduler.remove_worker(&worker_id1).await; tokio::task::yield_now().await; // Allow task<->worker matcher to run. { @@ -336,26 +386,44 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err ); } { + let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx1.borrow_and_update(); - expected_action_state1.stage = ActionStage::Executing; - assert_eq!(action_state.as_ref(), &expected_action_state1); + let action_state = client1_action_listener.changed().await.unwrap(); + // We now know the name of the action so populate it. + assert_eq!(&action_state.stage, &expected_action_stage); } { + let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx2.borrow_and_update(); - expected_action_state2.stage = ActionStage::Executing; - assert_eq!(action_state.as_ref(), &expected_action_state2); + let action_state = client2_action_listener.changed().await.unwrap(); + // We now know the name of the action so populate it. + assert_eq!(&action_state.stage, &expected_action_stage); } { // Worker2 should now see execution request. let msg_for_worker = rx_from_worker2.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker1); + expected_start_execute_for_worker1.operation_id = operation_id1.to_string(); + assert_eq!( + msg_for_worker, + UpdateForWorker { + update: Some(update_for_worker::Update::StartAction( + expected_start_execute_for_worker1 + )), + } + ); } { // Worker2 should now see execution request. let msg_for_worker = rx_from_worker2.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker2); + expected_start_execute_for_worker2.operation_id = operation_id2.to_string(); + assert_eq!( + msg_for_worker, + UpdateForWorker { + update: Some(update_for_worker::Update::StartAction( + expected_start_execute_for_worker2 + )), + } + ); } Ok(()) @@ -365,7 +433,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -374,7 +442,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -382,23 +450,29 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> ) .await?; - { + let _operation_id = { // Other tests check full data. We only care if we got StartAction. - match rx_from_worker.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + let operation_id = match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(start_execute)) => { + OperationId::try_from(start_execute.operation_id.as_str()).unwrap() + } v => panic!("Expected StartAction, got : {v:?}"), - } + }; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); - } + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); + operation_id + }; // Set the worker draining. - scheduler.set_drain_worker(worker_id, true).await?; + scheduler.set_drain_worker(&worker_id, true).await?; tokio::task::yield_now().await; let action_digest = DigestInfo::new([88u8; 32], 512); let insert_timestamp = make_system_time(14); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -408,7 +482,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> { // Client should get notification saying it's been queued. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -418,12 +492,12 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> } // Set the worker not draining. - scheduler.set_drain_worker(worker_id, false).await?; + scheduler.set_drain_worker(&worker_id, false).await?; tokio::task::yield_now().await; { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -440,7 +514,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E let worker_id1: WorkerId = WorkerId(Uuid::new_v4()); let worker_id2: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -459,7 +533,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, platform_properties.clone()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, worker_properties.clone(), @@ -469,7 +543,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E { // Client should get notification saying it's been queued. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -485,21 +559,20 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), })), }; let msg_for_worker = rx_from_worker2.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -521,18 +594,17 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); - let unique_qualifier = ActionInfoHashKey { + let unique_qualifier = ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: "".to_string(), digest: DigestInfo::zero_digest(), digest_function: DigestHasherFunc::Sha256, - salt: 0, - }; + }); let id = OperationId::new(unique_qualifier); let mut expected_action_state = ActionState { id, @@ -541,14 +613,14 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { let insert_timestamp1 = make_system_time(1); let insert_timestamp2 = make_system_time(2); - let mut client1_rx = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), insert_timestamp1, ) .await?; - let mut client2_rx = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -558,8 +630,8 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { { // Clients should get notification saying it's been queued. - let action_state1 = client1_rx.borrow_and_update(); - let action_state2 = client2_rx.borrow_and_update(); + let action_state1 = client1_action_listener.changed().await.unwrap(); + let action_state2 = client2_action_listener.changed().await.unwrap(); // Name is random so we set force it to be the same. expected_action_state.id = action_state1.id.clone(); assert_eq!(action_state1.as_ref(), &expected_action_state); @@ -575,17 +647,17 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp1.into()), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + // Operation ID is random so we ignore it. + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } // Action should now be executing. @@ -594,11 +666,11 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { // Both client1 and client2 should be receiving the same updates. // Most importantly the `name` (which is random) will be the same. assert_eq!( - client1_rx.borrow_and_update().as_ref(), + client1_action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); assert_eq!( - client2_rx.borrow_and_update().as_ref(), + client2_action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); } @@ -606,7 +678,7 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { { // Now if another action is requested it should also join with executing action. let insert_timestamp3 = make_system_time(2); - let mut client3_rx = setup_action( + let mut client3_action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -614,7 +686,7 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { ) .await?; assert_eq!( - client3_rx.borrow_and_update().as_ref(), + client3_action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); } @@ -625,7 +697,7 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { #[nativelink_test] async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -638,7 +710,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), drop(rx_from_worker); let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -647,7 +719,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), .await?; { // Client should get notification saying it's being queued not executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -663,7 +735,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { let worker_id1: WorkerId = WorkerId(Uuid::new_v4()); let worker_id2: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler { worker_timeout_s: WORKER_TIMEOUT_S, ..Default::default() @@ -676,7 +748,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -688,44 +760,49 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { let mut rx_from_worker2 = setup_new_worker(&scheduler, worker_id2, PlatformProperties::default()).await?; - let unique_qualifier = ActionInfoHashKey { - instance_name: "".to_string(), - digest: DigestInfo::zero_digest(), - digest_function: DigestHasherFunc::Sha256, - salt: 0, - }; - let id = OperationId::new(unique_qualifier); - let mut expected_action_state = ActionState { - id, - stage: ActionStage::Executing, - }; - - let execution_request_for_worker = UpdateForWorker { - update: Some(update_for_worker::Update::StartAction(StartExecute { - execute_request: Some(ExecuteRequest { - instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, - action_digest: Some(action_digest.into()), - digest_function: digest_function::Value::Sha256.into(), - ..Default::default() - }), - salt: 0, - queued_timestamp: Some(insert_timestamp.into()), - })), + let mut start_execute = StartExecute { + execute_request: Some(ExecuteRequest { + instance_name: INSTANCE_NAME.to_string(), + action_digest: Some(action_digest.into()), + digest_function: digest_function::Value::Sha256.into(), + ..Default::default() + }), + operation_id: "UNKNOWN HERE, WE WILL SET IT LATER".to_string(), + queued_timestamp: Some(insert_timestamp.into()), }; - { + let operation_id = { // Worker1 should now see execution request. let msg_for_worker = rx_from_worker1.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker); - } + let operation_id = if let update_for_worker::Update::StartAction(start_execute) = + msg_for_worker.update.as_ref().unwrap() + { + start_execute.operation_id.clone() + } else { + panic!("Expected StartAction, got : {msg_for_worker:?}"); + }; + start_execute.operation_id.clone_from(&operation_id); + assert_eq!( + msg_for_worker, + UpdateForWorker { + update: Some(update_for_worker::Update::StartAction( + start_execute.clone() + )), + } + ); + OperationId::try_from(operation_id.as_str()).unwrap() + }; { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); - // We now know the name of the action so populate it. - expected_action_state.id = action_state.id.clone(); - assert_eq!(action_state.as_ref(), &expected_action_state); + let action_state = action_listener.changed().await.unwrap(); + assert_eq!( + action_state.as_ref(), + &ActionState { + id: operation_id.clone(), + stage: ActionStage::Executing, + } + ); } // Keep worker 2 alive. @@ -750,14 +827,26 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); - expected_action_state.stage = ActionStage::Executing; - assert_eq!(action_state.as_ref(), &expected_action_state); + let action_state = action_listener.changed().await.unwrap(); + assert_eq!( + action_state.as_ref(), + &ActionState { + id: operation_id.clone(), + stage: ActionStage::Executing, + } + ); } { // Worker2 should now see execution request. let msg_for_worker = rx_from_worker2.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker); + assert_eq!( + msg_for_worker, + UpdateForWorker { + update: Some(update_for_worker::Update::StartAction( + start_execute.clone() + )), + } + ); } Ok(()) @@ -767,7 +856,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { async fn update_action_sends_completed_result_to_client_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -776,7 +865,7 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -784,22 +873,21 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err ) .await?; - { + let operation_id = { // Other tests check full data. We only care if we got StartAction. match rx_from_worker.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + Some(update_for_worker::Update::StartAction(start_execute)) => { + // Other tests check full data. We only care if client thinks we are Executing. + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); + start_execute.operation_id + } v => panic!("Expected StartAction, got : {v:?}"), } - // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); - } - - let action_info_hash_key = ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, }; + let action_result = ActionResult { output_files: vec![FileInfo { name_or_path: NameOrPath::Name("hello".to_string()), @@ -840,14 +928,14 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err scheduler .update_action( &worker_id, - action_info_hash_key, + &OperationId::try_from(operation_id.as_str())?, Ok(ActionStage::Completed(action_result.clone())), ) .await?; { // Client should get notification saying it has been completed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -855,14 +943,6 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err }; assert_eq!(action_state.as_ref(), &expected_action_state); } - { - // Update info for the action should now be closed (notification happens through Err). - let result = client_rx.changed().await; - assert!( - result.is_err(), - "Expected result to be an error : {result:?}" - ); - } Ok(()) } @@ -871,7 +951,7 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err async fn update_action_sends_completed_result_after_disconnect() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -880,7 +960,7 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let client_rx = setup_action( + let action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -888,24 +968,21 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E ) .await?; + let client_id = action_listener.client_operation_id().clone(); + // Drop our receiver and don't reconnect until completed. - let unique_qualifier = client_rx.borrow().id.unique_qualifier.clone(); - drop(client_rx); + drop(action_listener); - { + let operation_id = { // Other tests check full data. We only care if we got StartAction. - match rx_from_worker.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + let operation_id = match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(exec)) => exec.operation_id, v => panic!("Expected StartAction, got : {v:?}"), - } - } - - let action_info_hash_key = ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, + }; + // Other tests check full data. We only care if client thinks we are Executing. + OperationId::try_from(operation_id.as_str())? }; + let action_result = ActionResult { output_files: vec![FileInfo { name_or_path: NameOrPath::Name("hello".to_string()), @@ -946,19 +1023,20 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E scheduler .update_action( &worker_id, - action_info_hash_key, + &operation_id, Ok(ActionStage::Completed(action_result.clone())), ) .await?; // Now look up a channel after the action has completed. - let mut client_rx = scheduler - .find_existing_action(&unique_qualifier) + let mut action_listener = scheduler + .find_by_client_operation_id(&client_id) .await - .err_tip(|| "Action not found")?; + .unwrap() + .expect("Action not found"); { // Client should get notification saying it has been completed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -975,7 +1053,7 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { let good_worker_id: WorkerId = WorkerId(Uuid::new_v4()); let rogue_worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -984,7 +1062,7 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { let mut rx_from_worker = setup_new_worker(&scheduler, good_worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -999,15 +1077,18 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { v => panic!("Expected StartAction, got : {v:?}"), } // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } + let _ = setup_new_worker(&scheduler, rogue_worker_id, PlatformProperties::default()).await?; - let action_info_hash_key = ActionInfoHashKey { + let action_info_hash_key = ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: INSTANCE_NAME.to_string(), digest_function: DigestHasherFunc::Sha256, digest: action_digest, - salt: 0, - }; + }); let action_result = ActionResult { output_files: Vec::default(), output_folders: Vec::default(), @@ -1035,14 +1116,13 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { let update_action_result = scheduler .update_action( &rogue_worker_id, - action_info_hash_key, + &OperationId::new(action_info_hash_key), Ok(ActionStage::Completed(action_result.clone())), ) .await; { - const EXPECTED_ERR: &str = - "Got a result from a worker that should not be running the action"; + const EXPECTED_ERR: &str = "should not be running on worker"; // Our request should have sent an error back. assert!( update_action_result.is_err(), @@ -1058,8 +1138,8 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { { // Ensure client did not get notified. assert_eq!( - client_rx.has_changed().unwrap(), - false, + poll!(action_listener.changed()), + Poll::Pending, "Client should not have been notified of event" ); } @@ -1071,18 +1151,17 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); - let unique_qualifier = ActionInfoHashKey { + let unique_qualifier = ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: "".to_string(), digest: DigestInfo::zero_digest(), digest_function: DigestHasherFunc::Sha256, - salt: 0, - }; + }); let id = OperationId::new(unique_qualifier); let mut expected_action_state = ActionState { id, @@ -1090,7 +1169,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro }; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1106,26 +1185,27 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + // Operation ID is random so we ignore it. + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } - { + let operation_id = { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. expected_action_state.id = action_state.id.clone(); assert_eq!(action_state.as_ref(), &expected_action_state); - } + action_state.id.clone() + }; let action_result = ActionResult { output_files: Vec::default(), @@ -1155,12 +1235,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro scheduler .update_action( &worker_id, - ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, - }, + &operation_id, Ok(ActionStage::Completed(action_result.clone())), ) .await?; @@ -1169,7 +1244,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro // Action should now be executing. expected_action_state.stage = ActionStage::Completed(action_result.clone()); assert_eq!( - client_rx.borrow_and_update().as_ref(), + action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); } @@ -1179,7 +1254,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro { let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1188,7 +1263,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro .await?; // We didn't disconnect our worker, so it will have scheduled it to the worker. expected_action_state.stage = ActionStage::Executing; - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); // The name of the action changed (since it's a new action), so update it. expected_action_state.id = action_state.id.clone(); assert_eq!(action_state.as_ref(), &expected_action_state); @@ -1203,7 +1278,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -1216,7 +1291,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, platform_properties.clone()).await?; let insert_timestamp1 = make_system_time(1); - let mut client1_rx = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest1, platform_properties.clone(), @@ -1224,7 +1299,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> ) .await?; let insert_timestamp2 = make_system_time(1); - let mut client2_rx = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest2, platform_properties, @@ -1236,12 +1311,15 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } v => panic!("Expected StartAction, got : {v:?}"), } - { + let (operation_id1, operation_id2) = { + let state_1 = client1_action_listener.changed().await.unwrap(); + let state_2 = client2_action_listener.changed().await.unwrap(); // First client should be in an Executing state. - assert_eq!(client1_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!(state_1.stage, ActionStage::Executing); // Second client should be in a queued state. - assert_eq!(client2_rx.borrow_and_update().stage, ActionStage::Queued); - } + assert_eq!(state_2.stage, ActionStage::Queued); + (state_1.id.clone(), state_2.id.clone()) + }; let action_result = ActionResult { output_files: Vec::default(), @@ -1272,25 +1350,14 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> scheduler .update_action( &worker_id, - ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest1, - salt: 0, - }, + &operation_id1, Ok(ActionStage::Completed(action_result.clone())), ) .await?; - // Ensure client did not get notified. - assert!( - client1_rx.changed().await.is_ok(), - "Client should have been notified of event" - ); - { // First action should now be completed. - let action_state = client1_rx.borrow_and_update(); + let action_state = client1_action_listener.changed().await.unwrap(); let mut expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1311,26 +1378,24 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> v => panic!("Expected StartAction, got : {v:?}"), } // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client2_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + client2_action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } // Tell scheduler our second task is completed. scheduler .update_action( &worker_id, - ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest2, - salt: 0, - }, + &operation_id2, Ok(ActionStage::Completed(action_result.clone())), ) .await?; { // Our second client should be notified it completed. - let action_state = client2_rx.borrow_and_update(); + let action_state = client2_action_listener.changed().await.unwrap(); let mut expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1349,7 +1414,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -1363,7 +1428,7 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { // This is queued after the next one (even though it's placed in the map // first), so it should execute second. let insert_timestamp2 = make_system_time(2); - let mut client2_rx = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest2, platform_properties.clone(), @@ -1371,7 +1436,7 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { ) .await?; let insert_timestamp1 = make_system_time(1); - let mut client1_rx = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest1, platform_properties.clone(), @@ -1388,9 +1453,15 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { } { // First client should be in an Executing state. - assert_eq!(client1_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + client1_action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); // Second client should be in a queued state. - assert_eq!(client2_rx.borrow_and_update().stage, ActionStage::Queued); + assert_eq!( + client2_action_listener.changed().await.unwrap().stage, + ActionStage::Queued + ); } Ok(()) @@ -1400,9 +1471,9 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler { - max_job_retries: 2, + max_job_retries: 1, ..Default::default() }, || async move {}, @@ -1412,7 +1483,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1420,33 +1491,31 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> ) .await?; - { + let operation_id = { // Other tests check full data. We only care if we got StartAction. - match rx_from_worker.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + let operation_id = match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(exec)) => exec.operation_id, v => panic!("Expected StartAction, got : {v:?}"), - } + }; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); - } - - let action_info_hash_key = ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); + OperationId::try_from(operation_id.as_str())? }; + let _ = scheduler .update_action( &worker_id, - action_info_hash_key.clone(), + &operation_id, Err(make_err!(Code::Internal, "Some error")), ) .await; { // Client should get notification saying it has been queued again. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1465,18 +1534,21 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> v => panic!("Expected StartAction, got : {v:?}"), } // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } let err = make_err!(Code::Internal, "Some error"); // Send internal error from worker again. let _ = scheduler - .update_action(&worker_id, action_info_hash_key, Err(err.clone())) + .update_action(&worker_id, &operation_id, Err(err.clone())) .await; { // Client should get notification saying it has been queued again. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1501,14 +1573,23 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> output_upload_completed_timestamp: SystemTime::UNIX_EPOCH, }, server_logs: HashMap::default(), - error: Some(err.merge(make_err!( - Code::Internal, - "Job cancelled because it attempted to execute too many times and failed" - ))), + error: Some(err.clone()), message: String::new(), }), }; - assert_eq!(action_state.as_ref(), &expected_action_state); + let mut received_state = action_state.as_ref().clone(); + if let ActionStage::Completed(stage) = &mut received_state.stage { + if let Some(real_err) = &mut stage.error { + assert!( + real_err.to_string().contains("Job cancelled because it attempted to execute too many times and failed"), + "{real_err} did not contain 'Job cancelled because it attempted to execute too many times and failed'", + ); + *real_err = err; + } + } else { + panic!("Expected Completed, got : {:?}", action_state.stage); + }; + assert_eq!(received_state, expected_action_state); } Ok(()) @@ -1533,7 +1614,7 @@ async fn ensure_scheduler_drops_inner_spawn() -> Result<(), Error> { // Since the inner spawn owns this callback, we can use the callback to know if the // inner spawn was dropped because our callback would be dropped, which dropps our // DropChecker. - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), move || { // This will ensure dropping happens if this function is ever dropped. @@ -1558,7 +1639,7 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), let worker_id1: WorkerId = WorkerId(Uuid::new_v4()); let worker_id2: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -1566,7 +1647,7 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, PlatformProperties::default()).await?; - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1577,25 +1658,24 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), let mut rx_from_worker2 = setup_new_worker(&scheduler, worker_id2, PlatformProperties::default()).await?; - { + let operation_id = { // Other tests check full data. We only care if we got StartAction. - match rx_from_worker1.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + let operation_id = match rx_from_worker1.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(exec)) => exec.operation_id, v => panic!("Expected StartAction, got : {v:?}"), - } + }; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); - } + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); + OperationId::try_from(operation_id.as_str())? + }; let _ = scheduler .update_action( &worker_id1, - ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, - }, + &operation_id, Err(make_err!(Code::NotFound, "Some error")), ) .await; @@ -1610,7 +1690,10 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), .await .err_tip(|| "worker went away")?; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } Ok(()) diff --git a/nativelink-scheduler/tests/utils/mock_scheduler.rs b/nativelink-scheduler/tests/utils/mock_scheduler.rs index 0803be2ab..bf4362cc5 100644 --- a/nativelink-scheduler/tests/utils/mock_scheduler.rs +++ b/nativelink-scheduler/tests/utils/mock_scheduler.rs @@ -12,26 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; use nativelink_error::{make_input_err, Error}; -use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::action_scheduler::{ActionListener, ActionScheduler}; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionState}; -use tokio::sync::{mpsc, watch, Mutex}; +use nativelink_util::action_messages::{ActionInfo, ClientOperationId}; +use tokio::sync::{mpsc, Mutex}; #[allow(clippy::large_enum_variant)] enum ActionSchedulerCalls { GetPlatformPropertyManager(String), - AddAction(ActionInfo), - FindExistingAction(ActionInfoHashKey), + AddAction((ClientOperationId, ActionInfo)), + FindExistingAction(ClientOperationId), } enum ActionSchedulerReturns { GetPlatformPropertyManager(Result, Error>), - AddAction(Result>, Error>), - FindExistingAction(Option>>), + AddAction(Result>, Error>), + FindExistingAction(Result>>, Error>), } pub struct MockActionScheduler { @@ -81,8 +82,8 @@ impl MockActionScheduler { pub async fn expect_add_action( &self, - result: Result>, Error>, - ) -> ActionInfo { + result: Result>, Error>, + ) -> (ClientOperationId, ActionInfo) { let mut rx_call_lock = self.rx_call.lock().await; let ActionSchedulerCalls::AddAction(req) = rx_call_lock .recv() @@ -98,17 +99,17 @@ impl MockActionScheduler { req } - pub async fn expect_find_existing_action( + pub async fn expect_find_by_client_operation_id( &self, - result: Option>>, - ) -> ActionInfoHashKey { + result: Result>>, Error>, + ) -> ClientOperationId { let mut rx_call_lock = self.rx_call.lock().await; let ActionSchedulerCalls::FindExistingAction(req) = rx_call_lock .recv() .await .expect("Could not receive msg in mpsc") else { - panic!("Got incorrect call waiting for find_existing_action") + panic!("Got incorrect call waiting for find_by_client_operation_id") }; self.tx_resp .send(ActionSchedulerReturns::FindExistingAction(result)) @@ -142,10 +143,14 @@ impl ActionScheduler for MockActionScheduler { async fn add_action( &self, + client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error> { + ) -> Result>, Error> { self.tx_call - .send(ActionSchedulerCalls::AddAction(action_info)) + .send(ActionSchedulerCalls::AddAction(( + client_operation_id, + action_info, + ))) .expect("Could not send request to mpsc"); let mut rx_resp_lock = self.rx_resp.lock().await; match rx_resp_lock @@ -158,13 +163,13 @@ impl ActionScheduler for MockActionScheduler { } } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { self.tx_call .send(ActionSchedulerCalls::FindExistingAction( - unique_qualifier.clone(), + client_operation_id.clone(), )) .expect("Could not send request to mpsc"); let mut rx_resp_lock = self.rx_resp.lock().await; @@ -174,9 +179,7 @@ impl ActionScheduler for MockActionScheduler { .expect("Could not receive msg in mpsc") { ActionSchedulerReturns::FindExistingAction(result) => result, - _ => panic!("Expected find_existing_action return value"), + _ => panic!("Expected find_by_client_operation_id return value"), } } - - async fn clean_recently_completed_actions(&self) {} } diff --git a/nativelink-scheduler/tests/utils/scheduler_utils.rs b/nativelink-scheduler/tests/utils/scheduler_utils.rs index 7ee119c5f..06bb2020d 100644 --- a/nativelink-scheduler/tests/utils/scheduler_utils.rs +++ b/nativelink-scheduler/tests/utils/scheduler_utils.rs @@ -15,14 +15,17 @@ use std::collections::HashMap; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey}; +use nativelink_util::action_messages::{ActionInfo, ActionUniqueKey, ActionUniqueQualifier}; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::platform_properties::PlatformProperties; pub const INSTANCE_NAME: &str = "foobar_instance_name"; -pub fn make_base_action_info(insert_timestamp: SystemTime) -> ActionInfo { +pub fn make_base_action_info( + insert_timestamp: SystemTime, + action_digest: DigestInfo, +) -> ActionInfo { ActionInfo { command_digest: DigestInfo::new([0u8; 32], 0), input_root_digest: DigestInfo::new([0u8; 32], 0), @@ -33,12 +36,10 @@ pub fn make_base_action_info(insert_timestamp: SystemTime) -> ActionInfo { priority: 0, load_timestamp: UNIX_EPOCH, insert_timestamp, - unique_qualifier: ActionInfoHashKey { + unique_qualifier: ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: INSTANCE_NAME.to_string(), digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([0u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: false, + digest: action_digest, + }), } } diff --git a/nativelink-service/BUILD.bazel b/nativelink-service/BUILD.bazel index 57e53aa6f..f8f47072f 100644 --- a/nativelink-service/BUILD.bazel +++ b/nativelink-service/BUILD.bazel @@ -55,6 +55,7 @@ rust_test_suite( ], proc_macro_deps = [ "//nativelink-macro", + "@crates//:async-trait", ], deps = [ "//nativelink-config", @@ -64,6 +65,7 @@ rust_test_suite( "//nativelink-service", "//nativelink-store", "//nativelink-util", + "@crates//:async-lock", "@crates//:bytes", "@crates//:futures", "@crates//:hyper", diff --git a/nativelink-service/Cargo.toml b/nativelink-service/Cargo.toml index 18d889eeb..983d0f5d7 100644 --- a/nativelink-service/Cargo.toml +++ b/nativelink-service/Cargo.toml @@ -28,6 +28,8 @@ uuid = { version = "1.8.0", features = ["v4"] } [dev-dependencies] nativelink-macro = { path = "../nativelink-macro" } +async-trait = "0.1.80" +async-lock = "3.3.0" hyper = "0.14.28" maplit = "1.0.2" pretty_assertions = "1.4.0" diff --git a/nativelink-service/src/execution_server.rs b/nativelink-service/src/execution_server.rs index 35db08dbb..623025cf2 100644 --- a/nativelink-service/src/execution_server.rs +++ b/nativelink-service/src/execution_server.rs @@ -17,7 +17,8 @@ use std::pin::Pin; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use futures::{Stream, StreamExt}; +use futures::stream::unfold; +use futures::Stream; use nativelink_config::cas_server::{ExecutionConfig, InstanceName}; use nativelink_error::{make_input_err, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::execution_server::{ @@ -27,22 +28,47 @@ use nativelink_proto::build::bazel::remote::execution::v2::{ Action, Command, ExecuteRequest, WaitExecutionRequest, }; use nativelink_proto::google::longrunning::Operation; -use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::action_scheduler::{ActionListener, ActionScheduler}; use nativelink_store::ac_utils::get_and_decode_digest; use nativelink_store::store_manager::StoreManager; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionState, OperationId, DEFAULT_EXECUTION_PRIORITY, + ActionInfo, ActionUniqueKey, ActionUniqueQualifier, ClientOperationId, + DEFAULT_EXECUTION_PRIORITY, }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::{make_ctx_for_hash_func, DigestHasherFunc}; use nativelink_util::platform_properties::PlatformProperties; use nativelink_util::store_trait::Store; -use rand::{thread_rng, Rng}; -use tokio::sync::watch; -use tokio_stream::wrappers::WatchStream; use tonic::{Request, Response, Status}; use tracing::{error_span, event, instrument, Level}; +type InstanceInfoName = String; + +struct NativelinkClientOperationId { + instance_name: InstanceInfoName, + client_operation_id: ClientOperationId, +} + +impl NativelinkClientOperationId { + fn from_name(name: &str) -> Result { + let (instance_name, name) = name + .split_once('/') + .err_tip(|| "Expected instance_name and name to be separated by '/'")?; + Ok(Self { + instance_name: instance_name.to_string(), + client_operation_id: ClientOperationId::from_raw_string(name.to_string()), + }) + } + + fn into_string(self) -> String { + format!( + "{}/{}", + self.instance_name, + self.client_operation_id.into_string() + ) + } +} + struct InstanceInfo { scheduler: Arc, cas_store: Store, @@ -112,6 +138,17 @@ impl InstanceInfo { } } + let action_key = ActionUniqueKey { + instance_name, + digest_function, + digest: action_digest, + }; + let unique_qualifier = if skip_cache_lookup { + ActionUniqueQualifier::Uncachable(action_key) + } else { + ActionUniqueQualifier::Cachable(action_key) + }; + Ok(ActionInfo { command_digest, input_root_digest, @@ -120,17 +157,7 @@ impl InstanceInfo { priority, load_timestamp: UNIX_EPOCH, insert_timestamp: SystemTime::now(), - unique_qualifier: ActionInfoHashKey { - instance_name, - digest_function, - digest: action_digest, - salt: if action.do_not_cache { - thread_rng().gen::() - } else { - 0 - }, - }, - skip_cache_lookup, + unique_qualifier, }) } } @@ -139,7 +166,7 @@ pub struct ExecutionServer { instance_infos: HashMap, } -type ExecuteStream = Pin> + Send + Sync + 'static>>; +type ExecuteStream = Pin> + Send + 'static>>; impl ExecutionServer { pub fn new( @@ -179,11 +206,42 @@ impl ExecutionServer { Server::new(self) } - fn to_execute_stream(receiver: watch::Receiver>) -> Response { - let receiver_stream = Box::pin(WatchStream::new(receiver).map(|action_update| { - event!(Level::INFO, ?action_update, "Execute Resp Stream",); - Ok(Into::::into(action_update.as_ref().clone())) - })); + fn to_execute_stream( + nl_client_operation_id: NativelinkClientOperationId, + action_listener: Pin>, + ) -> Response { + let client_operation_id_string = nl_client_operation_id.into_string(); + let receiver_stream = Box::pin(unfold( + Some(action_listener), + move |maybe_action_listener| { + let client_operation_id_string = client_operation_id_string.clone(); + async move { + let mut action_listener = maybe_action_listener?; + match action_listener.changed().await { + Ok(action_update) => { + event!(Level::INFO, ?action_update, "Execute Resp Stream"); + let client_operation_id = ClientOperationId::from_raw_string( + client_operation_id_string.clone(), + ); + // If the action is finished we won't be sending any more updates. + let maybe_action_listener = if action_update.stage.is_finished() { + None + } else { + Some(action_listener) + }; + Some(( + Ok(action_update.as_operation(client_operation_id)), + maybe_action_listener, + )) + } + Err(err) => { + event!(Level::ERROR, ?err, "Error in action_listener stream"); + Some((Err(err.into()), None)) + } + } + } + }, + )); tonic::Response::new(receiver_stream) } @@ -213,7 +271,7 @@ impl ExecutionServer { get_and_decode_digest::(&instance_info.cas_store, digest.into()).await?; let action_info = instance_info .build_action_info( - instance_name, + instance_name.clone(), digest, &action, priority, @@ -225,38 +283,52 @@ impl ExecutionServer { ) .await?; - let rx = instance_info + let action_listener = instance_info .scheduler - .add_action(action_info) + .add_action( + ClientOperationId::new(action_info.unique_qualifier.clone()), + action_info, + ) .await .err_tip(|| "Failed to schedule task")?; - Ok(Self::to_execute_stream(rx)) + Ok(Self::to_execute_stream( + NativelinkClientOperationId { + instance_name, + client_operation_id: action_listener.client_operation_id().clone(), + }, + action_listener, + )) } async fn inner_wait_execution( &self, request: Request, ) -> Result, Status> { - let operation_id = OperationId::try_from(request.into_inner().name.as_str()) - .err_tip(|| "Decoding operation name into OperationId")?; - let Some(instance_info) = self - .instance_infos - .get(&operation_id.unique_qualifier.instance_name) - else { + let (instance_name, client_operation_id) = + NativelinkClientOperationId::from_name(&request.into_inner().name) + .map(|v| (v.instance_name, v.client_operation_id)) + .err_tip(|| "Failed to parse operation_id in ExecutionServer::wait_execution")?; + let Some(instance_info) = self.instance_infos.get(&instance_name) else { return Err(Status::not_found(format!( - "No scheduler with the instance name {}", - operation_id.unique_qualifier.instance_name + "No scheduler with the instance name {instance_name}" ))); }; let Some(rx) = instance_info .scheduler - .find_existing_action(&operation_id.unique_qualifier) + .find_by_client_operation_id(&client_operation_id) .await + .err_tip(|| "Error running find_existing_action in ExecutionServer::wait_execution")? else { return Err(Status::not_found("Failed to find existing task")); }; - Ok(Self::to_execute_stream(rx)) + Ok(Self::to_execute_stream( + NativelinkClientOperationId { + instance_name, + client_operation_id, + }, + rx, + )) } } diff --git a/nativelink-service/src/worker_api_server.rs b/nativelink-service/src/worker_api_server.rs index 14df2f82d..c4e377b99 100644 --- a/nativelink-service/src/worker_api_server.rs +++ b/nativelink-service/src/worker_api_server.rs @@ -27,11 +27,10 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ execute_result, ExecuteResult, GoingAwayRequest, KeepAliveRequest, SupportedProperties, UpdateForWorker, }; -use nativelink_scheduler::worker::{Worker}; +use nativelink_scheduler::worker::Worker; use nativelink_scheduler::worker_scheduler::WorkerScheduler; use nativelink_util::background_spawn; -use nativelink_util::action_messages::{ActionInfoHashKey, WorkerId}; -use nativelink_util::common::DigestInfo; +use nativelink_util::action_messages::{OperationId, WorkerId}; use nativelink_util::platform_properties::PlatformProperties; use tokio::sync::mpsc; use tokio::time::interval; @@ -188,7 +187,10 @@ impl WorkerApiServer { going_away_request: GoingAwayRequest, ) -> Result, Error> { let worker_id: WorkerId = going_away_request.worker_id.try_into()?; - self.scheduler.remove_worker(worker_id).await; + self.scheduler + .remove_worker(&worker_id) + .await + .err_tip(|| "While calling WorkerApiServer::inner_going_away")?; Ok(Response::new(())) } @@ -196,21 +198,11 @@ impl WorkerApiServer { &self, execute_result: ExecuteResult, ) -> Result, Error> { - let digest_function = execute_result - .digest_function() - .try_into() - .err_tip(|| "In inner_execution_response")?; let worker_id: WorkerId = execute_result.worker_id.try_into()?; - let action_digest: DigestInfo = execute_result - .action_digest - .err_tip(|| "Expected action_digest to exist")? - .try_into()?; - let action_info_hash_key = ActionInfoHashKey { - instance_name: execute_result.instance_name, - digest_function, - digest: action_digest, - salt: execute_result.salt, - }; + let operation_id = + OperationId::try_from(execute_result.operation_id.as_str()).err_tip(|| { + "Failed to convert operation_id in WorkerApiServer::inner_execution_response" + })?; match execute_result .result @@ -221,15 +213,15 @@ impl WorkerApiServer { .try_into() .err_tip(|| "Failed to convert ExecuteResponse into an ActionStage")?; self.scheduler - .update_action(&worker_id, action_info_hash_key, Ok(action_stage)) + .update_action(&worker_id, &operation_id, Ok(action_stage)) .await - .err_tip(|| format!("Failed to update_action {action_digest:?}"))?; + .err_tip(|| format!("Failed to operation {operation_id:?}"))?; } execute_result::Result::InternalError(e) => { self.scheduler - .update_action(&worker_id, action_info_hash_key, Err(e.into())) + .update_action(&worker_id, &operation_id, Err(e.into())) .await - .err_tip(|| format!("Failed to update_action {action_digest:?}"))?; + .err_tip(|| format!("Failed to operation {operation_id:?}"))?; } } Ok(Response::new(())) diff --git a/nativelink-service/tests/worker_api_server_test.rs b/nativelink-service/tests/worker_api_server_test.rs index 212d14dac..31df1d287 100644 --- a/nativelink-service/tests/worker_api_server_test.rs +++ b/nativelink-service/tests/worker_api_server_test.rs @@ -16,7 +16,10 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use async_lock::Mutex as AsyncMutex; +use async_trait::async_trait; use nativelink_config::cas_server::WorkerApiConfig; +use nativelink_config::schedulers::WorkerAllocationStrategy; use nativelink_error::{Error, ResultExt}; use nativelink_macro::nativelink_test; use nativelink_proto::build::bazel::remote::execution::v2::{ @@ -28,23 +31,99 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: execute_result, update_for_worker, ExecuteResult, KeepAliveRequest, SupportedProperties, }; use nativelink_proto::google::rpc::Status as ProtoStatus; -use nativelink_scheduler::action_scheduler::ActionScheduler; -use nativelink_scheduler::simple_scheduler::SimpleScheduler; +use nativelink_scheduler::api_worker_scheduler::ApiWorkerScheduler; +use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_scheduler::worker_scheduler::WorkerScheduler; use nativelink_service::worker_api_server::{ConnectWorkerStream, NowFn, WorkerApiServer}; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionStage, WorkerId}; +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionUniqueKey, ActionUniqueQualifier, OperationId, WorkerId, +}; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; +use nativelink_util::operation_state_manager::WorkerStateManager; use nativelink_util::platform_properties::PlatformProperties; use pretty_assertions::assert_eq; +use tokio::join; +use tokio::sync::{mpsc, Notify}; use tokio_stream::StreamExt; use tonic::Request; const BASE_NOW_S: u64 = 10; const BASE_WORKER_TIMEOUT_S: u64 = 100; +#[derive(Debug)] +enum WorkerStateManagerCalls { + UpdateOperation((OperationId, WorkerId, Result)), +} + +#[derive(Debug)] +enum WorkerStateManagerReturns { + UpdateOperation(Result<(), Error>), +} + +struct MockWorkerStateManager { + rx_call: Arc>>, + tx_call: mpsc::UnboundedSender, + rx_resp: Arc>>, + tx_resp: mpsc::UnboundedSender, +} + +impl MockWorkerStateManager { + pub fn new() -> Self { + let (tx_call, rx_call) = mpsc::unbounded_channel(); + let (tx_resp, rx_resp) = mpsc::unbounded_channel(); + Self { + rx_call: Arc::new(AsyncMutex::new(rx_call)), + tx_call, + rx_resp: Arc::new(AsyncMutex::new(rx_resp)), + tx_resp, + } + } + + pub async fn expect_update_operation( + &self, + result: Result<(), Error>, + ) -> (OperationId, WorkerId, Result) { + let mut rx_call_lock = self.rx_call.lock().await; + let recv = rx_call_lock.recv(); + let WorkerStateManagerCalls::UpdateOperation(req) = + recv.await.expect("Could not receive msg in mpsc"); + self.tx_resp + .send(WorkerStateManagerReturns::UpdateOperation(result)) + .expect("Could not send request to mpsc"); + req + } +} + +#[async_trait] +impl WorkerStateManager for MockWorkerStateManager { + async fn update_operation( + &self, + operation_id: &OperationId, + worker_id: &WorkerId, + action_stage: Result, + ) -> Result<(), Error> { + self.tx_call + .send(WorkerStateManagerCalls::UpdateOperation(( + operation_id.clone(), + *worker_id, + action_stage, + ))) + .expect("Could not send request to mpsc"); + let mut rx_resp_lock = self.rx_resp.lock().await; + match rx_resp_lock + .recv() + .await + .expect("Could not receive msg in mpsc") + { + WorkerStateManagerReturns::UpdateOperation(result) => result, + } + } +} + struct TestContext { - scheduler: Arc, + scheduler: Arc, + state_manager: Arc, worker_api_server: WorkerApiServer, connection_worker_stream: ConnectWorkerStream, worker_id: WorkerId, @@ -57,12 +136,16 @@ fn static_now_fn() -> Result { async fn setup_api_server(worker_timeout: u64, now_fn: NowFn) -> Result { const SCHEDULER_NAME: &str = "DUMMY_SCHEDULE_NAME"; - let scheduler = Arc::new(SimpleScheduler::new( - &nativelink_config::schedulers::SimpleScheduler { - worker_timeout_s: worker_timeout, - ..Default::default() - }, - )); + let platform_property_manager = Arc::new(PlatformPropertyManager::new(HashMap::new())); + let tasks_or_worker_change_notify = Arc::new(Notify::new()); + let state_manager = Arc::new(MockWorkerStateManager::new()); + let scheduler = ApiWorkerScheduler::new( + state_manager.clone(), + platform_property_manager, + WorkerAllocationStrategy::default(), + tasks_or_worker_change_notify, + worker_timeout, + ); let mut schedulers: HashMap> = HashMap::new(); schedulers.insert(SCHEDULER_NAME.to_string(), scheduler.clone()); @@ -107,6 +190,7 @@ async fn setup_api_server(worker_timeout: u64, now_fn: NowFn) -> Result Result<(), Box SystemTime { #[nativelink_test] pub async fn execution_response_success_test() -> Result<(), Box> { - let test_context = setup_api_server(BASE_WORKER_TIMEOUT_S, Box::new(static_now_fn)).await?; + let mut test_context = setup_api_server(BASE_WORKER_TIMEOUT_S, Box::new(static_now_fn)).await?; - const SALT: u64 = 5; let action_digest = DigestInfo::new([7u8; 32], 123); let instance_name = "instance_name".to_string(); - let action_info = ActionInfo { + let unique_qualifier = ActionUniqueQualifier::Uncachable(ActionUniqueKey { + instance_name: instance_name.clone(), + digest_function: DigestHasherFunc::Sha256, + digest: action_digest, + }); + let action_info = Arc::new(ActionInfo { command_digest: DigestInfo::new([0u8; 32], 0), input_root_digest: DigestInfo::new([0u8; 32], 0), timeout: Duration::MAX, @@ -300,15 +389,18 @@ pub async fn execution_response_success_test() -> Result<(), Box Result<(), Box { - drop(action_state); - // Note: `.changed()` might be triggered twice, since the first trigger - // might be Queued and the second will always be Executing, but there's no - // guarantee that the first trigger will be Queued. - client_action_state_receiver.changed().await?; - client_action_state_receiver.borrow() - } - _ => client_action_state_receiver.borrow(), - }; - assert_eq!(action_state.stage, ActionStage::Executing); - } - // Now send the result of our execution to the scheduler. - test_context - .worker_api_server - .execution_response(Request::new(result.clone())) - .await?; + let update_for_worker = test_context + .connection_worker_stream + .next() + .await + .expect("Worker stream ended early")? + .update + .expect("Expected update field to be populated"); + let update_for_worker::Update::StartAction(start_execute) = update_for_worker else { + panic!("Expected StartAction message"); + }; + assert_eq!(result.operation_id, start_execute.operation_id); { - // Check the result that the client would have received. - client_action_state_receiver.changed().await?; - let client_given_state = client_action_state_receiver.borrow(); - let execute_response = - if let execute_result::Result::ExecuteResponse(v) = result.result.unwrap() { - v - } else { - panic!("Expected type to be ExecuteResponse"); - }; - - assert_eq!( - client_given_state.stage, - execute_response.clone().try_into()? + // Ensure our state manager got the same result as the server. + let (execution_response_result, (operation_id, worker_id, client_given_state)) = join!( + test_context + .worker_api_server + .execution_response(Request::new(result.clone())), + test_context.state_manager.expect_update_operation(Ok(())), ); + execution_response_result.unwrap(); - // We just checked if conversion from ExecuteResponse into ActionStage was an exact mach. - // Now check if we cast the ActionStage into an ExecuteResponse we get the exact same struct. - assert_eq!(execute_response, client_given_state.stage.clone().into()); + assert_eq!(operation_id, expected_operation_id); + assert_eq!(worker_id, test_context.worker_id); + assert_eq!(client_given_state, Ok(execute_response.clone().try_into()?)); + assert_eq!(execute_response, client_given_state.unwrap().into()); } Ok(()) } diff --git a/nativelink-store/tests/cas_utils_test.rs b/nativelink-store/tests/cas_utils_test.rs index 8192e10a9..f352de961 100644 --- a/nativelink-store/tests/cas_utils_test.rs +++ b/nativelink-store/tests/cas_utils_test.rs @@ -23,7 +23,7 @@ fn sha256_is_zero_digest() { packed_hash: Sha256::new().finalize().into(), size_bytes: 0, }; - assert!(is_zero_digest(&digest)); + assert!(is_zero_digest(digest)); } #[test] @@ -34,7 +34,7 @@ fn sha256_is_non_zero_digest() { packed_hash: hasher.finalize().into(), size_bytes: 1, }; - assert!(!is_zero_digest(&digest)); + assert!(!is_zero_digest(digest)); } #[test] @@ -43,7 +43,7 @@ fn blake_is_zero_digest() { packed_hash: Blake3::new().finalize().into(), size_bytes: 0, }; - assert!(is_zero_digest(&digest)); + assert!(is_zero_digest(digest)); } #[test] @@ -54,5 +54,5 @@ fn blake_is_non_zero_digest() { packed_hash: hasher.finalize().into(), size_bytes: 1, }; - assert!(!is_zero_digest(&digest)); + assert!(!is_zero_digest(digest)); } diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index aae3c1cb5..7175719ca 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -11,6 +11,7 @@ rust_library( srcs = [ "src/action_messages.rs", "src/buf_channel.rs", + "src/chunked_stream.rs", "src/common.rs", "src/connection_manager.rs", "src/default_store_key_subscribe.rs", @@ -21,6 +22,7 @@ rust_library( "src/health_utils.rs", "src/lib.rs", "src/metrics_utils.rs", + "src/operation_state_manager.rs", "src/origin_context.rs", "src/platform_properties.rs", "src/proto_stream_utils.rs", @@ -40,6 +42,7 @@ rust_library( "//nativelink-error", "//nativelink-proto", "@crates//:async-lock", + "@crates//:bitflags", "@crates//:blake3", "@crates//:bytes", "@crates//:console-subscriber", @@ -49,6 +52,7 @@ rust_library( "@crates//:hyper-util", "@crates//:lru", "@crates//:parking_lot", + "@crates//:pin-project", "@crates//:pin-project-lite", "@crates//:prometheus-client", "@crates//:prost", diff --git a/nativelink-util/Cargo.toml b/nativelink-util/Cargo.toml index c15a014ef..a11c6a5e6 100644 --- a/nativelink-util/Cargo.toml +++ b/nativelink-util/Cargo.toml @@ -10,8 +10,10 @@ nativelink-proto = { path = "../nativelink-proto" } async-lock = "3.3.0" async-trait = "0.1.80" +bitflags = "2.5.0" blake3 = { version = "1.5.1", features = ["mmap"] } bytes = "1.6.0" +pin-project = "1.1.5" console-subscriber = { version = "0.3.0" } futures = "0.3.30" hex = "0.4.3" diff --git a/nativelink-util/src/action_messages.rs b/nativelink-util/src/action_messages.rs index 60cfe121c..6c5a5eca5 100644 --- a/nativelink-util/src/action_messages.rs +++ b/nativelink-util/src/action_messages.rs @@ -12,14 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -use std::sync::Arc; use std::time::{Duration, SystemTime}; -use blake3::Hasher as Blake3Hasher; use nativelink_error::{error_if, make_input_err, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::{ execution_stage, Action, ActionResult as ProtoActionResult, ExecuteOperationMetadata, @@ -43,39 +40,74 @@ use crate::platform_properties::PlatformProperties; /// Default priority remote execution jobs will get when not provided. pub const DEFAULT_EXECUTION_PRIORITY: i32 = 0; -pub type WorkerTimestamp = u64; +/// Exit code sent if there is an internal error. +pub const INTERNAL_ERROR_EXIT_CODE: i32 = -178; + +/// Holds an id that is unique to the client for a requested operation. +/// Each client should be issued a unique id even if they are attached +/// to the same underlying operation. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ClientOperationId(String); -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +impl ClientOperationId { + pub fn new(unique_qualifier: ActionUniqueQualifier) -> Self { + Self(OperationId::new(unique_qualifier).to_string()) + } + + pub fn from_raw_string(name: String) -> Self { + Self(name) + } + + pub fn into_string(self) -> String { + self.0 + } +} + +impl std::fmt::Display for ClientOperationId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", self.0.clone())) + } +} + +#[derive(Clone, Serialize, Deserialize)] pub struct OperationId { - pub unique_qualifier: ActionInfoHashKey, + pub unique_qualifier: ActionUniqueQualifier, pub id: Uuid, } -// TODO: Eventually we should make this it's own hash rather than delegate to ActionInfoHashKey. +impl PartialEq for OperationId { + fn eq(&self, other: &Self) -> bool { + self.id.eq(&other.id) + } +} + +impl Eq for OperationId {} + +impl PartialOrd for OperationId { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for OperationId { + fn cmp(&self, other: &Self) -> Ordering { + self.id.cmp(&other.id) + } +} + impl Hash for OperationId { fn hash(&self, state: &mut H) { - ActionInfoHashKey::hash(&self.unique_qualifier, state) + self.id.hash(state) } } impl OperationId { - pub fn new(unique_qualifier: ActionInfoHashKey) -> Self { + pub fn new(unique_qualifier: ActionUniqueQualifier) -> Self { Self { - id: uuid::Uuid::new_v4(), + id: Uuid::new_v4(), unique_qualifier, } } - - /// Utility function used to make a unique hash of the digest including the salt. - pub fn get_hash(&self) -> [u8; 32] { - self.unique_qualifier.get_hash() - } - - /// Returns the salt used for cache busting/hashing. - #[inline] - pub fn action_name(&self) -> String { - self.unique_qualifier.action_name() - } } impl TryFrom<&str> for OperationId { @@ -84,7 +116,7 @@ impl TryFrom<&str> for OperationId { /// Attempts to convert a string slice into an `OperationId`. /// /// The input string `value` is expected to be in the format: - /// `//-//`. + /// `//-//`. /// /// # Parameters /// @@ -105,7 +137,7 @@ impl TryFrom<&str> for OperationId { /// /// ```no_run /// use nativelink_util::action_messages::OperationId; - /// let operation_id_str = "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"; + /// let operation_id_str = "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"; /// let operation_id = OperationId::try_from(operation_id_str); /// ``` /// @@ -119,30 +151,41 @@ impl TryFrom<&str> for OperationId { .err_tip(|| format!("Invalid OperationId unique_qualifier / id fragment - {value}"))?; let (instance_name, rest) = unique_qualifier .split_once('/') - .err_tip(|| format!("Invalid ActionInfoHashKey instance name fragment - {value}"))?; + .err_tip(|| format!("Invalid UniqueQualifier instance name fragment - {value}"))?; let (digest_function, rest) = rest .split_once('/') - .err_tip(|| format!("Invalid ActionInfoHashKey digest function fragment - {value}"))?; + .err_tip(|| format!("Invalid UniqueQualifier digest function fragment - {value}"))?; let (digest_hash, rest) = rest .split_once('-') - .err_tip(|| format!("Invalid ActionInfoHashKey digest hash fragment - {value}"))?; - let (digest_size, salt) = rest + .err_tip(|| format!("Invalid UniqueQualifier digest hash fragment - {value}"))?; + let (digest_size, cachable) = rest .split_once('/') - .err_tip(|| format!("Invalid ActionInfoHashKey digest size fragment - {value}"))?; + .err_tip(|| format!("Invalid UniqueQualifier digest size fragment - {value}"))?; let digest = DigestInfo::try_new( digest_hash, digest_size .parse::() - .err_tip(|| format!("Invalid ActionInfoHashKey size value fragment - {value}"))?, + .err_tip(|| format!("Invalid UniqueQualifier size value fragment - {value}"))?, ) .err_tip(|| format!("Invalid DigestInfo digest hash - {value}"))?; - let salt = u64::from_str_radix(salt, 16) - .err_tip(|| format!("Invalid ActionInfoHashKey salt hex conversion - {value}"))?; - let unique_qualifier = ActionInfoHashKey { + let cachable = match cachable { + "u" => false, + "c" => true, + _ => { + return Err(make_input_err!( + "Invalid UniqueQualifier cachable value fragment - {value}" + )); + } + }; + let unique_key = ActionUniqueKey { instance_name: instance_name.to_string(), digest_function: digest_function.try_into()?, digest, - salt, + }; + let unique_qualifier = if cachable { + ActionUniqueQualifier::Cachable(unique_key) + } else { + ActionUniqueQualifier::Uncachable(unique_key) }; let id = Uuid::parse_str(id).map_err(|e| make_input_err!("Failed to parse {e} as uuid"))?; Ok(Self { @@ -154,26 +197,18 @@ impl TryFrom<&str> for OperationId { impl std::fmt::Display for OperationId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!( - "{}/{}", - self.unique_qualifier.action_name(), - self.id - )) + f.write_fmt(format_args!("{}/{}", self.unique_qualifier, self.id)) } } impl std::fmt::Debug for OperationId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!( - "{}:{}", - self.unique_qualifier.action_name(), - self.id - )) + std::fmt::Display::fmt(&self, f) } } /// Unique id of worker. -#[derive(Eq, PartialEq, Hash, Copy, Clone, Serialize, Deserialize)] +#[derive(Default, Eq, PartialEq, Hash, Copy, Clone, Serialize, Deserialize)] pub struct WorkerId(pub Uuid); impl std::fmt::Display for WorkerId { @@ -186,9 +221,7 @@ impl std::fmt::Display for WorkerId { impl std::fmt::Debug for WorkerId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut buf = Uuid::encode_buffer(); - let worker_id_str = self.0.hyphenated().encode_lower(&mut buf); - f.write_fmt(format_args!("{worker_id_str}")) + std::fmt::Display::fmt(&self, f) } } @@ -197,58 +230,76 @@ impl TryFrom for WorkerId { fn try_from(s: String) -> Result { match Uuid::parse_str(&s) { Err(e) => Err(make_input_err!( - "Failed to convert string to WorkerId : {} : {:?}", - s, - e + "Failed to convert string to WorkerId : {s} : {e:?}", )), Ok(my_uuid) => Ok(WorkerId(my_uuid)), } } } + +/// Holds the information needed to uniquely identify an action +/// and if it is cachable or not. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub enum ActionUniqueQualifier { + /// The action is cachable. + Cachable(ActionUniqueKey), + /// The action is uncachable. + Uncachable(ActionUniqueKey), +} + +impl ActionUniqueQualifier { + /// Get the instance_name of the action. + pub const fn instance_name(&self) -> &String { + match self { + Self::Cachable(action) => &action.instance_name, + Self::Uncachable(action) => &action.instance_name, + } + } + + /// Get the digest function of the action. + pub const fn digest_function(&self) -> DigestHasherFunc { + match self { + Self::Cachable(action) => action.digest_function, + Self::Uncachable(action) => action.digest_function, + } + } + + /// Get the digest of the action. + pub const fn digest(&self) -> DigestInfo { + match self { + Self::Cachable(action) => action.digest, + Self::Uncachable(action) => action.digest, + } + } +} + +impl std::fmt::Display for ActionUniqueQualifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let (cachable, unique_key) = match self { + Self::Cachable(action) => (true, action), + Self::Uncachable(action) => (false, action), + }; + f.write_fmt(format_args!( + "{}/{}/{}-{}/{}", + unique_key.instance_name, + unique_key.digest_function, + unique_key.digest.hash_str(), + unique_key.digest.size_bytes, + if cachable { 'c' } else { 'u' }, + )) + } +} + /// This is a utility struct used to make it easier to match `ActionInfos` in a /// `HashMap` without needing to construct an entire `ActionInfo`. -/// Since the hashing only needs the digest and salt we can just alias them here -/// and point the original `ActionInfo` structs to reference these structs for -/// it's hashing functions. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ActionInfoHashKey { +#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct ActionUniqueKey { /// Name of instance group this action belongs to. pub instance_name: String, /// The digest function this action expects. pub digest_function: DigestHasherFunc, /// Digest of the underlying `Action`. pub digest: DigestInfo, - /// Salt that can be filled with a random number to ensure no `ActionInfo` will be a match - /// to another `ActionInfo` in the scheduler. When caching is wanted this value is usually - /// zero. - pub salt: u64, -} - -impl ActionInfoHashKey { - /// Utility function used to make a unique hash of the digest including the salt. - pub fn get_hash(&self) -> [u8; 32] { - Blake3Hasher::new() - .update(self.instance_name.as_bytes()) - .update(&i32::from(self.digest_function.proto_digest_func()).to_le_bytes()) - .update(&self.digest.packed_hash[..]) - .update(&self.digest.size_bytes.to_le_bytes()) - .update(&self.salt.to_le_bytes()) - .finalize() - .into() - } - - /// Returns the salt used for cache busting/hashing. - #[inline] - pub fn action_name(&self) -> String { - format!( - "{}/{}/{}-{}/{:X}", - self.instance_name, - self.digest_function, - self.digest.hash_str(), - self.digest.size_bytes, - self.salt - ) - } } /// Information needed to execute an action. This struct is used over bazel's proto `Action` @@ -272,47 +323,43 @@ pub struct ActionInfo { pub load_timestamp: SystemTime, /// When this action was created. pub insert_timestamp: SystemTime, - - /// Info used to uniquely identify this ActionInfo. Normally the hash function would just - /// use the fields it needs and you wouldn't need to separate them, however we have a use - /// case where we sometimes want to lookup an entry in a HashMap, but we don't have the - /// info to construct an entire ActionInfo. In such case we construct only a ActionInfoHashKey - /// then use that object to lookup the entry in the map. The root problem is that HashMap - /// requires `ActionInfo :Borrow` in order for this to work, which means - /// we need to be able to return a &ActionInfoHashKey from ActionInfo, but since we cannot - /// return a temporary reference we must have an object tied to ActionInfo's lifetime and - /// return it's reference. - pub unique_qualifier: ActionInfoHashKey, - - /// Whether to try looking up this action in the cache. - pub skip_cache_lookup: bool, + /// Info used to uniquely identify this ActionInfo and if it is cachable. + /// This is primarily used to join actions/operations together using this key. + pub unique_qualifier: ActionUniqueQualifier, } impl ActionInfo { #[inline] pub const fn instance_name(&self) -> &String { - &self.unique_qualifier.instance_name + self.unique_qualifier.instance_name() } /// Returns the underlying digest of the `Action`. #[inline] - pub const fn digest(&self) -> &DigestInfo { - &self.unique_qualifier.digest - } - - /// Returns the salt used for cache busting/hashing. - #[inline] - pub const fn salt(&self) -> &u64 { - &self.unique_qualifier.salt + pub const fn digest(&self) -> DigestInfo { + self.unique_qualifier.digest() } - pub fn try_from_action_and_execute_request_with_salt( + pub fn try_from_action_and_execute_request( execute_request: ExecuteRequest, action: Action, - salt: u64, load_timestamp: SystemTime, queued_timestamp: SystemTime, ) -> Result { + let unique_key = ActionUniqueKey { + instance_name: execute_request.instance_name, + digest_function: DigestHasherFunc::try_from(execute_request.digest_function) + .err_tip(|| format!("Could not find digest_function in try_from_action_and_execute_request {:?}", execute_request.digest_function))?, + digest: execute_request + .action_digest + .err_tip(|| "Expected action_digest to exist on ExecuteRequest")? + .try_into()?, + }; + let unique_qualifier = if execute_request.skip_cache_lookup { + ActionUniqueQualifier::Uncachable(unique_key) + } else { + ActionUniqueQualifier::Cachable(unique_key) + }; Ok(Self { command_digest: action .command_digest @@ -328,20 +375,13 @@ impl ActionInfo { .try_into() .map_err(|_| make_input_err!("Failed convert proto duration to system duration"))?, platform_properties: action.platform.unwrap_or_default().into(), - priority: execute_request.execution_policy.unwrap_or_default().priority, + priority: execute_request + .execution_policy + .unwrap_or_default() + .priority, load_timestamp, insert_timestamp: queued_timestamp, - unique_qualifier: ActionInfoHashKey { - instance_name: execute_request.instance_name, - digest_function: DigestHasherFunc::try_from(execute_request.digest_function) - .err_tip(|| format!("Could not find digest_function in try_from_action_and_execute_request_with_salt {:?}", execute_request.digest_function))?, - digest: execute_request - .action_digest - .err_tip(|| "Expected action_digest to exist on ExecuteRequest")? - .try_into()?, - salt, - }, - skip_cache_lookup: execute_request.skip_cache_lookup, + unique_qualifier, }) } } @@ -349,92 +389,21 @@ impl ActionInfo { impl From for ExecuteRequest { fn from(val: ActionInfo) -> Self { let digest = val.digest().into(); + let (skip_cache_lookup, unique_qualifier) = match val.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_qualifier) => (false, unique_qualifier), + ActionUniqueQualifier::Uncachable(unique_qualifier) => (true, unique_qualifier), + }; Self { - instance_name: val.unique_qualifier.instance_name, + instance_name: unique_qualifier.instance_name, action_digest: Some(digest), - skip_cache_lookup: true, // The worker should never cache lookup. - execution_policy: None, // Not used in the worker. + skip_cache_lookup, + execution_policy: None, // Not used in the worker. results_cache_policy: None, // Not used in the worker. - digest_function: val - .unique_qualifier - .digest_function - .proto_digest_func() - .into(), + digest_function: unique_qualifier.digest_function.proto_digest_func().into(), } } } -// Note: Hashing, Eq, and Ord matching on this struct is unique. Normally these functions -// must play well with each other, but in our case the following rules apply: -// * Hash - Hashing must be unique on the exact command being run and must never match -// when do_not_cache is enabled, but must be consistent between identical data -// hashes. -// * Eq - Same as hash. -// * Ord - Used when sorting `ActionInfo` together. The only major sorting is priority and -// insert_timestamp, everything else is undefined, but must be deterministic. -impl Hash for ActionInfo { - fn hash(&self, state: &mut H) { - ActionInfoHashKey::hash(&self.unique_qualifier, state); - } -} - -impl PartialEq for ActionInfo { - fn eq(&self, other: &Self) -> bool { - ActionInfoHashKey::eq(&self.unique_qualifier, &other.unique_qualifier) - } -} - -impl Eq for ActionInfo {} - -impl Ord for ActionInfo { - fn cmp(&self, other: &Self) -> Ordering { - // Want the highest priority on top, but the lowest insert_timestamp. - self.priority - .cmp(&other.priority) - .then_with(|| other.insert_timestamp.cmp(&self.insert_timestamp)) - .then_with(|| self.salt().cmp(other.salt())) - .then_with(|| self.digest().size_bytes.cmp(&other.digest().size_bytes)) - .then_with(|| self.digest().packed_hash.cmp(&other.digest().packed_hash)) - .then_with(|| { - self.unique_qualifier - .digest_function - .cmp(&other.unique_qualifier.digest_function) - }) - } -} - -impl PartialOrd for ActionInfo { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Borrow for Arc { - #[inline] - fn borrow(&self) -> &ActionInfoHashKey { - &self.unique_qualifier - } -} - -impl Hash for ActionInfoHashKey { - fn hash(&self, state: &mut H) { - // Digest is unique, so hashing it is all we need. - self.digest_function.hash(state); - self.digest.hash(state); - self.salt.hash(state); - } -} - -impl PartialEq for ActionInfoHashKey { - fn eq(&self, other: &Self) -> bool { - self.digest == other.digest - && self.salt == other.salt - && self.digest_function == other.digest_function - } -} - -impl Eq for ActionInfoHashKey {} - /// Simple utility struct to determine if a string is representing a full path or /// just the name of the file. /// This is in order to be able to reuse the same struct instead of building different @@ -728,9 +697,6 @@ impl TryFrom for ExecutionMetadata { } } -/// Exit code sent if there is an internal error. -pub const INTERNAL_ERROR_EXIT_CODE: i32 = -178; - /// Represents the results of an execution. /// This struct must be 100% compatible with `ActionResult` in `remote_execution.proto`. #[derive(Eq, PartialEq, Debug, Clone, Serialize, Deserialize)] @@ -813,6 +779,20 @@ impl ActionStage { pub const fn is_finished(&self) -> bool { self.has_action_result() } + + /// Returns if the stage enum is the same as the other stage enum, but + /// does not compare the values of the enum. + pub const fn is_same_stage(&self, other: &Self) -> bool { + matches!( + (self, other), + (Self::Unknown, Self::Unknown) + | (Self::CacheCheck, Self::CacheCheck) + | (Self::Queued, Self::Queued) + | (Self::Executing, Self::Executing) + | (Self::Completed(_), Self::Completed(_)) + | (Self::CompletedFromCache(_), Self::CompletedFromCache(_)) + ) + } } impl MetricsComponent for ActionStage { @@ -1093,10 +1073,19 @@ where } } -impl TryFrom for ActionState { - type Error = Error; +/// Current state of the action. +/// This must be 100% compatible with `Operation` in `google/longrunning/operations.proto`. +#[derive(PartialEq, Debug, Clone)] +pub struct ActionState { + pub stage: ActionStage, + pub id: OperationId, +} - fn try_from(operation: Operation) -> Result { +impl ActionState { + pub fn try_from_operation( + operation: Operation, + operation_id: OperationId, + ) -> Result { let metadata = from_any::( &operation .metadata @@ -1135,51 +1124,23 @@ impl TryFrom for ActionState { } }; - // NOTE: This will error if we are forwarding an operation from - // one remote execution system to another that does not use our operation name - // format (ie: very unlikely, but possible). - let id = OperationId::try_from(operation.name.as_str())?; - Ok(Self { id, stage }) - } -} - -/// Current state of the action. -/// This must be 100% compatible with `Operation` in `google/longrunning/operations.proto`. -#[derive(PartialEq, Debug, Clone)] -pub struct ActionState { - pub stage: ActionStage, - pub id: OperationId, -} - -impl ActionState { - #[inline] - pub fn unique_qualifier(&self) -> &ActionInfoHashKey { - &self.id.unique_qualifier - } - #[inline] - pub fn action_digest(&self) -> &DigestInfo { - &self.id.unique_qualifier.digest - } -} - -impl MetricsComponent for ActionState { - fn gather_metrics(&self, c: &mut CollectorState) { - c.publish("stage", &self.stage, ""); + Ok(Self { + id: operation_id, + stage, + }) } -} -impl From for Operation { - fn from(val: ActionState) -> Self { - let stage = Into::::into(&val.stage) as i32; - let name = val.id.to_string(); + pub fn as_operation(&self, client_operation_id: ClientOperationId) -> Operation { + let stage = Into::::into(&self.stage) as i32; + let name = client_operation_id.into_string(); - let result = if val.stage.has_action_result() { - let execute_response: ExecuteResponse = val.stage.into(); + let result = if self.stage.has_action_result() { + let execute_response: ExecuteResponse = self.stage.clone().into(); Some(LongRunningResult::Response(to_any(&execute_response))) } else { None }; - let digest = Some(val.id.unique_qualifier.digest.into()); + let digest = Some(self.id.unique_qualifier.digest().into()); let metadata = ExecuteOperationMetadata { stage, @@ -1190,7 +1151,7 @@ impl From for Operation { partial_execution_metadata: None, }; - Self { + Operation { name, metadata: Some(to_any(&metadata)), done: result.is_some(), @@ -1198,3 +1159,9 @@ impl From for Operation { } } } + +impl MetricsComponent for ActionState { + fn gather_metrics(&self, c: &mut CollectorState) { + c.publish("stage", &self.stage, ""); + } +} diff --git a/nativelink-util/src/chunked_stream.rs b/nativelink-util/src/chunked_stream.rs new file mode 100644 index 000000000..e3665562d --- /dev/null +++ b/nativelink-util/src/chunked_stream.rs @@ -0,0 +1,110 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; +use std::ops::Bound; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{Future, Stream}; +use pin_project::pin_project; + +#[pin_project(project = StreamStateProj)] +enum StreamState { + Future(#[pin] Fut), + Next, +} + +/// Takes a range of keys and a function that returns a future that yields +/// an iterator of key-value pairs. The stream will yield all key-value pairs +/// in the range, in order and buffered. A great use case is where you need +/// to implement Stream, but to access the underlying data requires a lock, +/// but API does not require the data to be in sync with data already received. +#[pin_project] +pub struct ChunkedStream +where + K: Ord, + F: FnMut(Bound, Bound, VecDeque) -> Fut, + Fut: Future, Bound), VecDeque)>, E>>, +{ + chunk_fn: F, + buffer: VecDeque, + start_key: Option>, + end_key: Option>, + #[pin] + stream_state: StreamState, +} + +impl ChunkedStream +where + K: Ord, + F: FnMut(Bound, Bound, VecDeque) -> Fut, + Fut: Future, Bound), VecDeque)>, E>>, +{ + pub fn new(start_key: Bound, end_key: Bound, chunk_fn: F) -> Self { + Self { + chunk_fn, + buffer: VecDeque::new(), + start_key: Some(start_key), + end_key: Some(end_key), + stream_state: StreamState::Next, + } + } +} + +impl Stream for ChunkedStream +where + K: Ord, + F: FnMut(Bound, Bound, VecDeque) -> Fut, + Fut: Future, Bound), VecDeque)>, E>>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + loop { + if let Some(item) = this.buffer.pop_front() { + return Poll::Ready(Some(Ok(item))); + } + match this.stream_state.as_mut().project() { + StreamStateProj::Future(fut) => { + match futures::ready!(fut.poll(cx)) { + Ok(Some(((start, end), mut buffer))) => { + *this.start_key = Some(start); + *this.end_key = Some(end); + std::mem::swap(&mut buffer, this.buffer); + } + Ok(None) => return Poll::Ready(None), // End of stream. + Err(err) => return Poll::Ready(Some(Err(err))), + } + this.stream_state.set(StreamState::Next); + // Loop again. + } + StreamStateProj::Next => { + this.buffer.clear(); + // This trick is used to recycle capacity. + let buffer = std::mem::take(this.buffer); + let start_key = this + .start_key + .take() + .expect("start_key should never be None"); + let end_key = this.end_key.take().expect("end_key should never be None"); + let fut = (this.chunk_fn)(start_key, end_key, buffer); + this.stream_state.set(StreamState::Future(fut)); + // Loop again. + } + } + } + } +} diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 7fc457290..04d7571b8 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -14,6 +14,7 @@ pub mod action_messages; pub mod buf_channel; +pub mod chunked_stream; pub mod common; pub mod connection_manager; pub mod default_store_key_subscribe; @@ -23,6 +24,7 @@ pub mod fastcdc; pub mod fs; pub mod health_utils; pub mod metrics_utils; +pub mod operation_state_manager; pub mod origin_context; pub mod platform_properties; pub mod proto_stream_utils; diff --git a/nativelink-scheduler/src/operation_state_manager.rs b/nativelink-util/src/operation_state_manager.rs similarity index 56% rename from nativelink-scheduler/src/operation_state_manager.rs rename to nativelink-util/src/operation_state_manager.rs index 2b7184d3f..cb1b331e3 100644 --- a/nativelink-scheduler/src/operation_state_manager.rs +++ b/nativelink-util/src/operation_state_manager.rs @@ -20,14 +20,15 @@ use async_trait::async_trait; use bitflags::bitflags; use futures::Stream; use nativelink_error::Error; -use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionStage, ActionState, OperationId, WorkerId, +use prometheus_client::registry::Registry; + +use crate::action_messages::{ + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ClientOperationId, OperationId, WorkerId, }; -use nativelink_util::common::DigestInfo; -use tokio::sync::watch; +use crate::common::DigestInfo; bitflags! { - #[derive(Debug, Clone, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct OperationStageFlags: u32 { const CacheCheck = 1 << 1; const Queued = 1 << 2; @@ -37,24 +38,39 @@ bitflags! { } } +impl Default for OperationStageFlags { + fn default() -> Self { + Self::Any + } +} + #[async_trait] pub trait ActionStateResult: Send + Sync + 'static { // Provides the current state of the action. async fn as_state(&self) -> Result, Error>; - // Subscribes to the state of the action, receiving updates as they are published. - async fn as_receiver(&self) -> Result<&'_ watch::Receiver>, Error>; + // Waits for the state of the action to change. + async fn changed(&mut self) -> Result, Error>; // Provide result as action info. This behavior will not be supported by all implementations. - // TODO(adams): Expectation is this to experimental and removed in the future. async fn as_action_info(&self) -> Result, Error>; } +/// The direction in which the results are ordered. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OrderDirection { + Asc, + Desc, +} + /// The filters used to query operations from the state manager. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] pub struct OperationFilter { // TODO(adams): create rust builder pattern? /// The stage(s) that the operation must be in. pub stages: OperationStageFlags, + /// The client operation id. + pub client_operation_id: Option, + /// The operation id. pub operation_id: Option, @@ -70,80 +86,67 @@ pub struct OperationFilter { /// The operation must have been completed before this time. pub completed_before: Option, - /// The operation must have it's last client update before this time. - pub last_client_update_before: Option, - /// The unique key for filtering specific action results. - pub unique_qualifier: Option, - - /// The order by in which results are returned by the filter operation. - pub order_by: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum OperationFields { - Priority, - Timestamp, -} + pub unique_key: Option, -/// The order in which results are returned by the filter operation. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct OrderBy { - /// The fields to order by, each field is ordered in the order they are provided. - pub fields: Vec, - /// The order of the fields, true for descending, false for ascending. - pub desc: bool, + /// If the results should be ordered by priority and in which direction. + pub order_by_priority_direction: Option, } -pub type ActionStateResultStream = Pin> + Send>>; +pub type ActionStateResultStream<'a> = + Pin> + Send + 'a>>; #[async_trait] -pub trait ClientStateManager { +pub trait ClientStateManager: Sync + Send { /// Add a new action to the queue or joins an existing action. async fn add_action( - &mut self, - action_info: ActionInfo, - ) -> Result, Error>; + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result, Error>; /// Returns a stream of operations that match the filter. - async fn filter_operations( - &self, + async fn filter_operations<'a>( + &'a self, filter: OperationFilter, - ) -> Result; + ) -> Result, Error>; + + /// Register metrics with the registry. + fn register_metrics(self: Arc, _registry: &mut Registry) {} } #[async_trait] -pub trait WorkerStateManager { +pub trait WorkerStateManager: Sync + Send { /// Update that state of an operation. /// The worker must also send periodic updates even if the state /// did not change with a modified timestamp in order to prevent /// the operation from being considered stale and being rescheduled. async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: WorkerId, + &self, + operation_id: &OperationId, + worker_id: &WorkerId, action_stage: Result, ) -> Result<(), Error>; + + /// Register metrics with the registry. + fn register_metrics(self: Arc, _registry: &mut Registry) {} } #[async_trait] -pub trait MatchingEngineStateManager { +pub trait MatchingEngineStateManager: Sync + Send { /// Returns a stream of operations that match the filter. - async fn filter_operations( - &self, + async fn filter_operations<'a>( + &'a self, filter: OperationFilter, - ) -> Result; + ) -> Result, Error>; - /// Update that state of an operation. - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: Option, - action_stage: Result, + /// Assign an operation to a worker or unassign it. + async fn assign_operation( + &self, + operation_id: &OperationId, + worker_id_or_reason_for_unsassign: Result<&WorkerId, Error>, ) -> Result<(), Error>; - /// Remove an operation from the state manager. - /// It is important to use this function to remove operations - /// that are no longer needed to prevent memory leaks. - async fn remove_operation(&self, operation_id: OperationId) -> Result<(), Error>; + /// Register metrics with the registry. + fn register_metrics(self: Arc, _registry: &mut Registry) {} } diff --git a/nativelink-util/tests/operation_id_tests.rs b/nativelink-util/tests/operation_id_tests.rs index e1e8b5e30..a2513c8f7 100644 --- a/nativelink-util/tests/operation_id_tests.rs +++ b/nativelink-util/tests/operation_id_tests.rs @@ -19,22 +19,28 @@ use pretty_assertions::assert_eq; #[nativelink_test] async fn parse_operation_id() -> Result<(), Error> { - let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").unwrap(); - assert_eq!( - operation_id.to_string(), - "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); - assert_eq!( - operation_id.action_name(), - "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0" - ); - assert_eq!( - operation_id.id.to_string(), - "19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" - ); - assert_eq!( - hex::encode(operation_id.get_hash()), - "5a36f0db39e27667c4b91937cd29c1df8799ba468f2de6810c6865be05517644" - ); + { + // Check no cached. + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").unwrap(); + assert_eq!( + operation_id.to_string(), + "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!( + operation_id.id.to_string(), + "19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" + ); + } + { + // Check cached. + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/c/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").unwrap(); + assert_eq!( + operation_id.to_string(), + "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/c/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!( + operation_id.id.to_string(), + "19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" + ); + } Ok(()) } @@ -53,7 +59,7 @@ async fn parse_empty_failure() -> Result<(), Error> { assert_eq!(operation_id.messages.len(), 1); assert_eq!( operation_id.messages[0], - "Invalid ActionInfoHashKey instance name fragment - /" + "Invalid UniqueQualifier instance name fragment - /" ); let operation_id = OperationId::try_from("main").err().unwrap(); @@ -64,7 +70,7 @@ async fn parse_empty_failure() -> Result<(), Error> { "Invalid OperationId unique_qualifier / id fragment - main" ); - let operation_id = OperationId::try_from("main/nohashfn/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); + let operation_id = OperationId::try_from("main/nohashfn/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); assert_eq!(operation_id.code, Code::InvalidArgument); assert_eq!(operation_id.messages.len(), 1); assert_eq!( @@ -73,7 +79,7 @@ async fn parse_empty_failure() -> Result<(), Error> { ); let operation_id = - OperationId::try_from("main/SHA256/badhash-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52") + OperationId::try_from("main/SHA256/badhash-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52") .err() .unwrap(); assert_eq!(operation_id.messages.len(), 3); @@ -82,37 +88,35 @@ async fn parse_empty_failure() -> Result<(), Error> { assert_eq!(operation_id.messages[1], "Invalid sha256 hash: badhash"); assert_eq!( operation_id.messages[2], - "Invalid DigestInfo digest hash - main/SHA256/badhash-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" + "Invalid DigestInfo digest hash - main/SHA256/badhash-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" ); - let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); assert_eq!(operation_id.messages.len(), 2); assert_eq!(operation_id.code, Code::InvalidArgument); assert_eq!( operation_id.messages[0], "cannot parse integer from empty string" ); - assert_eq!(operation_id.messages[1], "Invalid ActionInfoHashKey size value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!(operation_id.messages[1], "Invalid UniqueQualifier size value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); - let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f--211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f--211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); assert_eq!(operation_id.code, Code::InvalidArgument); assert_eq!(operation_id.messages.len(), 2); assert_eq!(operation_id.messages[0], "invalid digit found in string"); - assert_eq!(operation_id.messages[1], "Invalid ActionInfoHashKey size value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f--211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!(operation_id.messages[1], "Invalid UniqueQualifier size value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f--211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/x/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); - assert_eq!(operation_id.messages.len(), 2); + assert_eq!(operation_id.messages.len(), 1); assert_eq!(operation_id.code, Code::InvalidArgument); - assert_eq!(operation_id.messages[0], "invalid digit found in string"); - assert_eq!(operation_id.messages[1], "Invalid ActionInfoHashKey salt hex conversion - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/x/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!(operation_id.messages[0], "Invalid UniqueQualifier cachable value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/x/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/-10/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); - assert_eq!(operation_id.messages.len(), 2); + assert_eq!(operation_id.messages.len(), 1); assert_eq!(operation_id.code, Code::InvalidArgument); - assert_eq!(operation_id.messages[0], "invalid digit found in string"); - assert_eq!(operation_id.messages[1], "Invalid ActionInfoHashKey salt hex conversion - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/-10/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!(operation_id.messages[0], "Invalid UniqueQualifier cachable value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/-10/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); - let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/baduuid").err().unwrap(); + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/baduuid").err().unwrap(); assert_eq!(operation_id.messages.len(), 1); assert_eq!(operation_id.code, Code::InvalidArgument); assert_eq!(operation_id.messages[0], "Failed to parse invalid character: expected an optional prefix of `urn:uuid:` followed by [0-9a-fA-F-], found `u` at 4 as uuid"); @@ -124,7 +128,7 @@ async fn parse_empty_failure() -> Result<(), Error> { .unwrap(); assert_eq!(operation_id.messages.len(), 1); assert_eq!(operation_id.code, Code::Internal); - assert_eq!(operation_id.messages[0], "Invalid ActionInfoHashKey digest size fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0"); + assert_eq!(operation_id.messages[0], "Invalid UniqueQualifier digest size fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0"); Ok(()) } diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index 6ff17e3c0..f9316bfcf 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -211,20 +211,29 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, Update::KeepAlive(()) => { self.metrics.keep_alives_received.inc(); } - Update::KillActionRequest(kill_action_request) => { - let mut action_id = [0u8; 32]; - hex::decode_to_slice(kill_action_request.action_id, &mut action_id as &mut [u8]) - .map_err(|e| make_input_err!( - "KillActionRequest failed to decode ActionId hex with error {}", - e - ))?; - - if let Err(err) = self.running_actions_manager.kill_action(action_id).await { + Update::KillOperationRequest(kill_operation_request) => { + let operation_id_res = kill_operation_request + .operation_id + .as_str() + .try_into(); + let operation_id = match operation_id_res { + Ok(operation_id) => operation_id, + Err(err) => { + event!( + Level::ERROR, + ?kill_operation_request, + ?err, + "Failed to convert string to operation_id" + ); + continue; + } + }; + if let Err(err) = self.running_actions_manager.kill_operation(&operation_id).await { event!( Level::ERROR, - action_id = hex::encode(action_id), + ?operation_id, ?err, - "Failed to send kill request for action" + "Failed to send kill request for operation" ); }; } @@ -232,7 +241,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, self.metrics.start_actions_received.inc(); let execute_request = start_execute.execute_request.as_ref(); - let salt = start_execute.salt; + let operation_id = start_execute.operation_id.clone(); let maybe_instance_name = execute_request.map(|v| v.instance_name.clone()); let action_digest = execute_request.and_then(|v| v.action_digest.clone()); let digest_hasher = execute_request @@ -257,7 +266,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, .and_then(|action| { event!( Level::INFO, - action_id = hex::encode(action.get_action_id()), + operation_id = ?action.get_operation_id(), "Received request to run action" ); action @@ -303,9 +312,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, ExecuteResult{ worker_id, instance_name, - action_digest, - salt, - digest_function: digest_hasher.proto_digest_func().into(), + operation_id, result: Some(execute_result::Result::ExecuteResponse(action_stage.into())), } ) @@ -316,9 +323,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, grpc_client.execution_response(ExecuteResult{ worker_id, instance_name, - action_digest, - salt, - digest_function: digest_hasher.proto_digest_func().into(), + operation_id, result: Some(execute_result::Result::InternalError(e.into())), }).await.err_tip(|| "Error calling execution_response with error")?; }, diff --git a/nativelink-worker/src/running_actions_manager.rs b/nativelink-worker/src/running_actions_manager.rs index b6b502540..0d338118c 100644 --- a/nativelink-worker/src/running_actions_manager.rs +++ b/nativelink-worker/src/running_actions_manager.rs @@ -56,7 +56,7 @@ use nativelink_store::filesystem_store::{FileEntry, FilesystemStore}; use nativelink_store::grpc_store::GrpcStore; use nativelink_util::action_messages::{ to_execute_response, ActionInfo, ActionResult, DirectoryInfo, ExecutionMetadata, FileInfo, - NameOrPath, SymlinkInfo, + NameOrPath, OperationId, SymlinkInfo, }; use nativelink_util::common::{fs, DigestInfo}; use nativelink_util::digest_hasher::{DigestHasher, DigestHasherFunc}; @@ -78,8 +78,6 @@ use tonic::Request; use tracing::{enabled, event, Level}; use uuid::Uuid; -pub type ActionId = [u8; 32]; - /// For simplicity we use a fixed exit code for cases when our program is terminated /// due to a signal. const EXIT_CODE_FOR_SIGNAL: i32 = 9; @@ -531,7 +529,7 @@ async fn process_side_channel_file( async fn do_cleanup( running_actions_manager: &RunningActionsManagerImpl, - action_id: &ActionId, + operation_id: &OperationId, action_directory: &str, ) -> Result<(), Error> { event!(Level::INFO, "Worker cleaning up"); @@ -539,10 +537,10 @@ async fn do_cleanup( let remove_dir_result = fs::remove_dir_all(action_directory) .await .err_tip(|| format!("Could not remove working directory {action_directory}")); - if let Err(err) = running_actions_manager.cleanup_action(action_id) { + if let Err(err) = running_actions_manager.cleanup_action(operation_id) { event!( Level::ERROR, - action_id = hex::encode(action_id), + ?operation_id, ?err, "Error cleaning up action" ); @@ -551,7 +549,7 @@ async fn do_cleanup( if let Err(err) = remove_dir_result { event!( Level::ERROR, - action_id = hex::encode(action_id), + ?operation_id, ?err, "Error removing working directory" ); @@ -562,7 +560,7 @@ async fn do_cleanup( pub trait RunningAction: Sync + Send + Sized + Unpin + 'static { /// Returns the action id of the action. - fn get_action_id(&self) -> &ActionId; + fn get_operation_id(&self) -> &OperationId; /// Anything that needs to execute before the actions is actually executed should happen here. fn prepare_action(self: Arc) -> impl Future, Error>> + Send; @@ -611,7 +609,7 @@ struct RunningActionImplState { } pub struct RunningActionImpl { - action_id: ActionId, + operation_id: OperationId, action_directory: String, work_directory: String, action_info: ActionInfo, @@ -624,7 +622,7 @@ pub struct RunningActionImpl { impl RunningActionImpl { fn new( execution_metadata: ExecutionMetadata, - action_id: ActionId, + operation_id: OperationId, action_directory: String, action_info: ActionInfo, timeout: Duration, @@ -633,7 +631,7 @@ impl RunningActionImpl { let work_directory = format!("{}/{}", action_directory, "work"); let (kill_channel_tx, kill_channel_rx) = oneshot::channel(); Self { - action_id, + operation_id, action_directory, work_directory, action_info, @@ -988,14 +986,14 @@ impl RunningActionImpl { if let Err(err) = child_process_guard.start_kill() { event!( Level::ERROR, - action_id = hex::encode(self.action_id), + operation_id = ?self.operation_id, ?err, "Could not kill process", ); } else { event!( Level::ERROR, - action_id = hex::encode(self.action_id), + operation_id = ?self.operation_id, "Could not get child process id, maybe already dead?", ); } @@ -1034,7 +1032,7 @@ impl RunningActionImpl { ) }; let cas_store = self.running_actions_manager.cas_store.as_ref(); - let hasher = self.action_info.unique_qualifier.digest_function; + let hasher = self.action_info.unique_qualifier.digest_function(); enum OutputType { None, File(FileInfo), @@ -1250,23 +1248,23 @@ impl Drop for RunningActionImpl { if self.did_cleanup.load(Ordering::Acquire) { return; } + let operation_id = self.operation_id.clone(); event!( Level::ERROR, - action_id = hex::encode(self.action_id), + ?operation_id, "RunningActionImpl did not cleanup. This is a violation of the requirements, will attempt to do it in the background." ); let running_actions_manager = self.running_actions_manager.clone(); - let action_id = self.action_id; let action_directory = self.action_directory.clone(); background_spawn!("running_action_impl_drop", async move { let Err(err) = - do_cleanup(&running_actions_manager, &action_id, &action_directory).await + do_cleanup(&running_actions_manager, &operation_id, &action_directory).await else { return; }; event!( Level::ERROR, - action_id = hex::encode(action_id), + ?operation_id, ?action_directory, ?err, "Error cleaning up action" @@ -1276,8 +1274,8 @@ impl Drop for RunningActionImpl { } impl RunningAction for RunningActionImpl { - fn get_action_id(&self) -> &ActionId { - &self.action_id + fn get_operation_id(&self) -> &OperationId { + &self.operation_id } async fn prepare_action(self: Arc) -> Result, Error> { @@ -1311,7 +1309,7 @@ impl RunningAction for RunningActionImpl { .wrap(async move { let result = do_cleanup( &self.running_actions_manager, - &self.action_id, + &self.operation_id, &self.action_directory, ) .await; @@ -1352,7 +1350,10 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static { fn kill_all(&self) -> impl Future + Send; - fn kill_action(&self, action_id: ActionId) -> impl Future> + Send; + fn kill_operation( + &self, + operation_id: &OperationId, + ) -> impl Future> + Send; fn metrics(&self) -> &Arc; } @@ -1643,7 +1644,7 @@ pub struct RunningActionsManagerImpl { upload_action_results: UploadActionResults, max_action_timeout: Duration, timeout_handled_externally: bool, - running_actions: Mutex>>, + running_actions: Mutex>>, // Note: We don't use Notify because we need to support a .wait_for()-like function, which // Notify does not support. action_done_tx: watch::Sender<()>, @@ -1699,11 +1700,10 @@ impl RunningActionsManagerImpl { fn make_action_directory<'a>( &'a self, - action_id: &'a ActionId, + operation_id: &'a OperationId, ) -> impl Future> + 'a { self.metrics.make_action_directory.wrap(async move { - let action_directory = - format!("{}/{}", self.root_action_directory, hex::encode(action_id)); + let action_directory = format!("{}/{}", self.root_action_directory, operation_id.id); fs::create_dir(&action_directory) .await .err_tip(|| format!("Error creating action directory {action_directory}"))?; @@ -1730,10 +1730,9 @@ impl RunningActionsManagerImpl { get_and_decode_digest::(self.cas_store.as_ref(), action_digest.into()) .await .err_tip(|| "During start_action")?; - let action_info = ActionInfo::try_from_action_and_execute_request_with_salt( + let action_info = ActionInfo::try_from_action_and_execute_request( execute_request, action, - start_execute.salt, load_start_timestamp, queued_timestamp, ) @@ -1742,10 +1741,10 @@ impl RunningActionsManagerImpl { }) } - fn cleanup_action(&self, action_id: &ActionId) -> Result<(), Error> { + fn cleanup_action(&self, operation_id: &OperationId) -> Result<(), Error> { let mut running_actions = self.running_actions.lock(); - let result = running_actions.remove(action_id).err_tip(|| { - format!("Expected action id '{action_id:?}' to exist in RunningActionsManagerImpl") + let result = running_actions.remove(operation_id).err_tip(|| { + format!("Expected action id '{operation_id:?}' to exist in RunningActionsManagerImpl") }); // No need to copy anything, we just are telling the receivers an event happened. self.action_done_tx.send_modify(|_| {}); @@ -1754,11 +1753,11 @@ impl RunningActionsManagerImpl { // Note: We do not capture metrics on this call, only `.kill_all()`. // Important: When the future returns the process may still be running. - async fn kill_action(action: Arc) { + async fn kill_operation(action: Arc) { event!( Level::WARN, - action_id = ?hex::encode(action.action_id), - "Sending kill to running action", + operation_id = ?action.operation_id, + "Sending kill to running operation", ); let kill_channel_tx = { let mut action_state = action.state.lock(); @@ -1768,8 +1767,8 @@ impl RunningActionsManagerImpl { if kill_channel_tx.send(()).is_err() { event!( Level::ERROR, - action_id = ?hex::encode(action.action_id), - "Error sending kill to running action", + operation_id = ?action.operation_id, + "Error sending kill to running operation", ); } } @@ -1792,14 +1791,18 @@ impl RunningActionsManager for RunningActionsManagerImpl { .clone() .and_then(|time| time.try_into().ok()) .unwrap_or(SystemTime::UNIX_EPOCH); + let operation_id: OperationId = start_execute + .operation_id + .as_str() + .try_into() + .err_tip(|| "Could not convert to operation_id in RunningActionsManager::create_and_add_action")?; let action_info = self.create_action_info(start_execute, queued_timestamp).await?; event!( Level::INFO, ?action_info, "Worker received action", ); - let action_id = action_info.unique_qualifier.get_hash(); - let action_directory = self.make_action_directory(&action_id).await?; + let action_directory = self.make_action_directory(&operation_id).await?; let execution_metadata = ExecutionMetadata { worker: worker_id, queued_timestamp: action_info.insert_timestamp, @@ -1827,7 +1830,7 @@ impl RunningActionsManager for RunningActionsManagerImpl { } let running_action = Arc::new(RunningActionImpl::new( execution_metadata, - action_id, + operation_id.clone(), action_directory, action_info, timeout, @@ -1835,7 +1838,7 @@ impl RunningActionsManager for RunningActionsManagerImpl { )); { let mut running_actions = self.running_actions.lock(); - running_actions.insert(action_id, Arc::downgrade(&running_action)); + running_actions.insert(operation_id, Arc::downgrade(&running_action)); } Ok(running_action) }) @@ -1858,17 +1861,15 @@ impl RunningActionsManager for RunningActionsManagerImpl { .await } - async fn kill_action(&self, action_id: ActionId) -> Result<(), Error> { + async fn kill_operation(&self, operation_id: &OperationId) -> Result<(), Error> { let running_action = { let running_actions = self.running_actions.lock(); running_actions - .get(&action_id) + .get(operation_id) .and_then(|action| action.upgrade()) - .ok_or_else(|| { - make_input_err!("Failed to get running action {}", hex::encode(action_id)) - })? + .ok_or_else(|| make_input_err!("Failed to get running action {operation_id}"))? }; - Self::kill_action(running_action).await; + Self::kill_operation(running_action).await; Ok(()) } @@ -1877,15 +1878,15 @@ impl RunningActionsManager for RunningActionsManagerImpl { self.metrics .kill_all .wrap_no_capture_result(async move { - let kill_actions: Vec> = { + let kill_operations: Vec> = { let running_actions = self.running_actions.lock(); running_actions .iter() - .filter_map(|(_action_id, action)| action.upgrade()) + .filter_map(|(_operation_id, action)| action.upgrade()) .collect() }; - for action in kill_actions { - Self::kill_action(action).await; + for action in kill_operations { + Self::kill_operation(action).await; } }) .await; diff --git a/nativelink-worker/tests/local_worker_test.rs b/nativelink-worker/tests/local_worker_test.rs index 5d094fe3c..1036db07d 100644 --- a/nativelink-worker/tests/local_worker_test.rs +++ b/nativelink-worker/tests/local_worker_test.rs @@ -32,18 +32,18 @@ mod utils { use nativelink_config::cas_server::{LocalWorkerConfig, WorkerProperty}; use nativelink_error::{make_err, make_input_err, Code, Error}; use nativelink_macro::nativelink_test; -use nativelink_proto::build::bazel::remote::execution::v2::digest_function; use nativelink_proto::build::bazel::remote::execution::v2::platform::Property; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ - execute_result, ConnectionResult, ExecuteResult, KillActionRequest, StartExecute, + execute_result, ConnectionResult, ExecuteResult, KillOperationRequest, StartExecute, SupportedProperties, UpdateForWorker, }; use nativelink_store::fast_slow_store::FastSlowStore; use nativelink_store::filesystem_store::FilesystemStore; use nativelink_store::memory_store::MemoryStore; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ExecutionMetadata, + ActionInfo, ActionResult, ActionStage, ActionUniqueKey, ActionUniqueQualifier, + ExecutionMetadata, OperationId, }; use nativelink_util::common::{encode_stream_proto, fs, DigestInfo}; use nativelink_util::digest_hasher::DigestHasherFunc; @@ -195,8 +195,6 @@ async fn kill_all_called_on_disconnect() -> Result<(), Box Result<(), Box> { - const SALT: u64 = 1000; - let mut test_context = setup_local_worker(HashMap::new()).await; let streaming_response = test_context.maybe_streaming_response.take().unwrap(); @@ -233,13 +231,11 @@ async fn blake3_digest_function_registerd_properly() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { - const SALT: u64 = 1000; - let mut test_context = setup_local_worker(HashMap::new()).await; let streaming_response = test_context.maybe_streaming_response.take().unwrap(); @@ -319,13 +313,11 @@ async fn simple_worker_start_action_test() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { - const SALT: u64 = 1000; - let mut test_context = setup_local_worker(HashMap::new()).await; let streaming_response = test_context.maybe_streaming_response.take().unwrap(); @@ -677,22 +660,21 @@ async fn kill_action_request_kills_action() -> Result<(), Box Result<(), Box SystemTime { previous_time } +fn make_operation_id(execute_request: &ExecuteRequest) -> OperationId { + let unique_qualifier = ActionUniqueQualifier::Cachable(ActionUniqueKey { + instance_name: execute_request.instance_name.clone(), + digest_function: execute_request.digest_function.try_into().unwrap(), + digest: execute_request + .action_digest + .clone() + .unwrap() + .try_into() + .unwrap(), + }); + OperationId::new(unique_qualifier) +} + #[nativelink_test] async fn download_to_directory_file_download_test() -> Result<(), Box> { const FILE1_NAME: &str = "file1.txt"; @@ -443,7 +459,6 @@ async fn ensure_output_files_full_directories_are_created_no_working_directory_t }, )?); { - const SALT: u64 = 55; let command = Command { arguments: vec!["touch".to_string(), "./some/path/test.txt".to_string()], output_files: vec!["some/path/test.txt".to_string()], @@ -487,16 +502,18 @@ async fn ensure_output_files_full_directories_are_created_no_working_directory_t ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: None, }, ) @@ -557,7 +574,6 @@ async fn ensure_output_files_full_directories_are_created_test( }, )?); { - const SALT: u64 = 55; let working_directory = "some_cwd"; let command = Command { arguments: vec!["touch".to_string(), "./some/path/test.txt".to_string()], @@ -603,16 +619,18 @@ async fn ensure_output_files_full_directories_are_created_test( ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: None, }, ) @@ -673,7 +691,6 @@ async fn blake3_upload_files() -> Result<(), Box> { }, )?); let action_result = { - const SALT: u64 = 55; #[cfg(target_family = "unix")] let arguments = vec![ "sh".to_string(), @@ -734,16 +751,19 @@ async fn blake3_upload_files() -> Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + digest_function: ProtoDigestFunction::Blake3.into(), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Blake3.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: None, }, ) @@ -844,7 +864,6 @@ async fn upload_files_from_above_cwd_test() -> Result<(), Box Result<(), Box Result<(), Box> )?); let queued_timestamp = make_system_time(1000); let action_result = { - const SALT: u64 = 55; let command = Command { arguments: vec![ "sh".to_string(), @@ -1060,16 +1080,18 @@ async fn upload_dir_and_symlink_test() -> Result<(), Box> ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(queued_timestamp.into()), }, ) @@ -1223,7 +1245,6 @@ async fn cleanup_happens_on_job_failure() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 55; let (_, _, cas_store, ac_store) = setup_stores().await?; let root_action_directory = make_temp_path("root_action_directory"); @@ -1383,17 +1405,19 @@ async fn kill_ends_action() -> Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .clone() .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -1445,7 +1469,6 @@ echo | set /p=\"Wrapper script did run\" 1>&2 exit 0 "; const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 66; const EXPECTED_STDOUT: &str = "Action did run"; let (_, _, cas_store, ac_store) = setup_stores().await?; @@ -1528,17 +1551,19 @@ exit 0 ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .clone() .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -1587,7 +1612,6 @@ echo | set /p=\"Wrapper script did run with property %PROPERTY% %VALUE% %INNER_T exit 0 "; const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 66; const EXPECTED_STDOUT: &str = "Action did run"; let (_, _, cas_store, ac_store) = setup_stores().await?; @@ -1694,17 +1718,19 @@ exit 0 ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .clone() .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -1751,7 +1777,6 @@ echo | set /p={\"failure\":\"timeout\"} 1>&2 > %SIDE_CHANNEL_FILE% exit 1 "; const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 66; let (_, _, cas_store, ac_store) = setup_stores().await?; let root_action_directory = make_temp_path("root_action_directory"); @@ -1833,17 +1858,19 @@ exit 1 ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .clone() .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -2346,16 +2373,18 @@ async fn ensure_worker_timeout_chooses_correct_values() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let execute_results_fut = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: 0, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -2740,18 +2775,20 @@ async fn kill_all_waits_for_all_tasks_to_finish() -> Result<(), Box Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), + execute_request: Some(execute_request), + operation_id, ..Default::default() }, ) @@ -2946,7 +2986,6 @@ async fn action_directory_contents_are_cleaned() -> Result<(), Box Result<(), Box Result<(), Box> { }, )?); let action_result = { - const SALT: u64 = 55; #[cfg(target_family = "unix")] let arguments = vec![ "sh".to_string(), @@ -3110,16 +3150,18 @@ async fn upload_with_single_permit() -> Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: None, }, ) @@ -3196,7 +3238,6 @@ async fn upload_with_single_permit() -> Result<(), Box> { async fn running_actions_manager_respects_action_timeout() -> Result<(), Box> { const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 66; let (_, _, cas_store, ac_store) = setup_stores().await?; let root_action_directory = make_temp_path("root_work_directory"); @@ -3282,17 +3323,19 @@ async fn running_actions_manager_respects_action_timeout() -> Result<(), Box>, tx_kill_all: mpsc::UnboundedSender<()>, - rx_kill_action: Mutex>, - tx_kill_action: mpsc::UnboundedSender, + rx_kill_operation: Mutex>, + tx_kill_operation: mpsc::UnboundedSender, metrics: Arc, } @@ -61,7 +59,7 @@ impl MockRunningActionsManager { let (tx_call, rx_call) = mpsc::unbounded_channel(); let (tx_resp, rx_resp) = mpsc::unbounded_channel(); let (tx_kill_all, rx_kill_all) = mpsc::unbounded_channel(); - let (tx_kill_action, rx_kill_action) = mpsc::unbounded_channel(); + let (tx_kill_operation, rx_kill_operation) = mpsc::unbounded_channel(); Self { rx_call: Mutex::new(rx_call), tx_call, @@ -69,8 +67,8 @@ impl MockRunningActionsManager { tx_resp, rx_kill_all: Mutex::new(rx_kill_all), tx_kill_all, - rx_kill_action: Mutex::new(rx_kill_action), - tx_kill_action, + rx_kill_operation: Mutex::new(rx_kill_operation), + tx_kill_operation, metrics: Arc::new(Metrics::default()), } } @@ -116,9 +114,9 @@ impl MockRunningActionsManager { .expect("Could not receive msg in mpsc"); } - pub async fn expect_kill_action(&self) -> ActionId { - let mut rx_kill_action_lock = self.rx_kill_action.lock().await; - rx_kill_action_lock + pub async fn expect_kill_operation(&self) -> OperationId { + let mut rx_kill_operation_lock = self.rx_kill_operation.lock().await; + rx_kill_operation_lock .recv() .await .expect("Could not receive msg in mpsc") @@ -165,9 +163,9 @@ impl RunningActionsManager for MockRunningActionsManager { Ok(()) } - async fn kill_action(&self, action_id: ActionId) -> Result<(), Error> { - self.tx_kill_action - .send(action_id) + async fn kill_operation(&self, operation_id: &OperationId) -> Result<(), Error> { + self.tx_kill_operation + .send(operation_id.clone()) .expect("Could not send request to mpsc"); Ok(()) } @@ -344,7 +342,7 @@ impl MockRunningAction { } impl RunningAction for MockRunningAction { - fn get_action_id(&self) -> &ActionId { + fn get_operation_id(&self) -> &OperationId { unreachable!("not implemented for tests"); } diff --git a/src/bin/nativelink.rs b/src/bin/nativelink.rs index 49142778f..c05c62b16 100644 --- a/src/bin/nativelink.rs +++ b/src/bin/nativelink.rs @@ -529,7 +529,7 @@ async fn inner_main( })? .clone() .set_drain_worker( - WorkerId::try_from(worker_id.clone())?, + &WorkerId::try_from(worker_id.clone())?, is_draining, ) .await?;