Skip to content

Commit

Permalink
Add futures::Sink<PublishMessage> to async_nats::Client
Browse files Browse the repository at this point in the history
Signed-off-by: Roman Volosatovs <rvolosatovs@riseup.net>
  • Loading branch information
rvolosatovs authored Aug 20, 2024
1 parent 17e5d65 commit 053944d
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 23 deletions.
1 change: 1 addition & 0 deletions async-nats/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ serde_repr = "0.1.16"
tokio = { version = "1.36", features = ["macros", "rt", "fs", "net", "sync", "time", "io-util"] }
url = { version = "2"}
tokio-rustls = { version = "0.26", default-features = false }
tokio-util = "0.7"
rustls-pemfile = "2"
nuid = "0.5"
serde_nanos = "0.1.3"
Expand Down
63 changes: 49 additions & 14 deletions async-nats/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use core::pin::Pin;
use core::task::{Context, Poll};

use crate::connection::State;
use crate::subject::ToSubject;
use crate::ServerInfo;
use crate::{PublishMessage, ServerInfo};

use super::{header::HeaderMap, status::StatusCode, Command, Message, Subscriber};
use crate::error::Error;
use bytes::Bytes;
use futures::future::TryFutureExt;
use futures::StreamExt;
use futures::{Sink, SinkExt as _, StreamExt};
use once_cell::sync::Lazy;
use portable_atomic::AtomicU64;
use regex::Regex;
Expand All @@ -29,6 +32,7 @@ use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tokio::sync::{mpsc, oneshot};
use tokio_util::sync::PollSender;
use tracing::trace;

static VERSION_RE: Lazy<Regex> =
Expand All @@ -44,6 +48,12 @@ impl From<tokio::sync::mpsc::error::SendError<Command>> for PublishError {
}
}

impl From<tokio_util::sync::PollSendError<Command>> for PublishError {
fn from(err: tokio_util::sync::PollSendError<Command>) -> Self {
PublishError::with_source(PublishErrorKind::Send, err)
}
}

#[derive(Copy, Clone, Debug, PartialEq)]
pub enum PublishErrorKind {
MaxPayloadExceeded,
Expand All @@ -67,13 +77,36 @@ pub struct Client {
info: tokio::sync::watch::Receiver<ServerInfo>,
pub(crate) state: tokio::sync::watch::Receiver<State>,
pub(crate) sender: mpsc::Sender<Command>,
poll_sender: PollSender<Command>,
next_subscription_id: Arc<AtomicU64>,
subscription_capacity: usize,
inbox_prefix: Arc<str>,
request_timeout: Option<Duration>,
max_payload: Arc<AtomicUsize>,
}

impl Sink<PublishMessage> for Client {
type Error = PublishError;

fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_sender.poll_ready_unpin(cx).map_err(Into::into)
}

fn start_send(mut self: Pin<&mut Self>, msg: PublishMessage) -> Result<(), Self::Error> {
self.poll_sender
.start_send_unpin(Command::Publish(msg))
.map_err(Into::into)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_sender.poll_flush_unpin(cx).map_err(Into::into)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.poll_sender.poll_close_unpin(cx).map_err(Into::into)
}
}

impl Client {
pub(crate) fn new(
info: tokio::sync::watch::Receiver<ServerInfo>,
Expand All @@ -84,10 +117,12 @@ impl Client {
request_timeout: Option<Duration>,
max_payload: Arc<AtomicUsize>,
) -> Client {
let poll_sender = PollSender::new(sender.clone());
Client {
info,
state,
sender,
poll_sender,
next_subscription_id: Arc::new(AtomicU64::new(1)),
subscription_capacity: capacity,
inbox_prefix: inbox_prefix.into(),
Expand Down Expand Up @@ -191,12 +226,12 @@ impl Client {
}

self.sender
.send(Command::Publish {
.send(Command::Publish(PublishMessage {
subject,
payload,
respond: None,
reply: None,
headers: None,
})
}))
.await?;
Ok(())
}
Expand Down Expand Up @@ -229,12 +264,12 @@ impl Client {
let subject = subject.to_subject();

self.sender
.send(Command::Publish {
.send(Command::Publish(PublishMessage {
subject,
payload,
respond: None,
reply: None,
headers: Some(headers),
})
}))
.await?;
Ok(())
}
Expand Down Expand Up @@ -265,12 +300,12 @@ impl Client {
let reply = reply.to_subject();

self.sender
.send(Command::Publish {
.send(Command::Publish(PublishMessage {
subject,
payload,
respond: Some(reply),
reply: Some(reply),
headers: None,
})
}))
.await?;
Ok(())
}
Expand Down Expand Up @@ -304,12 +339,12 @@ impl Client {
let reply = reply.to_subject();

self.sender
.send(Command::Publish {
.send(Command::Publish(PublishMessage {
subject,
payload,
respond: Some(reply),
reply: Some(reply),
headers: Some(headers),
})
}))
.await?;
Ok(())
}
Expand Down
23 changes: 14 additions & 9 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,19 @@ pub(crate) enum ServerOp {
},
}

/// `PublishMessage` represents a message being published
#[derive(Debug)]
pub struct PublishMessage {
pub subject: Subject,
pub payload: Bytes,
pub reply: Option<Subject>,
pub headers: Option<HeaderMap>,
}

/// `Command` represents all commands that a [`Client`] can handle
#[derive(Debug)]
pub(crate) enum Command {
Publish {
subject: Subject,
payload: Bytes,
respond: Option<Subject>,
headers: Option<HeaderMap>,
},
Publish(PublishMessage),
Request {
subject: Subject,
payload: Bytes,
Expand Down Expand Up @@ -822,12 +827,12 @@ impl ConnectionHandler {
self.connection.enqueue_write_op(&pub_op);
}

Command::Publish {
Command::Publish(PublishMessage {
subject,
payload,
respond,
reply: respond,
headers,
} => {
}) => {
self.connection.enqueue_write_op(&ClientOp::Publish {
subject,
payload,
Expand Down

0 comments on commit 053944d

Please sign in to comment.