Skip to content

Commit

Permalink
Use Arc instead of Box for slot provider
Browse files Browse the repository at this point in the history
  • Loading branch information
cretz committed Apr 21, 2024
1 parent 07133d7 commit b380a6d
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 25 deletions.
17 changes: 9 additions & 8 deletions client/src/worker_registry/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
use parking_lot::RwLock;
use slotmap::SlotMap;
use std::collections::{hash_map::Entry::Vacant, HashMap};
use std::sync::Arc;

use temporal_sdk_core_protos::temporal::api::workflowservice::v1::PollWorkflowTaskQueueResponse;

Expand Down Expand Up @@ -53,7 +54,7 @@ impl SlotKey {
#[derive(Default, Debug)]
struct SlotManagerImpl {
/// Maps keys, i.e., namespace#task_queue, to provider.
providers: HashMap<SlotKey, Box<dyn SlotProvider + Send + Sync>>,
providers: HashMap<SlotKey, Arc<dyn SlotProvider + Send + Sync>>,
/// Maps ids to keys in `providers`.
index: SlotMap<WorkerKey, SlotKey>,
}
Expand Down Expand Up @@ -81,7 +82,7 @@ impl SlotManagerImpl {
None
}

fn register(&mut self, provider: Box<dyn SlotProvider + Send + Sync>) -> Option<WorkerKey> {
fn register(&mut self, provider: Arc<dyn SlotProvider + Send + Sync>) -> Option<WorkerKey> {
let key = SlotKey::new(
provider.namespace().to_string(),
provider.task_queue().to_string(),
Expand Down Expand Up @@ -135,7 +136,7 @@ impl SlotManager {
}

/// Register a local worker that can provide WFT processing slots.
pub fn register(&self, provider: Box<dyn SlotProvider + Send + Sync>) -> Option<WorkerKey> {
pub fn register(&self, provider: Arc<dyn SlotProvider + Send + Sync>) -> Option<WorkerKey> {
self.manager.write().register(provider)
}

Expand Down Expand Up @@ -196,8 +197,8 @@ mod tests {
let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true);

let manager = SlotManager::new();
let some_slots = manager.register(Box::new(mock_provider1));
let no_slots = manager.register(Box::new(mock_provider2));
let some_slots = manager.register(Arc::new(mock_provider1));
let no_slots = manager.register(Arc::new(mock_provider2));
assert!(no_slots.is_none());

let mut found = 0;
Expand All @@ -219,8 +220,8 @@ mod tests {
new_mock_provider("foo".to_string(), "bar_q".to_string(), false, false);
let mock_provider2 = new_mock_provider("foo".to_string(), "bar_q".to_string(), false, true);

let no_slots = manager.register(Box::new(mock_provider2));
let some_slots = manager.register(Box::new(mock_provider1));
let no_slots = manager.register(Arc::new(mock_provider2));
let some_slots = manager.register(Arc::new(mock_provider1));
assert!(some_slots.is_none());

let mut not_found = 0;
Expand All @@ -245,7 +246,7 @@ mod tests {
for i in 0..10 {
let namespace = format!("myId{}", i % 3);
let mock_provider = new_mock_provider(namespace, "bar_q".to_string(), false, false);
worker_keys.push(manager.register(Box::new(mock_provider)));
worker_keys.push(manager.register(Arc::new(mock_provider)));
}
assert_eq!((3, 3), manager.num_providers());

Expand Down
2 changes: 1 addition & 1 deletion core/src/worker/client/mocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub(crate) static DEFAULT_TEST_CAPABILITIES: &Capabilities = &Capabilities {
pub(crate) fn mock_workflow_client() -> MockWorkerClient {
let mut r = MockWorkerClient::new();
r.expect_capabilities()
.returning(|| DEFAULT_TEST_CAPABILITIES.clone());
.returning(|| Some(DEFAULT_TEST_CAPABILITIES.clone()));
r.expect_workers()
.returning(|| DEFAULT_WORKERS_REGISTRY.clone());
r.expect_is_mock().returning(|| true);
Expand Down
12 changes: 6 additions & 6 deletions core/src/worker/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ use {
pub struct Worker {
config: WorkerConfig,
wf_client: Arc<dyn WorkerClient>,
slot_provider: SlotProvider,
slot_provider: Arc<SlotProvider>,
/// Registration key to enable eager workflow start for this worker
worker_key: Mutex<Option<WorkerKey>>,
/// Manages all workflows and WFT processing
Expand Down Expand Up @@ -230,7 +230,7 @@ impl Worker {
*worker_key = self
.wf_client
.workers()
.register(Box::new(self.slot_provider.clone()));
.register(self.slot_provider.clone());
}

#[cfg(test)]
Expand Down Expand Up @@ -385,13 +385,13 @@ impl Worker {
info!("Activity polling is disabled for this worker");
};
let la_sink = LAReqSink::new(local_act_mgr.clone());
let slot_provider = SlotProvider::new(
let slot_provider = Arc::new(SlotProvider::new(
config.namespace.clone(),
config.task_queue.clone(),
wft_semaphore.clone(),
Arc::new(external_wft_tx),
);
let worker_key = Mutex::new(client.workers().register(Box::new(slot_provider.clone())));
external_wft_tx,
));
let worker_key = Mutex::new(client.workers().register(slot_provider.clone()));
Self {
slot_provider,
worker_key,
Expand Down
17 changes: 7 additions & 10 deletions core/src/worker/slot_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ impl SlotTrait for Slot {
}
}

#[derive(derive_more::DebugCustom, Clone)]
#[derive(derive_more::DebugCustom)]
#[debug(fmt = "SlotProvider {{ namespace:{namespace}, task_queue: {task_queue} }}")]
pub struct SlotProvider {
namespace: String,
task_queue: String,
wft_semaphore: Arc<MeteredSemaphore>,
external_wft_tx: Arc<WFTStreamSender>,
external_wft_tx: WFTStreamSender,
}

impl SlotProvider {
pub(crate) fn new(
namespace: String,
task_queue: String,
wft_semaphore: Arc<MeteredSemaphore>,
external_wft_tx: Arc<WFTStreamSender>,
external_wft_tx: WFTStreamSender,
) -> Self {
Self {
namespace,
Expand All @@ -76,10 +76,7 @@ impl SlotProviderTrait for SlotProvider {
}
fn try_reserve_wft_slot(&self) -> Option<Box<dyn SlotTrait + Send>> {
match self.wft_semaphore.try_acquire_owned().ok() {
Some(permit) => Some(Box::new(Slot::new(
permit,
self.external_wft_tx.as_ref().clone(),
))),
Some(permit) => Some(Box::new(Slot::new(permit, self.external_wft_tx.clone()))),
None => None,
}
}
Expand Down Expand Up @@ -120,7 +117,7 @@ mod tests {
"my_namespace".to_string(),
"my_queue".to_string(),
wft_semaphore,
Arc::new(external_wft_tx),
external_wft_tx,
);

let slot = provider
Expand All @@ -145,7 +142,7 @@ mod tests {
"my_namespace".to_string(),
"my_queue".to_string(),
wft_semaphore,
Arc::new(external_wft_tx),
external_wft_tx,
);
assert!(provider.try_reserve_wft_slot().is_some());
}
Expand All @@ -166,7 +163,7 @@ mod tests {
"my_namespace".to_string(),
"my_queue".to_string(),
wft_semaphore.clone(),
Arc::new(external_wft_tx),
external_wft_tx,
);
let slot = provider.try_reserve_wft_slot();
assert!(slot.is_some());
Expand Down

0 comments on commit b380a6d

Please sign in to comment.