diff --git a/README.md b/README.md index bfa3448..168c188 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,53 @@ async fn main() -> Result<(), Box> { } ``` + +### Connecting + +Connect to Pusher: + +```rust +client.connect().await?; +``` + +### Subscribing to Channels + +Subscribe to a public channel: + +```rust +client.subscribe("my-channel").await?; +``` + +Subscribe to a private channel: + +```rust +client.subscribe("private-my-channel").await?; +``` + +Subscribe to a presence channel: + +```rust +client.subscribe("presence-my-channel").await?; +``` + +### Unsubscribing from Channels + +```rust +client.unsubscribe("my-channel").await?; +``` + +### Binding to Events + +Bind to a specific event on a channel: + +```rust +use pusher_rs::Event; + +client.bind("my-event", |event: Event| { + println!("Received event: {:?}", event); +}).await?; +``` + ### Subscribing to a channel ```rust @@ -133,6 +180,15 @@ The library supports four types of channels: Each channel type has specific features and authentication requirements. +### Handling Connection State + +Get the current connection state: + +```rust +let state = client.get_connection_state().await; +println!("Current connection state: {:?}", state); +``` + ## Error Handling The library uses a custom `PusherError` type for error handling. You can match on different error variants to handle specific error cases: @@ -148,6 +204,14 @@ match client.connect().await { } ``` +### Disconnecting + +When you're done, disconnect from Pusher: + +```rust +client.disconnect().await?; +``` + ## Advanced Usage ### Custom Configuration @@ -186,6 +250,25 @@ if let Some(channel) = channel_list.get("my-channel") { } ``` +### Presence Channels + +When subscribing to a presence channel, you can provide user information: + +```rust +use serde_json::json; + +let channel = "presence-my-channel"; +let socket_id = client.get_socket_id().await?; +let user_id = "user_123"; +let user_info = json!({ + "name": "John Doe", + "email": "john@example.com" +}); + +let auth = client.authenticate_presence_channel(&socket_id, channel, user_id, Some(&user_info))?; +client.subscribe_with_auth(channel, &auth).await?; +``` + ### Tests Integration tests live under `tests/integration_tests` diff --git a/src/events.rs b/src/events.rs index 0c4dbc6..00137da 100644 --- a/src/events.rs +++ b/src/events.rs @@ -5,6 +5,7 @@ use serde_json::Value; pub struct Event { pub event: String, pub channel: Option, + #[serde(with = "json_string")] pub data: Value, } @@ -113,6 +114,25 @@ impl SystemEvent { } } +mod json_string { + use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; + use serde_json::Value; + + pub fn serialize(value: &Value, serializer: S) -> Result + where + S: Serializer, + { + value.to_string().serialize(serializer) + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + serde_json::from_str(&s).map_err(D::Error::custom) + } +} #[cfg(test)] mod tests { use super::*; diff --git a/src/lib.rs b/src/lib.rs index b5387ec..bc829d1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,7 @@ use cbc::{Decryptor, Encryptor}; use hmac::{Hmac, Mac}; use log::info; use rand::Rng; -use serde_json::json; +use serde_json::{json, Value}; use sha2::Sha256; use std::collections::HashMap; use std::sync::Arc; @@ -28,14 +28,15 @@ pub use config::PusherConfig; pub use error::{PusherError, PusherResult}; pub use events::{Event, SystemEvent}; -use websocket::WebSocketClient; +use websocket::{WebSocketClient, WebSocketCommand}; /// This struct provides methods for connecting to Pusher, subscribing to channels, /// triggering events, and handling incoming events. pub struct PusherClient { config: PusherConfig, auth: PusherAuth, - websocket: Option, + // websocket: Option, + websocket_command_tx: Option>, channels: Arc>>, event_handlers: Arc>>>>, state: Arc>, @@ -73,29 +74,44 @@ impl PusherClient { let auth = PusherAuth::new(&config.app_key, &config.app_secret); let (event_tx, event_rx) = mpsc::channel(100); let state = Arc::new(RwLock::new(ConnectionState::Disconnected)); - let event_handlers = Arc::new(RwLock::new(HashMap::new())); - let encrypted_channels = Arc::new(RwLock::new(HashMap::new())); + let event_handlers = Arc::new(RwLock::new(std::collections::HashMap::new())); + let encrypted_channels = Arc::new(RwLock::new(std::collections::HashMap::new())); let client = Self { config, auth, - websocket: None, - channels: Arc::new(RwLock::new(HashMap::new())), + websocket_command_tx: None, + channels: Arc::new(RwLock::new(std::collections::HashMap::new())), event_handlers: event_handlers.clone(), state: state.clone(), event_tx, encrypted_channels, }; - // Spawn the event handling task tokio::spawn(Self::handle_events(event_rx, event_handlers)); Ok(client) } + + async fn send(&self, message: String) -> PusherResult<()> { + if let Some(tx) = &self.websocket_command_tx { + tx.send(WebSocketCommand::Send(message)) + .await + .map_err(|e| { + PusherError::WebSocketError(format!("Failed to send command: {}", e)) + })?; + Ok(()) + } else { + Err(PusherError::ConnectionError("Not connected".into())) + } + } + async fn handle_events( mut event_rx: mpsc::Receiver, event_handlers: Arc< - RwLock>>>, + RwLock< + std::collections::HashMap>>, + >, >, ) { while let Some(event) = event_rx.recv().await { @@ -115,18 +131,24 @@ impl PusherClient { /// A `PusherResult` indicating success or failure. pub async fn connect(&mut self) -> PusherResult<()> { let url = self.get_websocket_url()?; - let mut websocket = - WebSocketClient::new(url.clone(), Arc::clone(&self.state), self.event_tx.clone()); + let (command_tx, command_rx) = mpsc::channel(100); + + let mut websocket = WebSocketClient::new( + url.clone(), + Arc::clone(&self.state), + self.event_tx.clone(), + command_rx, + ); + log::info!("Connecting to Pusher using URL: {}", url); websocket.connect().await?; - self.websocket = Some(websocket); - // Start the WebSocket event loop - let mut ws = self.websocket.take().unwrap(); tokio::spawn(async move { - ws.run().await; + websocket.run().await; }); + self.websocket_command_tx = Some(command_tx); + Ok(()) } @@ -136,11 +158,12 @@ impl PusherClient { /// /// A `PusherResult` indicating success or failure. pub async fn disconnect(&mut self) -> PusherResult<()> { - if let Some(websocket) = &self.websocket { - websocket.close().await?; + if let Some(tx) = self.websocket_command_tx.take() { + tx.send(WebSocketCommand::Close).await.map_err(|e| { + PusherError::WebSocketError(format!("Failed to send close command: {}", e)) + })?; } *self.state.write().await = ConnectionState::Disconnected; - self.websocket = None; Ok(()) } @@ -158,21 +181,17 @@ impl PusherClient { let mut channels = self.channels.write().await; channels.insert(channel_name.to_string(), channel); - if let Some(websocket) = &self.websocket { - let data = json!({ - "event": "pusher:subscribe", - "data": { - "channel": channel_name - } - }); - websocket.send(serde_json::to_string(&data)?).await?; - } else { - return Err(PusherError::ConnectionError("Not connected".into())); - } + let data = json!({ + "event": "pusher:subscribe", + "data": { + "channel": channel_name + } + }); - Ok(()) + self.send(serde_json::to_string(&data)?).await } + /// Subscribes to an encrypted channel. /// /// # Arguments @@ -208,6 +227,7 @@ impl PusherClient { /// # Returns /// /// A `PusherResult` indicating success or failure. + /// pub async fn unsubscribe(&mut self, channel_name: &str) -> PusherResult<()> { { let mut channels = self.channels.write().await; @@ -219,19 +239,14 @@ impl PusherClient { encrypted_channels.remove(channel_name); } - if let Some(websocket) = &self.websocket { - let data = json!({ - "event": "pusher:unsubscribe", - "data": { - "channel": channel_name - } - }); - websocket.send(serde_json::to_string(&data)?).await?; - } else { - return Err(PusherError::ConnectionError("Not connected".into())); - } + let data = json!({ + "event": "pusher:unsubscribe", + "data": { + "channel": channel_name + } + }); - Ok(()) + self.send(serde_json::to_string(&data)?).await } /// Triggers an event on a channel. @@ -251,10 +266,14 @@ impl PusherClient { self.config.cluster, self.config.app_id ); + // Validate that the data is valid JSON, but keep it as a string + serde_json::from_str::(data) + .map_err(|e| PusherError::JsonError(e))?; + let body = json!({ "name": event, "channel": channel, - "data": data + "data": data, // Keep data as a string }); let path = format!("/apps/{}/events", self.config.app_id); let auth_params = self.auth.authenticate_request("POST", &path, &body)?; @@ -371,6 +390,7 @@ impl PusherClient { /// # Returns /// /// A `PusherResult` indicating success or failure. + /// pub async fn bind(&self, event_name: &str, callback: F) -> PusherResult<()> where F: Fn(Event) + Send + Sync + 'static, @@ -535,7 +555,8 @@ mod tests { #[tokio::test] async fn test_trigger_batch() { - let config = PusherConfig::from_env().expect("Failed to load Pusher configuration from environment"); + let config = + PusherConfig::from_env().expect("Failed to load Pusher configuration from environment"); let client = PusherClient::new(config).unwrap(); let batch_events = vec![ diff --git a/src/websocket.rs b/src/websocket.rs index 6d5a6ee..6b43d76 100644 --- a/src/websocket.rs +++ b/src/websocket.rs @@ -6,19 +6,18 @@ use tokio_tungstenite::{ }; use tokio::net::TcpStream; use futures_util::{SinkExt, StreamExt}; -use tokio::time::{sleep, Duration, Instant}; +use tokio::time::{sleep, interval, Duration}; use url::Url; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; use log::{debug, error, info, warn}; +use std::pin::Pin; use crate::error::{PusherError, PusherResult}; use crate::{Event, SystemEvent, ConnectionState}; const PING_INTERVAL: Duration = Duration::from_secs(30); const PONG_TIMEOUT: Duration = Duration::from_secs(10); -const MAX_RECONNECTION_ATTEMPTS: u32 = 6; -const INITIAL_BACKOFF: Duration = Duration::from_secs(1); pub struct WebSocketClient { url: Url, @@ -26,7 +25,6 @@ pub struct WebSocketClient { state: Arc>, event_tx: mpsc::Sender, command_rx: mpsc::Receiver, - command_tx: mpsc::Sender, } pub enum WebSocketCommand { @@ -35,26 +33,24 @@ pub enum WebSocketCommand { } impl WebSocketClient { - pub fn new(url: Url, state: Arc>, event_tx: mpsc::Sender) -> Self { - let (command_tx, command_rx) = mpsc::channel(100); + pub fn new( + url: Url, + state: Arc>, + event_tx: mpsc::Sender, + command_rx: mpsc::Receiver, + ) -> Self { Self { url, socket: None, state, event_tx, command_rx, - command_tx, } } - pub fn get_command_tx(&self) -> mpsc::Sender { - self.command_tx.clone() - } - pub async fn connect(&mut self) -> PusherResult<()> { debug!("Connecting to WebSocket: {}", self.url); - let url_string = self.url.to_string(); - let (socket, _) = connect_async(url_string).await + let (socket, _) = connect_async(self.url.to_string()).await .map_err(|e| PusherError::WebSocketError(format!("Failed to connect: {}", e)))?; self.socket = Some(socket); self.set_state(ConnectionState::Connected).await; @@ -62,166 +58,95 @@ impl WebSocketClient { } pub async fn run(&mut self) { - let mut reconnection_attempts = 0; - let mut backoff = INITIAL_BACKOFF; - - loop { - match self.run_connection().await { - Ok(_) => { - info!("WebSocket connection closed normally"); - break; - } - Err(e) => { - error!("WebSocket error: {}", e); - self.handle_disconnect().await; + let mut ping_interval = interval(PING_INTERVAL); + let mut pong_timeout = Box::pin(sleep(Duration::from_secs(0))); + let mut waiting_for_pong = false; - if reconnection_attempts >= MAX_RECONNECTION_ATTEMPTS { - error!("Max reconnection attempts reached. Giving up."); + while let Some(socket) = &mut self.socket { + tokio::select! { + _ = ping_interval.tick() => { + if let Err(e) = socket.send(Message::Ping(vec![])).await { + error!("Failed to send ping: {}", e); break; } - - reconnection_attempts += 1; - info!("Attempting to reconnect in {:?} (attempt {})", backoff, reconnection_attempts); - tokio::time::sleep(backoff).await; - backoff *= 2; // Exponential backoff - - if let Err(e) = self.connect().await { - error!("Failed to reconnect: {}", e); - continue; - } + waiting_for_pong = true; + pong_timeout = Box::pin(sleep(PONG_TIMEOUT)); } - } - } - } - - async fn run_connection(&mut self) -> PusherResult<()> { - let mut ping_interval = Instant::now() + PING_INTERVAL; - let mut pong_timeout: Option = None; - - while let Some(socket) = &mut self.socket { - tokio::select! { - cmd = self.command_rx.recv() => { + Some(cmd) = self.command_rx.recv() => { match cmd { - Some(WebSocketCommand::Send(msg)) => { - socket.send(Message::Text(msg)).await - .map_err(|e| PusherError::WebSocketError(format!("Failed to send message: {}", e)))?; - } - Some(WebSocketCommand::Close) => { - debug!("Closing WebSocket connection"); - socket.close(None).await - .map_err(|e| PusherError::WebSocketError(format!("Failed to close connection: {}", e)))?; - return Ok(()); + WebSocketCommand::Send(msg) => { + if let Err(e) = socket.send(Message::Text(msg)).await { + error!("Failed to send message: {}", e); + } } - None => { - return Err(PusherError::WebSocketError("Command channel closed".to_string())); + WebSocketCommand::Close => { + if let Err(e) = socket.close(None).await { + error!("Failed to close connection: {}", e); + } + break; } } } - message = socket.next() => { - match message { - Some(Ok(msg)) => self.handle_message(msg).await?, - Some(Err(e)) => return Err(PusherError::WebSocketError(format!("WebSocket error: {}", e))), + msg = socket.next() => { + match msg { + Some(Ok(msg)) => { + if let Message::Pong(_) = msg { + waiting_for_pong = false; + } + self.handle_message(msg).await; + } + Some(Err(e)) => { + error!("WebSocket error: {}", e); + break; + } None => { info!("WebSocket connection closed"); - return Ok(()); + break; } } } - _ = sleep(ping_interval.duration_since(Instant::now())) => { - debug!("Sending ping"); - socket.send(Message::Ping(vec![])).await - .map_err(|e| PusherError::WebSocketError(format!("Failed to send ping: {}", e)))?; - ping_interval = Instant::now() + PING_INTERVAL; - pong_timeout = Some(Instant::now() + PONG_TIMEOUT); - } - _ = async { - if let Some(timeout) = pong_timeout { - sleep(timeout.duration_since(Instant::now())).await; - } else { - std::future::pending::<()>().await; - } - } => { - if pong_timeout.is_some() { - warn!("Pong timeout reached"); - return Err(PusherError::WebSocketError("Pong timeout reached".to_string())); - } - } - } - - if let Some(timeout) = pong_timeout { - if Instant::now() >= timeout { - pong_timeout = None; + _ = &mut pong_timeout, if waiting_for_pong => { + error!("Pong timeout reached"); + break; } } } - Ok(()) + self.handle_disconnect().await; } - async fn handle_message(&mut self, msg: Message) -> PusherResult<()> { + + async fn handle_message(&mut self, msg: Message) { match msg { Message::Text(text) => self.handle_text_message(text).await, - Message::Ping(_) => self.handle_ping().await, - Message::Pong(_) => Ok(self.handle_pong()), - Message::Close(frame) => self.handle_close(frame).await, + Message::Ping(_) => { + if let Some(socket) = &mut self.socket { + if let Err(e) = socket.send(Message::Pong(vec![])).await { + error!("Failed to send pong: {}", e); + } + } + } + Message::Pong(_) => { + debug!("Received pong"); + } + Message::Close(frame) => { + info!("Received close frame: {:?}", frame); + self.handle_disconnect().await; + } _ => { debug!("Received unhandled message type"); - Ok(()) } } } - async fn handle_text_message(&self, text: String) -> PusherResult<()> { + async fn handle_text_message(&self, text: String) { debug!("Received text message: {}", text); - match serde_json::from_str::(&text) { - Ok(system_event) => self.handle_system_event(system_event).await, - Err(_) => { - if let Ok(event) = serde_json::from_str::(&text) { - self.event_tx.send(event).await - .map_err(|e| PusherError::WebSocketError(format!("Failed to send event to handler: {}", e)))?; - } else { - error!("Failed to parse message as Event or SystemEvent: {}", text); - } - Ok(()) - } - } - } - - async fn handle_system_event(&self, event: SystemEvent) -> PusherResult<()> { - match event.event.as_str() { - "pusher:connection_established" => { - info!("Connection established"); - self.set_state(ConnectionState::Connected).await; - } - "pusher:error" => { - error!("Received error event: {:?}", event.data); + if let Ok(event) = serde_json::from_str::(&text) { + if let Err(e) = self.event_tx.send(event).await { + error!("Failed to send event to handler: {}", e); } - _ => debug!("Received unhandled system event: {}", event.event), - } - Ok(()) - } - - async fn handle_ping(&mut self) -> PusherResult<()> { - debug!("Received ping, sending pong"); - if let Some(socket) = &mut self.socket { - socket.send(Message::Pong(vec![])).await - .map_err(|e| PusherError::WebSocketError(format!("Failed to send pong: {}", e)))?; - } - Ok(()) - } - - fn handle_pong(&mut self) { - debug!("Received pong"); - // TODO - Reset pong timeout - } - - async fn handle_close(&mut self, frame: Option>) -> PusherResult<()> { - if let Some(frame) = frame { - info!("WebSocket closed with code {} and reason: {}", frame.code, frame.reason); } else { - info!("WebSocket closed without close frame"); + error!("Failed to parse message as Event: {}", text); } - self.handle_disconnect().await; - Ok(()) } async fn handle_disconnect(&mut self) { @@ -234,33 +159,4 @@ impl WebSocketClient { *state = new_state.clone(); debug!("Connection state changed to: {:?}", new_state); } - - pub async fn send(&self, message: String) -> PusherResult<()> { - self.command_tx.send(WebSocketCommand::Send(message)).await - .map_err(|e| PusherError::WebSocketError(format!("Failed to send command: {}", e))) - } - - pub async fn close(&self) -> PusherResult<()> { - self.command_tx.send(WebSocketCommand::Close).await - .map_err(|e| PusherError::WebSocketError(format!("Failed to send close command: {}", e))) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::sync::mpsc; - use std::sync::Arc; - use tokio::sync::RwLock; - - #[tokio::test] - async fn test_websocket_client_creation() { - let url = Url::parse("wss://ws.pusherapp.com/app/1234?protocol=7").unwrap(); - let state = Arc::new(RwLock::new(ConnectionState::Disconnected)); - let (event_tx, _) = mpsc::channel(100); - - let client = WebSocketClient::new(url, state, event_tx); - assert!(client.socket.is_none()); - } - } \ No newline at end of file diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index b306ccb..be65773 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -42,21 +42,54 @@ async fn test_pusher_client_connection() { } #[tokio::test] -#[ignore] async fn test_channel_subscription() { let mut client = setup_client().await; - // client.connect().await.unwrap(); - client.subscribe("test-channel").await.unwrap(); + // Connect with a timeout + match timeout(Duration::from_secs(10), client.connect()).await { + Ok(result) => { + result.expect("Failed to connect to Pusher"); + } + Err(_) => panic!("Connection timed out"), + } + + // Ensure we're connected + assert_eq!( + client.get_connection_state().await, + ConnectionState::Connected + ); + + // Subscribe to the channel + match timeout(Duration::from_secs(5), client.subscribe("test-channel")).await { + Ok(result) => { + result.expect("Failed to subscribe to channel"); + } + Err(_) => panic!("Subscription timed out"), + } + + // Wait a bit for the subscription to be processed + tokio::time::sleep(Duration::from_secs(1)).await; let channels = client.get_subscribed_channels().await; log::info!("Subscribed channels: {:?}", channels); - assert!(channels.contains(&"test-channel".to_string())); + assert!(channels.contains(&"test-channel".to_string()), "Channel not found in subscribed channels"); + + // Unsubscribe from the channel + match timeout(Duration::from_secs(5), client.unsubscribe("test-channel")).await { + Ok(result) => { + result.expect("Failed to unsubscribe from channel"); + } + Err(_) => panic!("Unsubscription timed out"), + } - client.unsubscribe("test-channel").await.unwrap(); + // Wait a bit for the unsubscription to be processed + tokio::time::sleep(Duration::from_secs(1)).await; let channels = client.get_subscribed_channels().await; - assert!(!channels.contains(&"test-channel".to_string())); + assert!(!channels.contains(&"test-channel".to_string()), "Channel still present after unsubscription"); + + // Disconnect the client + client.disconnect().await.expect("Failed to disconnect"); } #[tokio::test] @@ -85,21 +118,23 @@ async fn test_event_binding() { assert!(*event_received.read().await); } -// #[tokio::test] -// async fn test_encrypted_channel() { -// let mut client = setup_client().await; +#[tokio::test] +#[ignore] +async fn test_encrypted_channel() { + let mut client = setup_client().await; + + client.connect().await.unwrap(); + client + .subscribe_encrypted("private-encrypted-channel") + .await + .unwrap(); -// client.connect().await.unwrap(); -// client -// .subscribe_encrypted("private-encrypted-channel") -// .await -// .unwrap(); + let channels = client.get_subscribed_channels().await; + assert!(channels.contains(&"private-encrypted-channel".to_string())); -// let channels = client.get_subscribed_channels().await; -// assert!(channels.contains(&"private-encrypted-channel".to_string())); + // TODO - Test sending and receiving encrypted messages +} -// // TODO - Test sending and receiving encrypted messages -// } #[tokio::test] async fn test_send_payload() { @@ -124,15 +159,15 @@ async fn test_send_payload() { let test_data = r#"{"message": "Hello, Pusher!"}"#; // Subscribe to the channel - // client - // .subscribe(test_channel) - // .await - // .expect("Failed to subscribe to channel"); + client + .subscribe(test_channel) + .await + .expect("Failed to subscribe to channel"); // Set up event binding to capture the triggered event let event_received = Arc::new(RwLock::new(false)); let event_received_clone = event_received.clone(); - let received_data = Arc::new(RwLock::new(String::new())); + let received_data = Arc::new(RwLock::new(None)); let received_data_clone = received_data.clone(); client @@ -143,7 +178,7 @@ async fn test_send_payload() { let mut flag = event_received.write().await; *flag = true; let mut data = received_data.write().await; - *data = serde_json::to_string(&event.data).unwrap(); + *data = Some(event.data); }); }) .await @@ -158,19 +193,21 @@ async fn test_send_payload() { // Wait for the event to be processed tokio::time::sleep(Duration::from_secs(2)).await; - // Assert that the event was received and processed - // assert!(*event_received.read().await, "Event was not received"); - - // Assert that the received data matches the sent data - // let received = received_data.read().await; - // assert_eq!( - // *received, test_data, - // "Received data does not match sent data" - // ); - - // client - // .unsubscribe(test_channel) - // .await - // .expect("Failed to unsubscribe from channel"); - // client.disconnect().await.expect("Failed to disconnect"); -} + // let's ssert that the event was received and processed + assert!(*event_received.read().await, "Event was not received"); + + // let's assert that the received data matches the sent data + let received = received_data.read().await; + let expected_data: serde_json::Value = serde_json::from_str(test_data).unwrap(); + assert_eq!( + received.as_ref().unwrap(), + &expected_data, + "Received data does not match sent data" + ); + + client + .unsubscribe(test_channel) + .await + .expect("Failed to unsubscribe from channel"); + client.disconnect().await.expect("Failed to disconnect"); +} \ No newline at end of file