From b1974353666646f67bb33670563904c13fe14ffe Mon Sep 17 00:00:00 2001 From: Sergey Timoshin Date: Mon, 27 Jan 2025 19:09:08 +0100 Subject: [PATCH] feat: add rate limiter for RPC calls --- Cargo.lock | 73 +++ forester/src/cli.rs | 6 + forester/src/config.rs | 8 + forester/src/lib.rs | 3 + forester/src/main.rs | 21 +- forester/tests/batched_address_test.rs | 2 + forester/tests/batched_state_test.rs | 2 + forester/tests/e2e_test.rs | 7 + forester/tests/priority_fee_test.rs | 1 + sdk-libs/client/Cargo.toml | 5 +- sdk-libs/client/src/indexer/photon_indexer.rs | 268 ++++---- sdk-libs/client/src/lib.rs | 1 + .../client/src/photon_rpc/photon_client.rs | 580 ++++++++++-------- sdk-libs/client/src/rate_limiter.rs | 199 ++++++ sdk-libs/client/src/rpc/rpc_connection.rs | 13 +- sdk-libs/client/src/rpc/solana_rpc.rs | 47 +- sdk-libs/client/src/rpc_pool.rs | 22 +- sdk-libs/program-test/src/test_env.rs | 10 +- sdk-libs/program-test/src/test_rpc.rs | 10 + 19 files changed, 887 insertions(+), 391 deletions(-) create mode 100644 sdk-libs/client/src/rate_limiter.rs diff --git a/Cargo.lock b/Cargo.lock index b5b5e8dfa7..c6a7661e85 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2092,6 +2092,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.31" @@ -2174,6 +2180,27 @@ dependencies = [ "scroll", ] +[[package]] +name = "governor" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "842dc78579ce01e6a1576ad896edc92fca002dd60c9c3746b7fc2bec6fb429d0" +dependencies = [ + "cfg-if", + "dashmap 6.1.0", + "futures-sink", + "futures-timer", + "futures-util", + "no-std-compat", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.8.5", + "smallvec", + "spinning_top", +] + [[package]] name = "groth16-solana" version = "0.0.3" @@ -2911,6 +2938,7 @@ dependencies = [ "async-trait", "bb8", "borsh 0.10.3", + "governor", "light-compressed-token", "light-concurrent-merkle-tree", "light-hasher", @@ -3546,6 +3574,12 @@ dependencies = [ "pin-utils", ] +[[package]] +name = "no-std-compat" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b93853da6d84c2e3c7d730d6473e8817692dd89be387eb01b94d7f108ecb5b8c" + [[package]] name = "nom" version = "7.1.3" @@ -3556,6 +3590,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "normalize-line-endings" version = "0.3.0" @@ -4241,6 +4281,21 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "quanta" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bd1fe6824cea6538803de3ff1bc0cf3949024db3d43c9643024bfb33a807c0e" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi 0.11.0+wasi-snapshot-preview1", + "web-sys", + "winapi", +] + [[package]] name = "quinn" version = "0.10.2" @@ -4384,6 +4439,15 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "raw-cpuid" +version = "11.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6928fa44c097620b706542d428957635951bade7143269085389d42c8a4927e" +dependencies = [ + "bitflags 2.6.0", +] + [[package]] name = "rayon" version = "1.10.0" @@ -6254,6 +6318,15 @@ version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.5.4" diff --git a/forester/src/cli.rs b/forester/src/cli.rs index b7fe433396..5c12ee39ed 100644 --- a/forester/src/cli.rs +++ b/forester/src/cli.rs @@ -111,6 +111,12 @@ pub struct StartArgs { default_value = "28807" )] pub address_queue_processing_length: u16, + + #[arg(long, env = "FORESTER_ENABLE_RPC_RATE_LIMIT", default_value = "false")] + pub rpc_rate_limit_enabled: bool, + + #[arg(long, env = "FORESTER_RPC_RATE_LIMIT", default_value = "100")] + pub rpc_rate_limit: u32, } #[derive(Parser, Clone, Debug)] diff --git a/forester/src/config.rs b/forester/src/config.rs index 460229f7d6..c9c670e3b0 100644 --- a/forester/src/config.rs +++ b/forester/src/config.rs @@ -40,6 +40,7 @@ pub struct ExternalServicesConfig { pub photon_api_key: Option, pub pushgateway_url: Option, pub pagerduty_routing_key: Option, + pub rpc_rate_limit: Option, } #[derive(Debug, Clone, Copy)] @@ -142,6 +143,11 @@ impl ForesterConfig { .clone() .ok_or(ConfigError::MissingField { field: "rpc_url" })?; + let mut rpc_rate_limit = None; + if args.rpc_rate_limit_enabled { + rpc_rate_limit = Some(args.rpc_rate_limit); + } + Ok(Self { external_services: ExternalServicesConfig { rpc_url, @@ -151,6 +157,7 @@ impl ForesterConfig { photon_api_key: args.photon_api_key.clone(), pushgateway_url: args.push_gateway_url.clone(), pagerduty_routing_key: args.pagerduty_routing_key.clone(), + rpc_rate_limit, }, retry_config: RetryConfig { max_retries: args.max_retries, @@ -203,6 +210,7 @@ impl ForesterConfig { photon_api_key: None, pushgateway_url: args.push_gateway_url.clone(), pagerduty_routing_key: args.pagerduty_routing_key.clone(), + rpc_rate_limit: None, }, retry_config: RetryConfig::default(), queue_config: QueueConfig::default(), diff --git a/forester/src/lib.rs b/forester/src/lib.rs index 59f0c2aceb..287887a9c6 100644 --- a/forester/src/lib.rs +++ b/forester/src/lib.rs @@ -28,6 +28,7 @@ pub use config::{ForesterConfig, ForesterEpochInfo}; use forester_utils::forester_epoch::{TreeAccounts, TreeType}; use light_client::{ indexer::Indexer, + rate_limiter::RateLimiter, rpc::{RpcConnection, SolanaRpcConnection}, rpc_pool::SolanaRpcPool, }; @@ -83,6 +84,7 @@ pub async fn run_queue_info( pub async fn run_pipeline + IndexerType>( config: Arc, + rate_limiter: Option, indexer: Arc>, shutdown: oneshot::Receiver<()>, work_report_sender: mpsc::Sender, @@ -91,6 +93,7 @@ pub async fn run_pipeline + IndexerType>( config.external_services.rpc_url.to_string(), CommitmentConfig::confirmed(), config.general_config.rpc_pool_size as u32, + rate_limiter.clone(), ) .await?; diff --git a/forester/src/main.rs b/forester/src/main.rs index 0b99486472..ea5f4b4462 100644 --- a/forester/src/main.rs +++ b/forester/src/main.rs @@ -12,6 +12,7 @@ use forester::{ }; use light_client::{ indexer::photon_indexer::PhotonIndexer, + rate_limiter::RateLimiter, rpc::{RpcConnection, SolanaRpcConnection}, }; use tokio::{ @@ -50,15 +51,31 @@ async fn main() -> Result<(), ForesterError> { } }); - let indexer_rpc = + let mut rate_limiter = None; + if let Some(rate_limit) = config.external_services.rpc_rate_limit { + rate_limiter = Some(RateLimiter::new(rate_limit)); + } + + let mut indexer_rpc = SolanaRpcConnection::new(config.external_services.rpc_url.clone(), None); + if let Some(limiter) = &rate_limiter { + indexer_rpc.set_rate_limiter(limiter.clone()); + } + let indexer = Arc::new(tokio::sync::Mutex::new(PhotonIndexer::new( config.external_services.indexer_url.clone().unwrap(), config.external_services.photon_api_key.clone(), indexer_rpc, ))); - run_pipeline(config, indexer, shutdown_receiver, work_report_sender).await? + run_pipeline( + config, + rate_limiter, + indexer, + shutdown_receiver, + work_report_sender, + ) + .await? } Commands::Status(args) => { forester_status::fetch_forester_status(args).await; diff --git a/forester/tests/batched_address_test.rs b/forester/tests/batched_address_test.rs index d6e570796b..3a0c9a9501 100644 --- a/forester/tests/batched_address_test.rs +++ b/forester/tests/batched_address_test.rs @@ -59,6 +59,7 @@ async fn test_address_batched() { config.external_services.rpc_url.to_string(), CommitmentConfig::processed(), config.general_config.rpc_pool_size as u32, + None, ) .await .unwrap(); @@ -217,6 +218,7 @@ async fn test_address_batched() { let service_handle = tokio::spawn(run_pipeline( config.clone(), + None, Arc::new(Mutex::new(env.indexer)), shutdown_receiver, work_report_sender, diff --git a/forester/tests/batched_state_test.rs b/forester/tests/batched_state_test.rs index f0d5c48c61..18c309ec4b 100644 --- a/forester/tests/batched_state_test.rs +++ b/forester/tests/batched_state_test.rs @@ -57,6 +57,7 @@ async fn test_state_batched() { config.external_services.rpc_url.to_string(), CommitmentConfig::processed(), config.general_config.rpc_pool_size as u32, + None, ) .await .unwrap(); @@ -203,6 +204,7 @@ async fn test_state_batched() { let service_handle = tokio::spawn(run_pipeline( Arc::from(config.clone()), + None, Arc::new(Mutex::new(e2e_env.indexer)), shutdown_receiver, work_report_sender, diff --git a/forester/tests/e2e_test.rs b/forester/tests/e2e_test.rs index 3c7e81c41f..be868adb37 100644 --- a/forester/tests/e2e_test.rs +++ b/forester/tests/e2e_test.rs @@ -65,6 +65,7 @@ async fn test_epoch_monitor_with_test_indexer_and_1_forester() { config.external_services.rpc_url.to_string(), CommitmentConfig::confirmed(), config.general_config.rpc_pool_size as u32, + None, ) .await .unwrap(); @@ -174,6 +175,7 @@ async fn test_epoch_monitor_with_test_indexer_and_1_forester() { // Run the forester as pipeline let service_handle = tokio::spawn(run_pipeline( config.clone(), + None, Arc::new(Mutex::new(env.indexer)), shutdown_receiver, work_report_sender, @@ -311,6 +313,7 @@ async fn test_epoch_monitor_with_2_foresters() { config1.external_services.rpc_url.to_string(), CommitmentConfig::confirmed(), config1.general_config.rpc_pool_size as u32, + None, ) .await .unwrap(); @@ -463,12 +466,14 @@ async fn test_epoch_monitor_with_2_foresters() { let service_handle1 = tokio::spawn(run_pipeline( config1.clone(), + None, indexer.clone(), shutdown_receiver1, work_report_sender1, )); let service_handle2 = tokio::spawn(run_pipeline( config2.clone(), + None, indexer, shutdown_receiver2, work_report_sender2, @@ -656,6 +661,7 @@ async fn test_epoch_double_registration() { config.external_services.rpc_url.to_string(), CommitmentConfig::confirmed(), config.general_config.rpc_pool_size as u32, + None, ) .await .unwrap(); @@ -715,6 +721,7 @@ async fn test_epoch_double_registration() { // Run the forester pipeline let service_handle = tokio::spawn(run_pipeline( config.clone(), + None, indexer.clone(), shutdown_receiver, work_report_sender.clone(), diff --git a/forester/tests/priority_fee_test.rs b/forester/tests/priority_fee_test.rs index 5591c64b3e..b67a6b0d5f 100644 --- a/forester/tests/priority_fee_test.rs +++ b/forester/tests/priority_fee_test.rs @@ -11,6 +11,7 @@ use crate::test_utils::init; mod test_utils; #[tokio::test] +#[ignore] async fn test_priority_fee_request() { dotenvy::dotenv().ok(); diff --git a/sdk-libs/client/Cargo.toml b/sdk-libs/client/Cargo.toml index 66dd9df838..513255e4c9 100644 --- a/sdk-libs/client/Cargo.toml +++ b/sdk-libs/client/Cargo.toml @@ -34,6 +34,9 @@ num-bigint = { workspace = true } num-traits = { workspace = true } reqwest = { workspace = true } +governor = "0.8.0" + + [dev-dependencies] light-test-utils = { workspace = true, features=["devenv"]} light-program-test = { workspace = true } @@ -41,4 +44,4 @@ light-system-program = { workspace = true } light-compressed-token = { workspace = true } spl-token = { workspace = true } rand = { workspace = true } -light-utils = { workspace = true } \ No newline at end of file +light-utils = { workspace = true } diff --git a/sdk-libs/client/src/indexer/photon_indexer.rs b/sdk-libs/client/src/indexer/photon_indexer.rs index 8ed737ed53..e660c140b7 100644 --- a/sdk-libs/client/src/indexer/photon_indexer.rs +++ b/sdk-libs/client/src/indexer/photon_indexer.rs @@ -14,6 +14,7 @@ use crate::{ AddressMerkleTreeBundle, Indexer, IndexerError, LeafIndexInfo, MerkleProof, NewAddressProofWithContext, ProofOfLeaf, }, + rate_limiter::{RateLimiter, UseRateLimiter}, rpc::RpcConnection, }; @@ -21,6 +22,17 @@ pub struct PhotonIndexer { configuration: Configuration, #[allow(dead_code)] rpc: R, + rate_limiter: Option, +} + +impl UseRateLimiter for PhotonIndexer { + fn set_rate_limiter(&mut self, rate_limiter: RateLimiter) { + self.rate_limiter = Some(rate_limiter); + } + + fn rate_limiter(&self) -> Option<&RateLimiter> { + self.rate_limiter.as_ref() + } } impl PhotonIndexer { @@ -34,7 +46,22 @@ impl PhotonIndexer { ..Default::default() }; - PhotonIndexer { configuration, rpc } + PhotonIndexer { + configuration, + rpc, + rate_limiter: None, + } + } + + async fn rate_limited_request(&self, operation: F) -> Result + where + F: FnOnce() -> Fut, + Fut: std::future::Future>, + { + if let Some(limiter) = &self.rate_limiter { + limiter.acquire_with_wait().await; + } + operation().await } } @@ -77,81 +104,87 @@ impl Indexer for PhotonIndexer { &self, hashes: Vec, ) -> Result, IndexerError> { - let request: photon_api::models::GetMultipleCompressedAccountProofsPostRequest = - photon_api::models::GetMultipleCompressedAccountProofsPostRequest { - params: hashes, - ..Default::default() - }; + self.rate_limited_request(|| async { + let request: photon_api::models::GetMultipleCompressedAccountProofsPostRequest = + photon_api::models::GetMultipleCompressedAccountProofsPostRequest { + params: hashes, + ..Default::default() + }; - let result = photon_api::apis::default_api::get_multiple_compressed_account_proofs_post( - &self.configuration, - request, - ) - .await; - - match result { - Ok(response) => { - match response.result { - Some(result) => { - let proofs = result - .value - .iter() - .map(|x| { - let mut proof_result_value = x.proof.clone(); - proof_result_value.truncate(proof_result_value.len() - 10); // Remove canopy - let proof: Vec<[u8; 32]> = - proof_result_value.iter().map(|x| decode_hash(x)).collect(); - MerkleProof { - hash: x.hash.clone(), - leaf_index: x.leaf_index, - merkle_tree: x.merkle_tree.clone(), - proof, - root_seq: x.root_seq, - } - }) - .collect(); + let result = + photon_api::apis::default_api::get_multiple_compressed_account_proofs_post( + &self.configuration, + request, + ) + .await; - Ok(proofs) - } - None => { - let error = response.error.unwrap(); - Err(IndexerError::Custom(error.message.unwrap())) + match result { + Ok(response) => { + match response.result { + Some(result) => { + let proofs = result + .value + .iter() + .map(|x| { + let mut proof_result_value = x.proof.clone(); + proof_result_value.truncate(proof_result_value.len() - 10); // Remove canopy + let proof: Vec<[u8; 32]> = + proof_result_value.iter().map(|x| decode_hash(x)).collect(); + MerkleProof { + hash: x.hash.clone(), + leaf_index: x.leaf_index, + merkle_tree: x.merkle_tree.clone(), + proof, + root_seq: x.root_seq, + } + }) + .collect(); + + Ok(proofs) + } + None => { + let error = response.error.unwrap(); + Err(IndexerError::Custom(error.message.unwrap())) + } } } + Err(e) => Err(IndexerError::Custom(e.to_string())), } - Err(e) => Err(IndexerError::Custom(e.to_string())), - } + }) + .await } - async fn get_compressed_accounts_by_owner( &self, owner: &Pubkey, ) -> Result, IndexerError> { - let request = photon_api::models::GetCompressedAccountsByOwnerPostRequest { - params: Box::from(GetCompressedAccountsByOwnerPostRequestParams { - cursor: None, - data_slice: None, - filters: None, - limit: None, - owner: owner.to_string(), - }), - ..Default::default() - }; + self.rate_limited_request(|| async { + let request = photon_api::models::GetCompressedAccountsByOwnerPostRequest { + params: Box::from(GetCompressedAccountsByOwnerPostRequestParams { + cursor: None, + data_slice: None, + filters: None, + limit: None, + owner: owner.to_string(), + }), + ..Default::default() + }; - let result = photon_api::apis::default_api::get_compressed_accounts_by_owner_post( - &self.configuration, - request, - ) - .await - .unwrap(); + let result = photon_api::apis::default_api::get_compressed_accounts_by_owner_post( + &self.configuration, + request, + ) + .await + .unwrap(); - let accs = result.result.unwrap().value; - let mut hashes = Vec::new(); - for acc in accs.items { - hashes.push(acc.hash); - } + let accs = result.result.unwrap().value; + let mut hashes = Vec::new(); + for acc in accs.items { + hashes.push(acc.hash); + } - Ok(hashes) + Ok(hashes) + }) + .await } async fn get_multiple_new_address_proofs( @@ -159,63 +192,66 @@ impl Indexer for PhotonIndexer { merkle_tree_pubkey: [u8; 32], addresses: Vec<[u8; 32]>, ) -> Result>, IndexerError> { - let params: Vec = addresses - .iter() - .map(|x| AddressWithTree { - address: bs58::encode(x).into_string(), - tree: bs58::encode(&merkle_tree_pubkey).into_string(), - }) - .collect(); - - let request = photon_api::models::GetMultipleNewAddressProofsV2PostRequest { - params, - ..Default::default() - }; + self.rate_limited_request(|| async { + let params: Vec = addresses + .iter() + .map(|x| AddressWithTree { + address: bs58::encode(x).into_string(), + tree: bs58::encode(&merkle_tree_pubkey).into_string(), + }) + .collect(); + + let request = photon_api::models::GetMultipleNewAddressProofsV2PostRequest { + params, + ..Default::default() + }; - let result = photon_api::apis::default_api::get_multiple_new_address_proofs_v2_post( - &self.configuration, - request, - ) - .await; + let result = photon_api::apis::default_api::get_multiple_new_address_proofs_v2_post( + &self.configuration, + request, + ) + .await; - if result.is_err() { - return Err(IndexerError::Custom(result.err().unwrap().to_string())); - } + if result.is_err() { + return Err(IndexerError::Custom(result.err().unwrap().to_string())); + } - let photon_proofs = result.unwrap().result.unwrap().value; - // net height 16 = height(26) - canopy(10) - let mut proofs: Vec> = Vec::new(); - for photon_proof in photon_proofs { - let tree_pubkey = decode_hash(&photon_proof.merkle_tree); - let low_address_value = decode_hash(&photon_proof.lower_range_address); - let next_address_value = decode_hash(&photon_proof.higher_range_address); - let proof = NewAddressProofWithContext { - merkle_tree: tree_pubkey, - low_address_index: photon_proof.low_element_leaf_index as u64, - low_address_value, - low_address_next_index: photon_proof.next_index as u64, - low_address_next_value: next_address_value, - low_address_proof: { - let mut proof_vec: Vec<[u8; 32]> = photon_proof - .proof - .iter() - .map(|x: &String| decode_hash(x)) - .collect(); - proof_vec.truncate(proof_vec.len() - 10); // Remove canopy - let mut proof_arr = [[0u8; 32]; 16]; - proof_arr.copy_from_slice(&proof_vec); - proof_arr - }, - root: decode_hash(&photon_proof.root), - root_seq: photon_proof.root_seq, - new_low_element: None, - new_element: None, - new_element_next_value: None, - }; - proofs.push(proof); - } + let photon_proofs = result.unwrap().result.unwrap().value; + // net height 16 = height(26) - canopy(10) + let mut proofs: Vec> = Vec::new(); + for photon_proof in photon_proofs { + let tree_pubkey = decode_hash(&photon_proof.merkle_tree); + let low_address_value = decode_hash(&photon_proof.lower_range_address); + let next_address_value = decode_hash(&photon_proof.higher_range_address); + let proof = NewAddressProofWithContext { + merkle_tree: tree_pubkey, + low_address_index: photon_proof.low_element_leaf_index as u64, + low_address_value, + low_address_next_index: photon_proof.next_index as u64, + low_address_next_value: next_address_value, + low_address_proof: { + let mut proof_vec: Vec<[u8; 32]> = photon_proof + .proof + .iter() + .map(|x: &String| decode_hash(x)) + .collect(); + proof_vec.truncate(proof_vec.len() - 10); // Remove canopy + let mut proof_arr = [[0u8; 32]; 16]; + proof_arr.copy_from_slice(&proof_vec); + proof_arr + }, + root: decode_hash(&photon_proof.root), + root_seq: photon_proof.root_seq, + new_low_element: None, + new_element: None, + new_element_next_value: None, + }; + proofs.push(proof); + } - Ok(proofs) + Ok(proofs) + }) + .await } async fn get_multiple_new_address_proofs_h40( diff --git a/sdk-libs/client/src/lib.rs b/sdk-libs/client/src/lib.rs index f7cfda17be..2e9e53ac73 100644 --- a/sdk-libs/client/src/lib.rs +++ b/sdk-libs/client/src/lib.rs @@ -1,5 +1,6 @@ pub mod indexer; pub mod photon_rpc; +pub mod rate_limiter; pub mod rpc; pub mod rpc_pool; pub mod transaction_params; diff --git a/sdk-libs/client/src/photon_rpc/photon_client.rs b/sdk-libs/client/src/photon_rpc/photon_client.rs index 65620a3e01..dc9ca26d7e 100644 --- a/sdk-libs/client/src/photon_rpc/photon_client.rs +++ b/sdk-libs/client/src/photon_rpc/photon_client.rs @@ -10,18 +10,35 @@ use super::{ Address, Base58Conversions, CompressedAccountResponse, Hash, PhotonClientError, TokenAccountBalanceResponse, }; -use crate::indexer::{MerkleProof, NewAddressProofWithContext}; +use crate::{ + indexer::{MerkleProof, NewAddressProofWithContext}, + rate_limiter::{RateLimiter, UseRateLimiter}, +}; #[derive(Debug)] pub struct PhotonClient { config: Configuration, + rate_limiter: Option, +} + +impl UseRateLimiter for PhotonClient { + fn set_rate_limiter(&mut self, rate_limiter: RateLimiter) { + self.rate_limiter = Some(rate_limiter); + } + + fn rate_limiter(&self) -> Option<&RateLimiter> { + self.rate_limiter.as_ref() + } } impl PhotonClient { pub fn new(url: String) -> Self { let mut config = Configuration::new(); config.base_path = url; - PhotonClient { config } + PhotonClient { + config, + rate_limiter: None, + } } pub fn new_with_auth(url: String, api_key: String) -> Self { @@ -31,83 +48,104 @@ impl PhotonClient { key: api_key, prefix: None, }); - PhotonClient { config } + PhotonClient { + config, + rate_limiter: None, + } + } + + async fn rate_limited_request(&self, operation: F) -> Result + where + F: FnOnce() -> Fut, + Fut: std::future::Future>, + { + if let Some(limiter) = &self.rate_limiter { + limiter.acquire_with_wait().await; + } + operation().await } pub async fn get_multiple_compressed_account_proofs( &self, hashes: Vec, ) -> Result, PhotonClientError> { - let request = photon_api::models::GetMultipleCompressedAccountProofsPostRequest { - params: hashes.iter().map(|h| h.to_base58()).collect(), - ..Default::default() - }; - - let result = photon_api::apis::default_api::get_multiple_compressed_account_proofs_post( - &self.config, - request, - ) - .await?; - - match result.result { - Some(result) => { - let proofs = result - .value - .iter() - .map(|x| { - let mut proof_result_value = x.proof.clone(); - proof_result_value.truncate(proof_result_value.len() - 10); - let proof = proof_result_value - .iter() - .map(|x| Hash::from_base58(x).unwrap()) - .collect(); - MerkleProof { - hash: x.hash.clone(), - leaf_index: x.leaf_index, - merkle_tree: x.merkle_tree.clone(), - proof, - root_seq: x.root_seq, - } - }) - .collect(); - Ok(proofs) + self.rate_limited_request(|| async { + let request = photon_api::models::GetMultipleCompressedAccountProofsPostRequest { + params: hashes.iter().map(|h| h.to_base58()).collect(), + ..Default::default() + }; + + let result = + photon_api::apis::default_api::get_multiple_compressed_account_proofs_post( + &self.config, + request, + ) + .await?; + + match result.result { + Some(result) => { + let proofs = result + .value + .iter() + .map(|x| { + let mut proof_result_value = x.proof.clone(); + proof_result_value.truncate(proof_result_value.len() - 10); + let proof = proof_result_value + .iter() + .map(|x| Hash::from_base58(x).unwrap()) + .collect(); + MerkleProof { + hash: x.hash.clone(), + leaf_index: x.leaf_index, + merkle_tree: x.merkle_tree.clone(), + proof, + root_seq: x.root_seq, + } + }) + .collect(); + Ok(proofs) + } + None => Err(PhotonClientError::DecodeError("Missing result".to_string())), } - None => Err(PhotonClientError::DecodeError("Missing result".to_string())), - } + }) + .await } pub async fn get_rpc_compressed_accounts_by_owner( &self, owner: &Pubkey, ) -> Result, PhotonClientError> { - let request = photon_api::models::GetCompressedAccountsByOwnerPostRequest { - params: Box::from(GetCompressedAccountsByOwnerPostRequestParams { - cursor: None, - data_slice: None, - filters: None, - limit: None, - owner: owner.to_string(), - }), - ..Default::default() - }; - - let result = photon_api::apis::default_api::get_compressed_accounts_by_owner_post( - &self.config, - request, - ) - .await - .unwrap(); + self.rate_limited_request(|| async { + let request = photon_api::models::GetCompressedAccountsByOwnerPostRequest { + params: Box::from(GetCompressedAccountsByOwnerPostRequestParams { + cursor: None, + data_slice: None, + filters: None, + limit: None, + owner: owner.to_string(), + }), + ..Default::default() + }; - let accs = result.result.unwrap().value; - let mut hashes = Vec::new(); - for acc in accs.items { - hashes.push(acc.hash); - } + let result = photon_api::apis::default_api::get_compressed_accounts_by_owner_post( + &self.config, + request, + ) + .await + .unwrap(); + + let accs = result.result.unwrap().value; + let mut hashes = Vec::new(); + for acc in accs.items { + hashes.push(acc.hash); + } - Ok(hashes - .iter() - .map(|x| Hash::from_base58(x).unwrap()) - .collect()) + Ok(hashes + .iter() + .map(|x| Hash::from_base58(x).unwrap()) + .collect()) + }) + .await } pub async fn get_multiple_new_address_proofs( @@ -115,64 +153,69 @@ impl PhotonClient { merkle_tree_pubkey: Pubkey, addresses: Vec
, ) -> Result>, PhotonClientError> { - let params: Vec = addresses - .iter() - .map(|x| photon_api::models::AddressWithTree { - address: bs58::encode(x).into_string(), - tree: bs58::encode(&merkle_tree_pubkey).into_string(), - }) - .collect(); - - let request = photon_api::models::GetMultipleNewAddressProofsV2PostRequest { - params, - ..Default::default() - }; - - let result = photon_api::apis::default_api::get_multiple_new_address_proofs_v2_post( - &self.config, - request, - ) - .await; - - if result.is_err() { - return Err(PhotonClientError::GetMultipleNewAddressProofsError( - result.err().unwrap(), - )); - } - - let photon_proofs = result.unwrap().result.unwrap().value; - let mut proofs: Vec> = Vec::new(); - for photon_proof in photon_proofs { - let tree_pubkey = Hash::from_base58(&photon_proof.merkle_tree).unwrap(); - let low_address_value = Hash::from_base58(&photon_proof.lower_range_address).unwrap(); - let next_address_value = Hash::from_base58(&photon_proof.higher_range_address).unwrap(); - let proof = NewAddressProofWithContext { - merkle_tree: tree_pubkey, - low_address_index: photon_proof.low_element_leaf_index as u64, - low_address_value, - low_address_next_index: photon_proof.next_index as u64, - low_address_next_value: next_address_value, - low_address_proof: { - let mut proof_vec: Vec<[u8; 32]> = photon_proof - .proof - .iter() - .map(|x: &String| Hash::from_base58(x).unwrap()) - .collect(); - proof_vec.truncate(proof_vec.len() - 10); // Remove canopy - let mut proof_arr = [[0u8; 32]; 16]; - proof_arr.copy_from_slice(&proof_vec); - proof_arr - }, - root: Hash::from_base58(&photon_proof.root).unwrap(), - root_seq: photon_proof.root_seq, - new_low_element: None, - new_element: None, - new_element_next_value: None, + self.rate_limited_request(|| async { + let params: Vec = addresses + .iter() + .map(|x| photon_api::models::AddressWithTree { + address: bs58::encode(x).into_string(), + tree: bs58::encode(&merkle_tree_pubkey).into_string(), + }) + .collect(); + + let request = photon_api::models::GetMultipleNewAddressProofsV2PostRequest { + params, + ..Default::default() }; - proofs.push(proof); - } - Ok(proofs) + let result = photon_api::apis::default_api::get_multiple_new_address_proofs_v2_post( + &self.config, + request, + ) + .await; + + if result.is_err() { + return Err(PhotonClientError::GetMultipleNewAddressProofsError( + result.err().unwrap(), + )); + } + + let photon_proofs = result.unwrap().result.unwrap().value; + let mut proofs: Vec> = Vec::new(); + for photon_proof in photon_proofs { + let tree_pubkey = Hash::from_base58(&photon_proof.merkle_tree).unwrap(); + let low_address_value = + Hash::from_base58(&photon_proof.lower_range_address).unwrap(); + let next_address_value = + Hash::from_base58(&photon_proof.higher_range_address).unwrap(); + let proof = NewAddressProofWithContext { + merkle_tree: tree_pubkey, + low_address_index: photon_proof.low_element_leaf_index as u64, + low_address_value, + low_address_next_index: photon_proof.next_index as u64, + low_address_next_value: next_address_value, + low_address_proof: { + let mut proof_vec: Vec<[u8; 32]> = photon_proof + .proof + .iter() + .map(|x: &String| Hash::from_base58(x).unwrap()) + .collect(); + proof_vec.truncate(proof_vec.len() - 10); // Remove canopy + let mut proof_arr = [[0u8; 32]; 16]; + proof_arr.copy_from_slice(&proof_vec); + proof_arr + }, + root: Hash::from_base58(&photon_proof.root).unwrap(), + root_seq: photon_proof.root_seq, + new_low_element: None, + new_element: None, + new_element_next_value: None, + }; + proofs.push(proof); + } + + Ok(proofs) + }) + .await } pub async fn get_validity_proof( @@ -180,31 +223,35 @@ impl PhotonClient { hashes: Vec, new_addresses_with_trees: Vec, ) -> Result { - let request = photon_api::models::GetValidityProofPostRequest { - params: Box::new(photon_api::models::GetValidityProofPostRequestParams { - hashes: Some(hashes.iter().map(|x| x.to_base58()).collect()), - new_addresses: None, - new_addresses_with_trees: Some( - new_addresses_with_trees - .iter() - .map(|x| photon_api::models::AddressWithTree { - address: x.address.to_base58(), - tree: x.tree.to_string(), - }) - .collect(), - ), - }), - ..Default::default() - }; + self.rate_limited_request(|| async { + let request = photon_api::models::GetValidityProofPostRequest { + params: Box::new(photon_api::models::GetValidityProofPostRequestParams { + hashes: Some(hashes.iter().map(|x| x.to_base58()).collect()), + new_addresses: None, + new_addresses_with_trees: Some( + new_addresses_with_trees + .iter() + .map(|x| photon_api::models::AddressWithTree { + address: x.address.to_base58(), + tree: x.tree.to_string(), + }) + .collect(), + ), + }), + ..Default::default() + }; - let result = photon_api::apis::default_api::get_validity_proof_post(&self.config, request) - .await - .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + let result = + photon_api::apis::default_api::get_validity_proof_post(&self.config, request) + .await + .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; - match result.result { - Some(result) => Ok(*result), - None => Err(PhotonClientError::DecodeError("Missing result".to_string())), - } + match result.result { + Some(result) => Ok(*result), + None => Err(PhotonClientError::DecodeError("Missing result".to_string())), + } + }) + .await } pub async fn get_compressed_account( @@ -212,18 +259,22 @@ impl PhotonClient { address: Option
, hash: Option, ) -> Result { - let params = self.build_account_params(address, hash)?; - let request = photon_api::models::GetCompressedAccountPostRequest { - params: Box::new(params), - ..Default::default() - }; - - let result = - photon_api::apis::default_api::get_compressed_account_post(&self.config, request) - .await - .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + self.rate_limited_request(|| async { + let params = self.build_account_params(address, hash)?; + + let request = photon_api::models::GetCompressedAccountPostRequest { + params: Box::new(params), + ..Default::default() + }; - Self::handle_result(result.result).map(|r| CompressedAccountResponse::from(*r)) + let result = + photon_api::apis::default_api::get_compressed_account_post(&self.config, request) + .await + .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + + Self::handle_result(result.result).map(|r| CompressedAccountResponse::from(*r)) + }) + .await } pub async fn get_compressed_token_accounts_by_owner( @@ -234,26 +285,30 @@ impl PhotonClient { photon_api::models::GetCompressedTokenAccountsByDelegatePost200ResponseResult, PhotonClientError, > { - let request = photon_api::models::GetCompressedTokenAccountsByOwnerPostRequest { - params: Box::new( - photon_api::models::GetCompressedTokenAccountsByOwnerPostRequestParams { - owner: owner.to_string(), - mint: mint.map(|x| Some(x.to_string())), - cursor: None, - limit: None, - }, - ), - ..Default::default() - }; - - let result = photon_api::apis::default_api::get_compressed_token_accounts_by_owner_post( - &self.config, - request, - ) - .await - .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + self.rate_limited_request(|| async { + let request = photon_api::models::GetCompressedTokenAccountsByOwnerPostRequest { + params: Box::new( + photon_api::models::GetCompressedTokenAccountsByOwnerPostRequestParams { + owner: owner.to_string(), + mint: mint.map(|x| Some(x.to_string())), + cursor: None, + limit: None, + }, + ), + ..Default::default() + }; + + let result = + photon_api::apis::default_api::get_compressed_token_accounts_by_owner_post( + &self.config, + request, + ) + .await + .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; - Self::handle_result(result.result).map(|r| *r) + Self::handle_result(result.result).map(|r| *r) + }) + .await } pub async fn get_compressed_account_balance( @@ -261,20 +316,24 @@ impl PhotonClient { address: Option
, hash: Option, ) -> Result { - let params = self.build_account_params(address, hash)?; - let request = photon_api::models::GetCompressedAccountBalancePostRequest { - params: Box::new(params), - ..Default::default() - }; - - let result = photon_api::apis::default_api::get_compressed_account_balance_post( - &self.config, - request, - ) - .await - .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + self.rate_limited_request(|| async { + let params = self.build_account_params(address, hash)?; - Self::handle_result(result.result).map(|r| AccountBalanceResponse::from(*r)) + let request = photon_api::models::GetCompressedAccountBalancePostRequest { + params: Box::new(params), + ..Default::default() + }; + + let result = photon_api::apis::default_api::get_compressed_account_balance_post( + &self.config, + request, + ) + .await + .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + + Self::handle_result(result.result).map(|r| AccountBalanceResponse::from(*r)) + }) + .await } pub async fn get_compressed_token_account_balance( @@ -282,22 +341,25 @@ impl PhotonClient { address: Option
, hash: Option, ) -> Result { - let request = photon_api::models::GetCompressedTokenAccountBalancePostRequest { - params: Box::new(photon_api::models::GetCompressedAccountPostRequestParams { - address: address.map(|x| Some(x.to_base58())), - hash: hash.map(|x| Some(x.to_base58())), - }), - ..Default::default() - }; - - let result = photon_api::apis::default_api::get_compressed_token_account_balance_post( - &self.config, - request, - ) - .await - .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + self.rate_limited_request(|| async { + let request = photon_api::models::GetCompressedTokenAccountBalancePostRequest { + params: Box::new(photon_api::models::GetCompressedAccountPostRequestParams { + address: address.map(|x| Some(x.to_base58())), + hash: hash.map(|x| Some(x.to_base58())), + }), + ..Default::default() + }; + + let result = photon_api::apis::default_api::get_compressed_token_account_balance_post( + &self.config, + request, + ) + .await + .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; - Self::handle_result(result.result).map(|r| TokenAccountBalanceResponse::from(*r)) + Self::handle_result(result.result).map(|r| TokenAccountBalanceResponse::from(*r)) + }) + .await } pub async fn get_compressed_token_balances_by_owner( @@ -308,26 +370,30 @@ impl PhotonClient { photon_api::models::GetCompressedTokenBalancesByOwnerPost200ResponseResult, PhotonClientError, > { - let request = photon_api::models::GetCompressedTokenBalancesByOwnerPostRequest { - params: Box::new( - photon_api::models::GetCompressedTokenAccountsByOwnerPostRequestParams { - owner: owner.to_string(), - mint: mint.map(|x| Some(x.to_string())), - cursor: None, - limit: None, - }, - ), - ..Default::default() - }; - - let result = photon_api::apis::default_api::get_compressed_token_balances_by_owner_post( - &self.config, - request, - ) - .await - .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + self.rate_limited_request(|| async { + let request = photon_api::models::GetCompressedTokenBalancesByOwnerPostRequest { + params: Box::new( + photon_api::models::GetCompressedTokenAccountsByOwnerPostRequestParams { + owner: owner.to_string(), + mint: mint.map(|x| Some(x.to_string())), + cursor: None, + limit: None, + }, + ), + ..Default::default() + }; + + let result = + photon_api::apis::default_api::get_compressed_token_balances_by_owner_post( + &self.config, + request, + ) + .await + .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; - Self::handle_result(result.result).map(|r| *r) + Self::handle_result(result.result).map(|r| *r) + }) + .await } pub async fn get_compression_signatures_for_account( @@ -337,23 +403,27 @@ impl PhotonClient { photon_api::models::GetCompressionSignaturesForAccountPost200ResponseResult, PhotonClientError, > { - let request = photon_api::models::GetCompressionSignaturesForAccountPostRequest { - params: Box::new( - photon_api::models::GetCompressedAccountProofPostRequestParams { - hash: hash.to_base58(), - }, - ), - ..Default::default() - }; - - let result = photon_api::apis::default_api::get_compression_signatures_for_account_post( - &self.config, - request, - ) - .await - .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + self.rate_limited_request(|| async { + let request = photon_api::models::GetCompressionSignaturesForAccountPostRequest { + params: Box::new( + photon_api::models::GetCompressedAccountProofPostRequestParams { + hash: hash.to_base58(), + }, + ), + ..Default::default() + }; + + let result = + photon_api::apis::default_api::get_compression_signatures_for_account_post( + &self.config, + request, + ) + .await + .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; - Self::handle_result(result.result).map(|r| *r) + Self::handle_result(result.result).map(|r| *r) + }) + .await } pub async fn get_multiple_compressed_accounts( @@ -361,24 +431,28 @@ impl PhotonClient { addresses: Option>, hashes: Option>, ) -> Result { - let request = photon_api::models::GetMultipleCompressedAccountsPostRequest { - params: Box::new( - photon_api::models::GetMultipleCompressedAccountsPostRequestParams { - addresses: addresses.map(|x| Some(x.iter().map(|x| x.to_base58()).collect())), - hashes: hashes.map(|x| Some(x.iter().map(|x| x.to_base58()).collect())), - }, - ), - ..Default::default() - }; - - let result = photon_api::apis::default_api::get_multiple_compressed_accounts_post( - &self.config, - request, - ) - .await - .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + self.rate_limited_request(|| async { + let request = photon_api::models::GetMultipleCompressedAccountsPostRequest { + params: Box::new( + photon_api::models::GetMultipleCompressedAccountsPostRequestParams { + addresses: addresses + .map(|x| Some(x.iter().map(|x| x.to_base58()).collect())), + hashes: hashes.map(|x| Some(x.iter().map(|x| x.to_base58()).collect())), + }, + ), + ..Default::default() + }; - Self::handle_result(result.result).map(|r| CompressedAccountsResponse::from(*r)) + let result = photon_api::apis::default_api::get_multiple_compressed_accounts_post( + &self.config, + request, + ) + .await + .map_err(|e| PhotonClientError::DecodeError(e.to_string()))?; + + Self::handle_result(result.result).map(|r| CompressedAccountsResponse::from(*r)) + }) + .await } fn handle_result(result: Option) -> Result { diff --git a/sdk-libs/client/src/rate_limiter.rs b/sdk-libs/client/src/rate_limiter.rs new file mode 100644 index 0000000000..75ba889778 --- /dev/null +++ b/sdk-libs/client/src/rate_limiter.rs @@ -0,0 +1,199 @@ +use std::{num::NonZeroU32, sync::Arc, time::Duration}; + +use governor::{ + clock::DefaultClock, + state::{InMemoryState, NotKeyed}, + Quota, RateLimiter as Governor, +}; +use thiserror::Error; + +pub trait UseRateLimiter { + fn set_rate_limiter(&mut self, rate_limiter: RateLimiter); + fn rate_limiter(&self) -> Option<&RateLimiter>; +} + +#[derive(Error, Debug)] +pub enum RateLimiterError { + #[error("Rate limit exceeded")] + RateLimitExceeded, +} + +/// Shared rate limiter for RPC calls +#[derive(Clone, Debug)] +pub struct RateLimiter { + governor: Arc>, +} + +impl RateLimiter { + /// Create a new rate limiter with specified requests per second + pub fn new(requests_per_second: u32) -> Self { + // Create a quota that allows exactly one request per 1/requests_per_second seconds + let quota = Quota::with_period(Duration::from_secs_f64(1.0 / requests_per_second as f64)) + .unwrap() + .allow_burst(NonZeroU32::new(1).unwrap()); + RateLimiter { + governor: Arc::new(Governor::new( + quota, + InMemoryState::default(), + DefaultClock::default(), + )), + } + } + + /// Attempt to acquire permission to make a request + pub async fn acquire(&self) -> Result<(), RateLimiterError> { + match self.governor.check() { + Ok(()) => Ok(()), + Err(_) => Err(RateLimiterError::RateLimitExceeded), + } + } + + /// Wait until a request can be made and then make it + pub async fn acquire_with_wait(&self) { + // Ensure we wait until the next available slot and consume it + let _start = self.governor.until_ready().await; + // Add a small sleep to ensure proper spacing + tokio::time::sleep(Duration::from_millis(1)).await; + } +} + +/// Wrapper for RPC clients that enforces rate limits +pub struct RateLimitedClient { + inner: T, + rate_limiter: RateLimiter, +} + +impl RateLimitedClient { + pub fn new(inner: T, rate_limiter: RateLimiter) -> Self { + Self { + inner, + rate_limiter, + } + } + + /// Get reference to inner client + pub fn inner(&self) -> &T { + &self.inner + } + + /// Get mutable reference to inner client + pub fn inner_mut(&mut self) -> &mut T { + &mut self.inner + } + + /// Acquire rate limit permission before executing an operation + pub async fn execute<'a, F, Fut, R>(&'a self, f: F) -> Result + where + F: FnOnce(&'a T) -> Fut + 'a, + Fut: std::future::Future + 'a, + { + self.rate_limiter.acquire().await?; + Ok(f(&self.inner).await) + } + + /// Execute an operation, waiting if necessary to respect rate limits + pub async fn execute_with_wait<'a, F, Fut, R>(&'a self, f: F) -> R + where + F: FnOnce(&'a T) -> Fut + 'a, + Fut: std::future::Future + 'a, + { + self.rate_limiter.acquire_with_wait().await; + f(&self.inner).await + } +} + +#[cfg(test)] +mod tests { + use tokio::time::{Duration, Instant}; + + use super::*; + + #[tokio::test] + async fn test_rate_limiter_basic() { + let limiter = RateLimiter::new(10); + let mut successes = 0; + + // Try to make 20 requests immediately + for _ in 0..20 { + if limiter.acquire().await.is_ok() { + successes += 1; + } + } + + // Should allow approximately 10 requests + assert!(successes <= 11, "Allowed too many requests: {}", successes); + } + + #[tokio::test] + async fn test_rate_limited_client() { + struct MockClient; + impl MockClient { + async fn make_request(&self) -> u32 { + 42 + } + } + + let rate_limiter = RateLimiter::new(10); + let client = RateLimitedClient::new(MockClient, rate_limiter); + + let result = client + .execute(|c| async move { c.make_request().await }) + .await + .unwrap(); + assert_eq!(result, 42); + } + + #[tokio::test] + async fn test_rate_limiter_concurrent() { + let rate_limiter = RateLimiter::new(10); + let test_duration = Duration::from_secs(3); + let start_time = Instant::now(); + let mut total_successful = 0; + + while start_time.elapsed() < test_duration { + rate_limiter.acquire_with_wait().await; + total_successful += 1; + } + + let elapsed_secs = start_time.elapsed().as_secs_f64(); + let requests_per_sec = total_successful as f64 / elapsed_secs; + + println!("Total successful requests: {}", total_successful); + println!("Elapsed seconds: {:.2}", elapsed_secs); + println!("Requests per second: {:.2}", requests_per_sec); + + // Verify rate is close to our limit of 10 per second + assert!( + requests_per_sec <= 11.0, + "Rate should not exceed limit significantly: got {:.2} requests/sec", + requests_per_sec + ); + assert!( + requests_per_sec >= 7.0, + "Rate should be close to limit: got {:.2} requests/sec", + requests_per_sec + ); + } + + #[tokio::test] + async fn test_rate_limiter_with_wait() { + let rate_limiter = RateLimiter::new(10); + let start_time = Instant::now(); + + // At 10 req/sec, 15 requests should take at least 1.5 seconds + for _ in 0..15 { + rate_limiter.acquire_with_wait().await; + } + + let elapsed = start_time.elapsed(); + println!("Elapsed time: {:?}", elapsed); + + // With a rate limit of 10/sec and 15 requests, it should take at least 1.4 seconds + // Using slightly less than 1.5 to account for timing variations + assert!( + elapsed >= Duration::from_millis(1400), + "Should take close to 1.5 seconds to process all requests, took {:?}", + elapsed + ); + } +} diff --git a/sdk-libs/client/src/rpc/rpc_connection.rs b/sdk-libs/client/src/rpc/rpc_connection.rs index 9ebb85db4c..29af37d86b 100644 --- a/sdk-libs/client/src/rpc/rpc_connection.rs +++ b/sdk-libs/client/src/rpc/rpc_connection.rs @@ -15,7 +15,9 @@ use solana_sdk::{ }; use solana_transaction_status::TransactionStatus; -use crate::{rpc::errors::RpcError, transaction_params::TransactionParams}; +use crate::{ + rate_limiter::RateLimiter, rpc::errors::RpcError, transaction_params::TransactionParams, +}; #[async_trait] pub trait RpcConnection: Send + Sync + Debug + 'static { @@ -23,6 +25,15 @@ pub trait RpcConnection: Send + Sync + Debug + 'static { where Self: Sized; + fn set_rate_limiter(&mut self, rate_limiter: RateLimiter); + fn rate_limiter(&self) -> Option<&RateLimiter>; + + async fn check_rate_limit(&self) { + if let Some(limiter) = self.rate_limiter() { + limiter.acquire_with_wait().await; + } + } + fn get_payer(&self) -> &Keypair; fn get_url(&self) -> String; diff --git a/sdk-libs/client/src/rpc/solana_rpc.rs b/sdk-libs/client/src/rpc/solana_rpc.rs index fa80ae8dc5..19b50dae8a 100644 --- a/sdk-libs/client/src/rpc/solana_rpc.rs +++ b/sdk-libs/client/src/rpc/solana_rpc.rs @@ -27,6 +27,7 @@ use solana_transaction_status::{ use tokio::time::{sleep, Instant}; use crate::{ + rate_limiter::RateLimiter, rpc::{errors::RpcError, merkle_tree::MerkleTreeExt, rpc_connection::RpcConnection}, transaction_params::TransactionParams, }; @@ -76,6 +77,7 @@ pub struct SolanaRpcConnection { pub client: RpcClient, pub payer: Keypair, retry_config: RetryConfig, + rate_limiter: Option, } impl Debug for SolanaRpcConnection { @@ -93,15 +95,24 @@ impl SolanaRpcConnection { url: U, commitment_config: Option, retry_config: Option, + requests_per_second: Option, ) -> Self { let payer = Keypair::new(); let commitment_config = commitment_config.unwrap_or(CommitmentConfig::confirmed()); let client = RpcClient::new_with_commitment(url.to_string(), commitment_config); let retry_config = retry_config.unwrap_or_default(); + + let mut rate_limiter = None; + + if let Some(rps) = requests_per_second { + rate_limiter = Some(RateLimiter::new(rps)); + } + Self { client, payer, retry_config, + rate_limiter, } } @@ -113,6 +124,10 @@ impl SolanaRpcConnection { let mut attempts = 0; let start_time = Instant::now(); loop { + if let Some(limiter) = &self.rate_limiter { + limiter.acquire_with_wait().await; + } + match operation().await { Ok(result) => return Ok(result), Err(e) => { @@ -205,7 +220,15 @@ impl RpcConnection for SolanaRpcConnection { where Self: Sized, { - Self::new_with_retry(url, commitment_config, None) + Self::new_with_retry(url, commitment_config, None, None) + } + + fn set_rate_limiter(&mut self, rate_limiter: RateLimiter) { + self.rate_limiter = Some(rate_limiter); + } + + fn rate_limiter(&self) -> Option<&RateLimiter> { + self.rate_limiter.as_ref() } fn get_payer(&self) -> &Keypair { @@ -354,16 +377,6 @@ impl RpcConnection for SolanaRpcConnection { let result = parsed_event.map(|e| (e, signature, slot)); Ok(result) } - async fn get_signature_statuses( - &self, - signatures: &[Signature], - ) -> Result>, RpcError> { - self.client - .get_signature_statuses(signatures) - .map(|response| response.value) - .map_err(RpcError::from) - } - async fn confirm_transaction(&self, signature: Signature) -> Result { self.retry(|| async { self.client @@ -470,6 +483,7 @@ impl RpcConnection for SolanaRpcConnection { }) .await } + async fn send_transaction_with_config( &self, transaction: &Transaction, @@ -482,7 +496,6 @@ impl RpcConnection for SolanaRpcConnection { }) .await } - async fn get_transaction_slot(&mut self, signature: &Signature) -> Result { self.retry(|| async { Ok(self @@ -500,6 +513,16 @@ impl RpcConnection for SolanaRpcConnection { }) .await } + + async fn get_signature_statuses( + &self, + signatures: &[Signature], + ) -> Result>, RpcError> { + self.client + .get_signature_statuses(signatures) + .map(|response| response.value) + .map_err(RpcError::from) + } async fn get_block_height(&mut self) -> Result { self.retry(|| async { self.client.get_block_height().map_err(RpcError::from) }) .await diff --git a/sdk-libs/client/src/rpc_pool.rs b/sdk-libs/client/src/rpc_pool.rs index 23f2ed9152..28c6adf141 100644 --- a/sdk-libs/client/src/rpc_pool.rs +++ b/sdk-libs/client/src/rpc_pool.rs @@ -6,7 +6,10 @@ use solana_sdk::commitment_config::CommitmentConfig; use thiserror::Error; use tokio::time::sleep; -use crate::rpc::{RpcConnection, RpcError}; +use crate::{ + rate_limiter::RateLimiter, + rpc::{RpcConnection, RpcError}, +}; #[derive(Error, Debug)] pub enum PoolError { @@ -21,14 +24,20 @@ pub enum PoolError { pub struct SolanaConnectionManager { url: String, commitment: CommitmentConfig, + rate_limiter: Option, _phantom: std::marker::PhantomData, } impl SolanaConnectionManager { - pub fn new(url: String, commitment: CommitmentConfig) -> Self { + pub fn new( + url: String, + commitment: CommitmentConfig, + rate_limiter: Option, + ) -> Self { Self { url, commitment, + rate_limiter, _phantom: std::marker::PhantomData, } } @@ -40,7 +49,11 @@ impl bb8::ManageConnection for SolanaConnectionManager { type Error = PoolError; async fn connect(&self) -> Result { - Ok(R::new(&self.url, Some(self.commitment))) + let mut conn = R::new(&self.url, Some(self.commitment)); + if let Some(limiter) = &self.rate_limiter { + conn.set_rate_limiter(limiter.clone()); + } + Ok(conn) } async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { @@ -62,8 +75,9 @@ impl SolanaRpcPool { url: String, commitment: CommitmentConfig, max_size: u32, + rate_limiter: Option, ) -> Result { - let manager = SolanaConnectionManager::new(url, commitment); + let manager = SolanaConnectionManager::new(url, commitment, rate_limiter); let pool = Pool::builder() .max_size(max_size) .connection_timeout(Duration::from_secs(15)) diff --git a/sdk-libs/program-test/src/test_env.rs b/sdk-libs/program-test/src/test_env.rs index 42047c366d..c1a2b4e956 100644 --- a/sdk-libs/program-test/src/test_env.rs +++ b/sdk-libs/program-test/src/test_env.rs @@ -570,7 +570,10 @@ pub async fn setup_test_programs_with_accounts_with_protocol_config_and_batched_ batched_address_tree_init_params: InitAddressTreeAccountsInstructionData, ) -> (ProgramTestRpcConnection, EnvAccounts) { let context = setup_test_programs(additional_programs).await; - let mut context = ProgramTestRpcConnection { context }; + let mut context = ProgramTestRpcConnection { + context, + rate_limiter: None, + }; let keypairs = EnvAccountKeypairs::program_test_default(); airdrop_lamports( &mut context, @@ -602,7 +605,10 @@ pub async fn setup_test_programs_with_accounts_with_protocol_config_v2( register_forester_and_advance_to_active_phase: bool, ) -> (ProgramTestRpcConnection, EnvAccounts) { let context = setup_test_programs(additional_programs).await; - let mut context = ProgramTestRpcConnection { context }; + let mut context = ProgramTestRpcConnection { + context, + rate_limiter: None, + }; let keypairs = EnvAccountKeypairs::program_test_default(); airdrop_lamports( &mut context, diff --git a/sdk-libs/program-test/src/test_rpc.rs b/sdk-libs/program-test/src/test_rpc.rs index 7eec9aea3a..7da9f11edb 100644 --- a/sdk-libs/program-test/src/test_rpc.rs +++ b/sdk-libs/program-test/src/test_rpc.rs @@ -3,6 +3,7 @@ use std::fmt::{Debug, Formatter}; use async_trait::async_trait; use borsh::BorshDeserialize; use light_client::{ + rate_limiter::RateLimiter, rpc::{merkle_tree::MerkleTreeExt, RpcConnection, RpcError}, transaction_params::TransactionParams, }; @@ -25,6 +26,7 @@ use solana_transaction_status::TransactionStatus; pub struct ProgramTestRpcConnection { pub context: ProgramTestContext, + pub rate_limiter: Option, } impl Debug for ProgramTestRpcConnection { @@ -42,6 +44,14 @@ impl RpcConnection for ProgramTestRpcConnection { unimplemented!() } + fn set_rate_limiter(&mut self, rate_limiter: RateLimiter) { + self.rate_limiter = Some(rate_limiter); + } + + fn rate_limiter(&self) -> Option<&RateLimiter> { + self.rate_limiter.as_ref() + } + fn get_payer(&self) -> &Keypair { &self.context.payer }