Skip to content

Commit

Permalink
Merge pull request #3 from chmod77/feat/rewrite-connect-websocket
Browse files Browse the repository at this point in the history
feat - rewrite WS connection
  • Loading branch information
chmod77 authored Aug 31, 2024
2 parents 9b41da7 + 9d19219 commit 40a0164
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 255 deletions.
83 changes: 83 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,53 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
```


### Connecting

Connect to Pusher:

```rust
client.connect().await?;
```

### Subscribing to Channels

Subscribe to a public channel:

```rust
client.subscribe("my-channel").await?;
```

Subscribe to a private channel:

```rust
client.subscribe("private-my-channel").await?;
```

Subscribe to a presence channel:

```rust
client.subscribe("presence-my-channel").await?;
```

### Unsubscribing from Channels

```rust
client.unsubscribe("my-channel").await?;
```

### Binding to Events

Bind to a specific event on a channel:

```rust
use pusher_rs::Event;

client.bind("my-event", |event: Event| {
println!("Received event: {:?}", event);
}).await?;
```

### Subscribing to a channel

```rust
Expand Down Expand Up @@ -133,6 +180,15 @@ The library supports four types of channels:

Each channel type has specific features and authentication requirements.

### Handling Connection State

Get the current connection state:

```rust
let state = client.get_connection_state().await;
println!("Current connection state: {:?}", state);
```

## Error Handling

The library uses a custom `PusherError` type for error handling. You can match on different error variants to handle specific error cases:
Expand All @@ -148,6 +204,14 @@ match client.connect().await {
}
```

### Disconnecting

When you're done, disconnect from Pusher:

```rust
client.disconnect().await?;
```

## Advanced Usage

### Custom Configuration
Expand Down Expand Up @@ -186,6 +250,25 @@ if let Some(channel) = channel_list.get("my-channel") {
}
```

### Presence Channels

When subscribing to a presence channel, you can provide user information:

```rust
use serde_json::json;

let channel = "presence-my-channel";
let socket_id = client.get_socket_id().await?;
let user_id = "user_123";
let user_info = json!({
"name": "John Doe",
"email": "john@example.com"
});

let auth = client.authenticate_presence_channel(&socket_id, channel, user_id, Some(&user_info))?;
client.subscribe_with_auth(channel, &auth).await?;
```

### Tests

Integration tests live under `tests/integration_tests`
Expand Down
20 changes: 20 additions & 0 deletions src/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use serde_json::Value;
pub struct Event {
pub event: String,
pub channel: Option<String>,
#[serde(with = "json_string")]
pub data: Value,
}

Expand Down Expand Up @@ -113,6 +114,25 @@ impl SystemEvent {
}
}

mod json_string {
use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer};
use serde_json::Value;

pub fn serialize<S>(value: &Value, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
value.to_string().serialize(serializer)
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<Value, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
serde_json::from_str(&s).map_err(D::Error::custom)
}
}
#[cfg(test)]
mod tests {
use super::*;
Expand Down
109 changes: 65 additions & 44 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use cbc::{Decryptor, Encryptor};
use hmac::{Hmac, Mac};
use log::info;
use rand::Rng;
use serde_json::json;
use serde_json::{json, Value};
use sha2::Sha256;
use std::collections::HashMap;
use std::sync::Arc;
Expand All @@ -28,14 +28,15 @@ pub use config::PusherConfig;
pub use error::{PusherError, PusherResult};
pub use events::{Event, SystemEvent};

use websocket::WebSocketClient;
use websocket::{WebSocketClient, WebSocketCommand};

/// This struct provides methods for connecting to Pusher, subscribing to channels,
/// triggering events, and handling incoming events.
pub struct PusherClient {
config: PusherConfig,
auth: PusherAuth,
websocket: Option<WebSocketClient>,
// websocket: Option<WebSocketClient>,
websocket_command_tx: Option<mpsc::Sender<WebSocketCommand>>,
channels: Arc<RwLock<HashMap<String, Channel>>>,
event_handlers: Arc<RwLock<HashMap<String, Vec<Box<dyn Fn(Event) + Send + Sync + 'static>>>>>,
state: Arc<RwLock<ConnectionState>>,
Expand Down Expand Up @@ -73,29 +74,44 @@ impl PusherClient {
let auth = PusherAuth::new(&config.app_key, &config.app_secret);
let (event_tx, event_rx) = mpsc::channel(100);
let state = Arc::new(RwLock::new(ConnectionState::Disconnected));
let event_handlers = Arc::new(RwLock::new(HashMap::new()));
let encrypted_channels = Arc::new(RwLock::new(HashMap::new()));
let event_handlers = Arc::new(RwLock::new(std::collections::HashMap::new()));
let encrypted_channels = Arc::new(RwLock::new(std::collections::HashMap::new()));

let client = Self {
config,
auth,
websocket: None,
channels: Arc::new(RwLock::new(HashMap::new())),
websocket_command_tx: None,
channels: Arc::new(RwLock::new(std::collections::HashMap::new())),
event_handlers: event_handlers.clone(),
state: state.clone(),
event_tx,
encrypted_channels,
};

// Spawn the event handling task
tokio::spawn(Self::handle_events(event_rx, event_handlers));

Ok(client)
}

async fn send(&self, message: String) -> PusherResult<()> {
if let Some(tx) = &self.websocket_command_tx {
tx.send(WebSocketCommand::Send(message))
.await
.map_err(|e| {
PusherError::WebSocketError(format!("Failed to send command: {}", e))
})?;
Ok(())
} else {
Err(PusherError::ConnectionError("Not connected".into()))
}
}

async fn handle_events(
mut event_rx: mpsc::Receiver<Event>,
event_handlers: Arc<
RwLock<HashMap<String, Vec<Box<dyn Fn(Event) + Send + Sync + 'static>>>>,
RwLock<
std::collections::HashMap<String, Vec<Box<dyn Fn(Event) + Send + Sync + 'static>>>,
>,
>,
) {
while let Some(event) = event_rx.recv().await {
Expand All @@ -115,18 +131,24 @@ impl PusherClient {
/// A `PusherResult` indicating success or failure.
pub async fn connect(&mut self) -> PusherResult<()> {
let url = self.get_websocket_url()?;
let mut websocket =
WebSocketClient::new(url.clone(), Arc::clone(&self.state), self.event_tx.clone());
let (command_tx, command_rx) = mpsc::channel(100);

let mut websocket = WebSocketClient::new(
url.clone(),
Arc::clone(&self.state),
self.event_tx.clone(),
command_rx,
);

log::info!("Connecting to Pusher using URL: {}", url);
websocket.connect().await?;
self.websocket = Some(websocket);

// Start the WebSocket event loop
let mut ws = self.websocket.take().unwrap();
tokio::spawn(async move {
ws.run().await;
websocket.run().await;
});

self.websocket_command_tx = Some(command_tx);

Ok(())
}

Expand All @@ -136,11 +158,12 @@ impl PusherClient {
///
/// A `PusherResult` indicating success or failure.
pub async fn disconnect(&mut self) -> PusherResult<()> {
if let Some(websocket) = &self.websocket {
websocket.close().await?;
if let Some(tx) = self.websocket_command_tx.take() {
tx.send(WebSocketCommand::Close).await.map_err(|e| {
PusherError::WebSocketError(format!("Failed to send close command: {}", e))
})?;
}
*self.state.write().await = ConnectionState::Disconnected;
self.websocket = None;
Ok(())
}

Expand All @@ -158,21 +181,17 @@ impl PusherClient {
let mut channels = self.channels.write().await;
channels.insert(channel_name.to_string(), channel);

if let Some(websocket) = &self.websocket {
let data = json!({
"event": "pusher:subscribe",
"data": {
"channel": channel_name
}
});
websocket.send(serde_json::to_string(&data)?).await?;
} else {
return Err(PusherError::ConnectionError("Not connected".into()));
}
let data = json!({
"event": "pusher:subscribe",
"data": {
"channel": channel_name
}
});

Ok(())
self.send(serde_json::to_string(&data)?).await
}


/// Subscribes to an encrypted channel.
///
/// # Arguments
Expand Down Expand Up @@ -208,6 +227,7 @@ impl PusherClient {
/// # Returns
///
/// A `PusherResult` indicating success or failure.
///
pub async fn unsubscribe(&mut self, channel_name: &str) -> PusherResult<()> {
{
let mut channels = self.channels.write().await;
Expand All @@ -219,19 +239,14 @@ impl PusherClient {
encrypted_channels.remove(channel_name);
}

if let Some(websocket) = &self.websocket {
let data = json!({
"event": "pusher:unsubscribe",
"data": {
"channel": channel_name
}
});
websocket.send(serde_json::to_string(&data)?).await?;
} else {
return Err(PusherError::ConnectionError("Not connected".into()));
}
let data = json!({
"event": "pusher:unsubscribe",
"data": {
"channel": channel_name
}
});

Ok(())
self.send(serde_json::to_string(&data)?).await
}

/// Triggers an event on a channel.
Expand All @@ -251,10 +266,14 @@ impl PusherClient {
self.config.cluster, self.config.app_id
);

// Validate that the data is valid JSON, but keep it as a string
serde_json::from_str::<serde_json::Value>(data)
.map_err(|e| PusherError::JsonError(e))?;

let body = json!({
"name": event,
"channel": channel,
"data": data
"data": data, // Keep data as a string
});
let path = format!("/apps/{}/events", self.config.app_id);
let auth_params = self.auth.authenticate_request("POST", &path, &body)?;
Expand Down Expand Up @@ -371,6 +390,7 @@ impl PusherClient {
/// # Returns
///
/// A `PusherResult` indicating success or failure.
///
pub async fn bind<F>(&self, event_name: &str, callback: F) -> PusherResult<()>
where
F: Fn(Event) + Send + Sync + 'static,
Expand Down Expand Up @@ -535,7 +555,8 @@ mod tests {

#[tokio::test]
async fn test_trigger_batch() {
let config = PusherConfig::from_env().expect("Failed to load Pusher configuration from environment");
let config =
PusherConfig::from_env().expect("Failed to load Pusher configuration from environment");
let client = PusherClient::new(config).unwrap();

let batch_events = vec![
Expand Down
Loading

0 comments on commit 40a0164

Please sign in to comment.