Skip to content

Commit

Permalink
fix Calling subscribe() multiple times on PubSubStream with the same …
Browse files Browse the repository at this point in the history
…channel results in an error on close() #59
  • Loading branch information
mcatanzariti committed Mar 16, 2024
1 parent e3c1b01 commit 4109e8a
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
24 changes: 19 additions & 5 deletions src/client/pub_sub_stream.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use crate::{
client::{Client, ClientPreparedCommand},
commands::InternalPubSubCommands,
network::PubSubSender,
resp::{ByteBufSeed, CommandArgs, SingleArg, SingleArgCollection},
PubSubReceiver, Result,
client::{Client, ClientPreparedCommand}, commands::InternalPubSubCommands, network::PubSubSender, resp::{ByteBufSeed, CommandArgs, SingleArg, SingleArgCollection}, Error, PubSubReceiver, Result
};
use futures_util::{Stream, StreamExt};
use serde::{
Expand Down Expand Up @@ -104,6 +100,12 @@ impl PubSubSplitSink {
{
let channels = CommandArgs::default().arg(channels).build();

for channel in &channels {
if self.channels.iter().any(|c| c == channel) {
return Err(Error::Client(format!("pub sub stream already subscribed to channel `{}`", String::from_utf8_lossy(channel))));
}
}

self.client
.subscribe_from_pub_sub_sender(&channels, &self.sender)
.await?;
Expand All @@ -121,6 +123,12 @@ impl PubSubSplitSink {
{
let patterns = CommandArgs::default().arg(patterns).build();

for pattern in &patterns {
if self.patterns.iter().any(|p| p == pattern) {
return Err(Error::Client(format!("pub sub stream already subscribed to pattern `{}`", String::from_utf8_lossy(pattern))));
}
}

self.client
.psubscribe_from_pub_sub_sender(&patterns, &self.sender)
.await?;
Expand All @@ -138,6 +146,12 @@ impl PubSubSplitSink {
{
let shardchannels = CommandArgs::default().arg(shardchannels).build();

for shardchannel in &shardchannels {
if self.shardchannels.iter().any(|c| c == shardchannel) {
return Err(Error::Client(format!("pub sub stream already subscribed to shard channel `{}`", String::from_utf8_lossy(shardchannel))));
}
}

self.client
.ssubscribe_from_pub_sub_sender(&shardchannels, &self.sender)
.await?;
Expand Down
23 changes: 23 additions & 0 deletions src/tests/pub_sub_commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,26 @@ async fn split() -> Result<()> {

Ok(())
}


#[cfg_attr(feature = "tokio-runtime", tokio::test)]
#[cfg_attr(feature = "async-std-runtime", async_std::test)]
#[serial]
async fn subscribe_twice() -> Result<()> {
let pub_sub_client = get_test_client().await?;
let regular_client = get_test_client().await?;

// cleanup
regular_client.flushdb(FlushingMode::Sync).await?;

let mut pub_sub_stream = pub_sub_client.subscribe("mychannel").await?;
assert!(pub_sub_stream.subscribe("mychannel").await.is_err());

pub_sub_stream.psubscribe("pattern").await?;
assert!(pub_sub_stream.psubscribe("pattern").await.is_err());

pub_sub_stream.ssubscribe("mychannel").await?;
assert!(pub_sub_stream.ssubscribe("mychannel").await.is_err());

Ok(())
}

0 comments on commit 4109e8a

Please sign in to comment.