diff --git a/runtime/runtime/src/bandwidth_scheduler/distribute_remaining.rs b/runtime/runtime/src/bandwidth_scheduler/distribute_remaining.rs new file mode 100644 index 00000000000..4158047e20a --- /dev/null +++ b/runtime/runtime/src/bandwidth_scheduler/distribute_remaining.rs @@ -0,0 +1,355 @@ +use near_primitives::bandwidth_scheduler::Bandwidth; +use near_primitives::shard_layout::ShardLayout; +use near_primitives::types::ShardIndex; + +use super::scheduler::{ShardIndexMap, ShardLink, ShardLinkMap}; + +/// After bandwidth scheduler processes all of the bandwidth requests, there's usually some leftover +/// budget for sending and receiving data between shards. This function is responsible for +/// distributing the remaining bandwidth in a fair manner. It looks at how much more each shard +/// could send/receive and grants bandwidth on links to fully utilize the available bandwidth. +/// The `is_link_allowed_map` argument makes it possible to disallow granting bandwidth on some +/// links. This usually happens when a shard is fully congested and is allowed to receive receipts +/// only from the allowed sender shard. +/// The algorithm processes senders and receivers in the order of increasing budget. It grants a bit +/// of bandwidth, keeping in mind that others will also want to send some data there. Processing +/// them in this order gives the guarantee that all senders processed later will send at least as +/// much as the one being processed right now. This means that we can grant `remaining_bandwidth / +/// remaining_senders` and be sure that utilization will be high. +/// The algorithm is safe because it never grants more than `remaining_bandwidth`, which ensures +/// that the grants stay under the budget. +/// The algorithm isn't ideal, the utilization can be a bit lower when there are a lot of disallowed +/// links, but it's good enough for bandwidth scheduler. +pub fn distribute_remaining_bandwidth( + sender_budgets: &ShardIndexMap, + receiver_budgets: &ShardIndexMap, + is_link_allowed: &ShardLinkMap, + shard_layout: &ShardLayout, +) -> ShardLinkMap { + let mut sender_infos: ShardIndexMap = ShardIndexMap::new(shard_layout); + let mut receiver_infos: ShardIndexMap = ShardIndexMap::new(shard_layout); + + for shard_index in shard_layout.shard_indexes() { + let sender_budget = sender_budgets.get(&shard_index).copied().unwrap_or(0); + sender_infos + .insert(shard_index, EndpointInfo { links_num: 0, bandwidth_left: sender_budget }); + + let receiver_budget = receiver_budgets.get(&shard_index).copied().unwrap_or(0); + receiver_infos + .insert(shard_index, EndpointInfo { links_num: 0, bandwidth_left: receiver_budget }); + } + + for sender in shard_layout.shard_indexes() { + for receiver in shard_layout.shard_indexes() { + if *is_link_allowed.get(&ShardLink::new(sender, receiver)).unwrap_or(&false) { + sender_infos.get_mut(&sender).unwrap().links_num += 1; + receiver_infos.get_mut(&receiver).unwrap().links_num += 1; + } + } + } + + let mut senders_by_avg_link_bandwidth: Vec = shard_layout.shard_indexes().collect(); + senders_by_avg_link_bandwidth + .sort_by_key(|shard| sender_infos.get(shard).unwrap().average_link_bandwidth()); + + let mut receivers_by_avg_link_bandwidth: Vec = + shard_layout.shard_indexes().collect(); + receivers_by_avg_link_bandwidth + .sort_by_key(|shard| receiver_infos.get(shard).unwrap().average_link_bandwidth()); + + let mut bandwidth_grants: ShardLinkMap = ShardLinkMap::new(shard_layout); + for sender in senders_by_avg_link_bandwidth { + let sender_info = sender_infos.get_mut(&sender).unwrap(); + for &receiver in &receivers_by_avg_link_bandwidth { + if !*is_link_allowed.get(&ShardLink::new(sender, receiver)).unwrap_or(&false) { + continue; + } + + let receiver_info = receiver_infos.get_mut(&receiver).unwrap(); + + if sender_info.links_num == 0 || receiver_info.links_num == 0 { + break; + } + + let sender_proposition = sender_info.link_proposition(); + let receiver_proposition = receiver_info.link_proposition(); + let granted_bandwidth = std::cmp::min(sender_proposition, receiver_proposition); + bandwidth_grants.insert(ShardLink::new(sender, receiver), granted_bandwidth); + + sender_info.bandwidth_left -= granted_bandwidth; + sender_info.links_num -= 1; + + receiver_info.bandwidth_left -= granted_bandwidth; + receiver_info.links_num -= 1; + } + } + + bandwidth_grants +} + +/// Information about sender or receiver shard, used in `distribute_remaining_bandwidth` +struct EndpointInfo { + /// How much more bandwidth can be sent/received + bandwidth_left: Bandwidth, + /// How many more links the bandwidth will be granted on + links_num: u64, +} + +impl EndpointInfo { + /// How much can be sent on every link on average + fn average_link_bandwidth(&self) -> Bandwidth { + if self.links_num == 0 { + return 0; + } + self.bandwidth_left / self.links_num + } + + /// Propose amount of bandwidth to grant on the next link. + /// Both sides of the link propose something and the minimum of the two is granted on the link. + fn link_proposition(&self) -> Bandwidth { + self.bandwidth_left / self.links_num + } +} + +#[cfg(test)] +mod tests { + use std::collections::{BTreeMap, BTreeSet}; + + use near_primitives::bandwidth_scheduler::Bandwidth; + use near_primitives::shard_layout::ShardLayout; + use near_primitives::types::ShardIndex; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaCha20Rng; + + use crate::bandwidth_scheduler::distribute_remaining::distribute_remaining_bandwidth; + use crate::bandwidth_scheduler::scheduler::{ShardIndexMap, ShardLink, ShardLinkMap}; + use testlib::bandwidth_scheduler::estimate_link_throughputs; + + fn run_distribute_remaining( + sender_budgets: &[Bandwidth], + receiver_budgets: &[Bandwidth], + allowed_links: AllowedLinks, + ) -> ShardLinkMap { + assert_eq!(sender_budgets.len(), receiver_budgets.len()); + let shard_layout = ShardLayout::multi_shard(sender_budgets.len().try_into().unwrap(), 0); + let mut sender_budgets_map = ShardIndexMap::new(&shard_layout); + for (i, sender_budget) in sender_budgets.iter().enumerate() { + sender_budgets_map.insert(i, *sender_budget); + } + let mut receiver_budgets_map = ShardIndexMap::new(&shard_layout); + for (i, receiver_budget) in receiver_budgets.iter().enumerate() { + receiver_budgets_map.insert(i, *receiver_budget); + } + + let mut is_link_allowed_map = ShardLinkMap::new(&shard_layout); + let default_allowed = match &allowed_links { + AllowedLinks::AllAllowed => true, + AllowedLinks::AllowedList(_) | AllowedLinks::NoneAllowed => false, + }; + for sender_index in shard_layout.shard_indexes() { + for receiver_index in shard_layout.shard_indexes() { + is_link_allowed_map + .insert(ShardLink::new(sender_index, receiver_index), default_allowed); + } + } + if let AllowedLinks::AllowedList(allowed_links_list) = allowed_links { + for (sender_index, receiver_index) in allowed_links_list { + is_link_allowed_map.insert(ShardLink::new(sender_index, receiver_index), true); + } + } + + distribute_remaining_bandwidth( + &sender_budgets_map, + &receiver_budgets_map, + &is_link_allowed_map, + &shard_layout, + ) + } + + /// Convenient way to specify is_link_allowed_map in the tests + enum AllowedLinks { + /// All links are allowed + AllAllowed, + /// All links are forbidden + NoneAllowed, + /// Only the links in the list are allowed + AllowedList(Vec<(ShardIndex, ShardIndex)>), + } + + fn assert_grants( + granted: &ShardLinkMap, + expected: &[(ShardIndex, ShardIndex, Bandwidth)], + ) { + let mut granted_map: BTreeMap<(ShardIndex, ShardIndex), Bandwidth> = BTreeMap::new(); + for sender in 0..granted.num_indexes() { + for receiver in 0..granted.num_indexes() { + if let Some(grant) = granted.get(&ShardLink::new(sender, receiver)) { + granted_map.insert((sender, receiver), *grant); + } + } + } + + let expected_map: BTreeMap<(ShardIndex, ShardIndex), Bandwidth> = expected + .iter() + .map(|(sender, receiver, grant)| ((*sender, *receiver), *grant)) + .collect(); + + assert_eq!(granted_map, expected_map); + } + + /// A single link, sender can send less than the receiver can receive. + #[test] + fn test_one_link() { + let granted = run_distribute_remaining(&[50], &[100], AllowedLinks::AllAllowed); + assert_grants(&granted, &[(0, 0, 50)]); + } + + /// A single link which is not allowed. No bandwidth should be granted. + #[test] + fn test_one_link_not_allowed() { + let granted = run_distribute_remaining(&[50], &[100], AllowedLinks::NoneAllowed); + assert_grants(&granted, &[]); + } + + /// Three shards, all links are allowed. + /// Bandwidth should be distributed equally. + #[test] + fn test_three_shards() { + let granted = + run_distribute_remaining(&[300, 300, 300], &[300, 300, 300], AllowedLinks::AllAllowed); + assert_grants( + &granted, + &[ + (0, 0, 100), + (0, 1, 100), + (0, 2, 100), + (1, 0, 100), + (1, 1, 100), + (1, 2, 100), + (2, 0, 100), + (2, 1, 100), + (2, 2, 100), + ], + ); + } + + /// (1) can send to (0) and (2) + /// (0) and (2) can send to (1) + /// Each active link should get half of the available budget. + #[test] + fn test_two_to_one() { + let allowed_links = AllowedLinks::AllowedList(vec![(1, 0), (1, 2), (0, 1), (2, 1)]); + let granted = run_distribute_remaining(&[100, 100, 100], &[100, 100, 100], allowed_links); + assert_grants(&granted, &[(0, 1, 50), (1, 0, 50), (1, 2, 50), (2, 1, 50)]); + } + + /// Three shards, two of them are fully congested. + /// (1) can only receive receipts from (0) + /// (2) can only receive receipts from (1) + /// (0) can receive receipts from all shards. + #[test] + fn test_two_fully_congested() { + let allowed_links = AllowedLinks::AllowedList(vec![(0, 0), (0, 1), (1, 0), (1, 2), (2, 0)]); + let granted = run_distribute_remaining(&[300, 300, 300], &[300, 300, 300], allowed_links); + assert_grants(&granted, &[(0, 0, 100), (0, 1, 200), (1, 0, 100), (1, 2, 200), (2, 0, 100)]); + } + + /// Run `distribute_remaining_bandwidth` on a random test scenario + fn randomized_test(seed: u64) { + println!("\n\n# Test with seed {}", seed); + let mut rng = ChaCha20Rng::seed_from_u64(seed); + let num_shards = rng.gen_range(1..10); + let all_links_allowed = rng.gen_bool(0.5); + println!("num_shards: {}", num_shards); + println!("all_links_allowed: {}", all_links_allowed); + + let mut active_links: BTreeSet<(ShardIndex, ShardIndex)> = BTreeSet::new(); + for sender in 0..num_shards { + for receiver in 0..num_shards { + if !all_links_allowed && rng.gen_bool(0.5) { + continue; + } + active_links.insert((sender, receiver)); + } + } + + println!("active_links: {:?}", active_links); + + fn generate_budget(rng: &mut ChaCha20Rng) -> Bandwidth { + if rng.gen_bool(0.1) { + 0 + } else { + rng.gen_range(0..1000) + } + } + let sender_budgets: Vec = + (0..num_shards).map(|_| generate_budget(&mut rng)).collect(); + let receiver_budgets: Vec = + (0..num_shards).map(|_| generate_budget(&mut rng)).collect(); + println!("sender_budgets: {:?}", sender_budgets); + println!("receiver_budgets: {:?}", receiver_budgets); + + let allowed_links = AllowedLinks::AllowedList(active_links.iter().copied().collect()); + let grants = run_distribute_remaining(&sender_budgets, &receiver_budgets, allowed_links); + + let mut total_incoming = vec![0; num_shards]; + let mut total_outgoing = vec![0; num_shards]; + let mut total_throughput = 0; + + for sender in 0..num_shards { + for receiver in 0..num_shards { + if let Some(grant) = grants.get(&ShardLink::new(sender, receiver)) { + total_outgoing[sender] += grant; + total_incoming[receiver] += grant; + total_throughput += grant; + } + } + } + + // Assert that granted bandwidth doesn't exceed sender or receiver budgets. + for i in 0..num_shards { + assert!(total_outgoing[i] <= sender_budgets[i]); + assert!(total_incoming[i] <= receiver_budgets[i]); + } + + // Make sure that bandwidth utilization is high + dbg!(total_throughput); + if all_links_allowed { + // When all links are allowed the algorithm achieves 99% bandwidth utilization. + let total_sending_budget: u64 = sender_budgets.iter().sum(); + let total_receiver_budget: u64 = receiver_budgets.iter().sum(); + let theoretical_throughput = std::cmp::min(total_sending_budget, total_receiver_budget); + dbg!(theoretical_throughput); + assert!(total_throughput >= theoretical_throughput * 99 / 100); + } else { + // When some links are not allowed, the algorithm achieves good enough bandwidth utilization. + // (> 75% of estimated possible throughput). + let estimated_link_throughputs = + estimate_link_throughputs(&active_links, &sender_budgets, &receiver_budgets); + let estimated_total_throughput: Bandwidth = estimated_link_throughputs + .iter() + .map(|(_link, throughput)| throughput.as_u64()) + .sum(); + dbg!(estimated_total_throughput); + assert!(total_throughput >= estimated_total_throughput * 75 / 100); + } + + // Ensure that bandwidth is not granted on forbidden links + for sender in 0..num_shards { + for receiver in 0..num_shards { + let granted = *grants.get(&ShardLink::new(sender, receiver)).unwrap_or(&0); + + if !active_links.contains(&(sender, receiver)) { + assert_eq!(granted, 0); + } + } + } + } + + #[test] + fn test_randomized() { + for i in 0..1000 { + randomized_test(i); + } + } +} diff --git a/runtime/runtime/src/bandwidth_scheduler/mod.rs b/runtime/runtime/src/bandwidth_scheduler/mod.rs index 282542dd3af..f4d3c90725e 100644 --- a/runtime/runtime/src/bandwidth_scheduler/mod.rs +++ b/runtime/runtime/src/bandwidth_scheduler/mod.rs @@ -14,6 +14,7 @@ use scheduler::{BandwidthScheduler, GrantedBandwidth, ShardStatus}; use crate::ApplyState; +mod distribute_remaining; mod scheduler; #[cfg(test)] mod simulator; diff --git a/runtime/runtime/src/bandwidth_scheduler/scheduler.rs b/runtime/runtime/src/bandwidth_scheduler/scheduler.rs index 3ef961e4b24..2d9ba7c6529 100644 --- a/runtime/runtime/src/bandwidth_scheduler/scheduler.rs +++ b/runtime/runtime/src/bandwidth_scheduler/scheduler.rs @@ -169,9 +169,9 @@ pub struct BandwidthScheduler { shard_layout: ShardLayout, /// Configuration parameters for the algorithm. params: BandwidthSchedulerParams, - /// ShardStatus for each shard. - /// (ShardIndex -> ShardStatus) - shards_status: ShardIndexMap, + /// For every link keeps information whether sending receipts on this link is allowed. + /// Keeps result of `Self::calculate_is_link_allowed()` for every pair of shards + is_link_allowed_map: ShardLinkMap, /// Each shard can send and receive at most `max_shard_bandwidth` bytes of receipts. /// This is tracked in the `sender_budget` and `receiver_budget` fields, which keep /// track of how much more a shard can send or receive before hitting the limit. @@ -228,7 +228,7 @@ impl BandwidthScheduler { } } - // Translate shard statuses to the internal representation. + // Initialize the allowed link map based on shard statuses let mut shard_status_by_index: ShardIndexMap = ShardIndexMap::new(&shard_layout); for (shard_id, status) in shards_status { @@ -237,6 +237,19 @@ impl BandwidthScheduler { } } + let mut is_link_allowed_map: ShardLinkMap = ShardLinkMap::new(&shard_layout); + for sender_index in shard_layout.shard_indexes() { + for receiver_index in shard_layout.shard_indexes() { + let is_allowed = Self::calculate_is_link_allowed( + sender_index, + receiver_index, + &shard_status_by_index, + ); + is_link_allowed_map + .insert(ShardLink::new(sender_index, receiver_index), is_allowed); + } + } + // Convert bandwidth requests to representation used in the algorithm. let mut scheduler_bandwidth_requests: Vec = Vec::new(); for (sender_shard, shard_bandwidth_requests) in @@ -269,7 +282,7 @@ impl BandwidthScheduler { // Init the scheduler state let mut scheduler = BandwidthScheduler { shard_layout, - shards_status: shard_status_by_index, + is_link_allowed_map, sender_budget, receiver_budget, link_allowances, @@ -306,7 +319,6 @@ impl BandwidthScheduler { self.receiver_budget.insert(receiver, self.params.max_shard_bandwidth); } } - /// Give every link a fair amount of allowance at every height. fn increase_allowances(&mut self) { // In an ideal, fair world, every link would send the same amount of bandwidth. @@ -379,7 +391,18 @@ impl BandwidthScheduler { /// remaining unused bandwidth that could be granted on the links. This function distributes the /// remaining bandwidth over all the links in a fair manner to improve bandwidth utilization. fn distribute_remaining_bandwidth(&mut self) { - // TODO(bandwidth_scheduler) - will be added in a future PR + let remaining_bandwidth_grants = + super::distribute_remaining::distribute_remaining_bandwidth( + &self.sender_budget, + &self.receiver_budget, + &self.is_link_allowed_map, + &self.shard_layout, + ); + for link in self.iter_links() { + if let Some(remaining_grant) = remaining_bandwidth_grants.get(&link) { + self.grant_more_bandwidth(&link, *remaining_grant); + } + } } /// Convert granted bandwidth from internal representation to the representation returned by scheduler. @@ -459,20 +482,33 @@ impl BandwidthScheduler { self.sender_budget.insert(link.sender, sender_budget - bandwidth); self.receiver_budget.insert(link.receiver, receiver_budget - bandwidth); self.decrease_allowance(link, bandwidth); + self.grant_more_bandwidth(link, bandwidth); + + TryGrantOutcome::Granted + } + /// Add new granted bandwidth to the link. Doesn't adjust allowance or budgets. + fn grant_more_bandwidth(&mut self, link: &ShardLink, bandwidth: Bandwidth) { let current_granted = self.granted_bandwidth.get(link).copied().unwrap_or(0); let new_granted = current_granted.checked_add(bandwidth).unwrap_or_else(|| { tracing::warn!(target: "runtime", "Granting bandwidth on link {:?} would overflow, this is unexpected. Granting max bandwidth instead", link); Bandwidth::MAX }); self.granted_bandwidth.insert(*link, new_granted); - TryGrantOutcome::Granted + } + + fn is_link_allowed(&self, link: &ShardLink) -> bool { + *self.is_link_allowed_map.get(link).unwrap_or(&false) } /// Decide if it's allowed to send receipts on the link, based on shard statuses. /// Makes sure that receipts are not sent to fully congested shards or shards with missing chunks. - fn is_link_allowed(&self, link: &ShardLink) -> bool { - let Some(receiver_status) = self.shards_status.get(&link.receiver) else { + fn calculate_is_link_allowed( + sender_index: ShardIndex, + receiver_index: ShardIndex, + shards_status: &ShardIndexMap, + ) -> bool { + let Some(receiver_status) = shards_status.get(&receiver_index) else { // Receiver shard status unknown - don't send anything on the link, just to be safe. return false; }; @@ -483,7 +519,7 @@ impl BandwidthScheduler { return false; } - let sender_status_opt = self.shards_status.get(&link.sender); + let sender_status_opt = shards_status.get(&sender_index); if let Some(sender_status) = sender_status_opt { if sender_status.last_chunk_missing { // The chunk on sender's shard is missing. Don't grant any bandwidth on links from a shard @@ -497,7 +533,7 @@ impl BandwidthScheduler { // Only the "allowed shard" is allowed to send receipts to a fully congested shard. if receiver_status.is_fully_congested { - if Some(link.sender) == receiver_status.allowed_sender_shard_index { + if Some(sender_index) == receiver_status.allowed_sender_shard_index { return true; } else { return false; @@ -644,6 +680,10 @@ impl ShardIndexMap { self.data[*index].as_ref() } + pub fn get_mut(&mut self, index: &ShardIndex) -> Option<&mut T> { + self.data[*index].as_mut() + } + pub fn insert(&mut self, index: ShardIndex, value: T) { self.data[index] = Some(value); } @@ -678,6 +718,11 @@ impl ShardLinkMap { self.data[data_index] = Some(value); } + #[cfg(test)] + pub fn num_indexes(&self) -> usize { + self.num_indexes + } + fn data_index_for_link(&self, link: &ShardLink) -> usize { debug_assert!( link.sender < self.num_indexes, @@ -797,14 +842,13 @@ mod tests { /// Run with: /// cargo test -p node-runtime --release test_scheduler_worst_case_performance -- --nocapture /// - /// Example output on an n2d-standard-8 GCP VM with AMD EPYC 7B13 CPU: - /// Running scheduler with 6 shards: 0.10 ms - /// Running scheduler with 10 shards: 0.16 ms - /// Running scheduler with 32 shards: 1.76 ms - /// Running scheduler with 64 shards: 5.74 ms - /// Running scheduler with 128 shards: 23.63 ms - /// Running scheduler with 256 shards: 93.15 ms - /// Running scheduler with 512 shards: 371.76 ms + /// Running scheduler with 6 shards: 0.13 ms + /// Running scheduler with 10 shards: 0.19 ms + /// Running scheduler with 32 shards: 1.85 ms + /// Running scheduler with 64 shards: 5.80 ms + /// Running scheduler with 128 shards: 23.98 ms + /// Running scheduler with 256 shards: 97.44 ms + /// Running scheduler with 512 shards: 385.97 ms #[test] fn test_scheduler_worst_case_performance() { for num_shards in [6, 10, 32, 64, 128, 256, 512] { diff --git a/runtime/runtime/src/bandwidth_scheduler/simulator.rs b/runtime/runtime/src/bandwidth_scheduler/simulator.rs index 4c6d19f248f..36875b103aa 100644 --- a/runtime/runtime/src/bandwidth_scheduler/simulator.rs +++ b/runtime/runtime/src/bandwidth_scheduler/simulator.rs @@ -560,7 +560,7 @@ fn test_bandwidth_scheduler_simulator_small_vs_big() { .build(); let summary = run_scenario(scenario); assert!(summary.bandwidth_utilization > 0.90); // 90% utilization - assert!(summary.link_imbalance_ratio < 1.05); // < 5% difference on links + assert!(summary.link_imbalance_ratio < 1.06); // < 6% difference on links assert!(summary.worst_link_estimation_ratio > 0.90); // 90% of estimated link throughput assert!(summary.max_incoming <= summary.max_shard_bandwidth); // Incoming max_shard_bandwidth is respected assert!(summary.max_outgoing <= summary.max_shard_bandwidth); // Outgoing max_shard_bandwidth is respected diff --git a/test-utils/testlib/src/bandwidth_scheduler.rs b/test-utils/testlib/src/bandwidth_scheduler.rs index 508ce67963a..79b689cfaa4 100644 --- a/test-utils/testlib/src/bandwidth_scheduler.rs +++ b/test-utils/testlib/src/bandwidth_scheduler.rs @@ -360,8 +360,9 @@ impl TestBandwidthStats { let link_imbalance_ratio = max_sent_on_link.as_u64() as f64 / min_sent_on_link.as_u64() as f64; + let max_budget = vec![self.scheduler_params.max_shard_bandwidth; 1000]; let estimated_link_throughputs = - estimate_link_throughputs(active_links, self.scheduler_params.max_shard_bandwidth); + estimate_link_throughputs(active_links, &max_budget, &max_budget); let mut estimated_throughput = ByteSize::b(1); for (_link, link_throughput) in &estimated_link_throughputs { @@ -506,28 +507,35 @@ impl std::fmt::Display for TestSummary { /// Ideally this would be done with some sort of network flow algorithm, but for now this will do.I /// guess granting a bit on all links is like poor man's Ford-Flukerson. /// TODO(bandwidth_scheduler) - make this better. -fn estimate_link_throughputs( +pub fn estimate_link_throughputs( active_links: &BTreeSet<(ShardIndex, ShardIndex)>, - max_shard_bandwidth: Bandwidth, + sender_budgets: &[Bandwidth], + receiver_budgets: &[Bandwidth], ) -> BTreeMap<(ShardIndex, ShardIndex), ByteSize> { if active_links.is_empty() { return BTreeMap::new(); } - let max_shard_bandwidth = ByteSize::b(max_shard_bandwidth); let max_index = active_links.iter().map(|(a, b)| std::cmp::max(*a, *b)).max().unwrap(); let num_shards = max_index + 1; - let mut sender_granted = vec![ByteSize::b(0); num_shards]; - let mut receiver_granted = vec![ByteSize::b(0); num_shards]; - let mut link_granted = vec![vec![ByteSize::b(0); num_shards]; num_shards]; + let min_nonzero_budget = sender_budgets + .iter() + .chain(receiver_budgets.iter()) + .filter(|b| **b > 0) + .min() + .unwrap_or(&0); + let single_increase = std::cmp::max(1, min_nonzero_budget / num_shards as u64); + + let mut sender_granted = vec![0; num_shards]; + let mut receiver_granted = vec![0; num_shards]; + let mut link_granted = vec![vec![0; num_shards]; num_shards]; - let single_increase = ByteSize::b(max_shard_bandwidth.as_u64() / num_shards as u64 / 4); let mut links: Vec<(ShardIndex, ShardIndex)> = active_links.iter().copied().collect(); while !links.is_empty() { let mut next_links = Vec::new(); for link in links { - if sender_granted[link.0] + single_increase <= max_shard_bandwidth - && receiver_granted[link.1] + single_increase <= max_shard_bandwidth + if sender_granted[link.0] + single_increase <= sender_budgets[link.0] + && receiver_granted[link.1] + single_increase <= receiver_budgets[link.1] { sender_granted[link.0] += single_increase; receiver_granted[link.1] += single_increase; @@ -540,7 +548,7 @@ fn estimate_link_throughputs( let mut res = BTreeMap::new(); for link in active_links { - res.insert(*link, link_granted[link.0][link.1]); + res.insert(*link, ByteSize::b(link_granted[link.0][link.1])); } res }