Skip to content

Commit

Permalink
feature: move wait_for_xxx util into metrics.
Browse files Browse the repository at this point in the history
Introduce struct `Wait` as a wrapper of the metrics channel to impl
wait-for utils:
- `log()`:  wait for log to apply.
- `current_leader()`: wait for known leader.
- `state()`: wait for the role.
- `members()`: wait for membership_config.members.
- `next_members()`: wait for membership_config.members_after_consensus.

E.g.:

```rust
// wait for ever for raft node's current leader to become 3:
r.wait(None).current_leader(2).await?;
```

The timeout is now an option arg to all wait_for_xxx functions in
fixtures. wait_for_xxx_timeout are all removed.
  • Loading branch information
drmingdrmer committed Jun 21, 2021
1 parent 919d91c commit 1ad17e8
Show file tree
Hide file tree
Showing 22 changed files with 530 additions and 142 deletions.
2 changes: 2 additions & 0 deletions async-raft/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ pub mod config;
mod core;
pub mod error;
pub mod metrics;
#[cfg(test)]
mod metrics_wait_test;
pub mod network;
pub mod raft;
mod replication;
Expand Down
157 changes: 157 additions & 0 deletions async-raft/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
//! Metrics are observed on a running Raft node via the `Raft::metrics()` method, which will
//! return a stream of metrics.
use std::collections::HashSet;

use serde::Deserialize;
use serde::Serialize;
use thiserror::Error;
use tokio::sync::watch;
use tokio::time::Duration;

use crate::core::State;
use crate::raft::MembershipConfig;
use crate::NodeId;
use crate::RaftError;

/// A set of metrics describing the current state of a Raft node.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
Expand Down Expand Up @@ -47,3 +53,154 @@ impl RaftMetrics {
}
}
}

// Error variants related to metrics.
#[derive(Debug, Error)]
pub enum WaitError {
#[error("timeout after {0:?} when {1}")]
Timeout(Duration, String),

#[error("{0}")]
RaftError(#[from] RaftError),
}

/// Wait is a wrapper of RaftMetrics channel that impls several utils to wait for metrics to satisfy some condition.
pub struct Wait {
pub timeout: Duration,
pub rx: watch::Receiver<RaftMetrics>,
}

impl Wait {
/// Wait for metrics to satisfy some condition or timeout.
#[tracing::instrument(level = "debug", skip(self, func), fields(msg=msg.to_string().as_str()))]
pub async fn metrics<T>(&self, func: T, msg: impl ToString) -> Result<RaftMetrics, WaitError>
where T: Fn(&RaftMetrics) -> bool + Send {
let mut rx = self.rx.clone();
loop {
let latest = rx.borrow().clone();

tracing::debug!(
"id={} wait {:} latest: {:?}",
latest.id,
msg.to_string(),
latest
);

if func(&latest) {
tracing::debug!(
"id={} done wait {:} latest: {:?}",
latest.id,
msg.to_string(),
latest
);
return Ok(latest);
}

let delay = tokio::time::sleep(self.timeout);

tokio::select! {
_ = delay => {
tracing::debug!( "id={} timeout wait {:} latest: {:?}", latest.id, msg.to_string(), latest );
return Err(WaitError::Timeout(self.timeout, format!("{} latest: {:?}", msg.to_string(), latest)));
}
changed = rx.changed() => {
match changed {
Ok(_) => {
// metrics changed, continue the waiting loop
},
Err(err) => {
tracing::debug!(
"id={} error: {:?}; wait {:} latest: {:?}",
latest.id,
err,
msg.to_string(),
latest
);
return Err(WaitError::RaftError(RaftError::ShuttingDown));
}
}
}
};
}
}

/// Wait for `current_leader` to become `Some(leader_id)` until timeout.
#[tracing::instrument(level = "debug", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn current_leader(
&self,
leader_id: NodeId,
msg: impl ToString,
) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.current_leader == Some(leader_id),
&format!("{} .current_leader -> {}", msg.to_string(), leader_id),
)
.await
}

/// Wait until applied upto `want_log`(inclusive) logs or timeout.
#[tracing::instrument(level = "debug", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn log(&self, want_log: u64, msg: impl ToString) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.last_log_index == want_log,
&format!("{} .last_log_index -> {}", msg.to_string(), want_log),
)
.await?;

self.metrics(
|x| x.last_applied == want_log,
&format!("{} .last_applied -> {}", msg.to_string(), want_log),
)
.await
}

/// Wait for `state` to become `want_state` or timeout.
#[tracing::instrument(level = "debug", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn state(
&self,
want_state: State,
msg: impl ToString,
) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.state == want_state,
&format!("{} .state -> {:?}", msg.to_string(), want_state),
)
.await
}

/// Wait for `membership_config.members` to become expected node set or timeout.
#[tracing::instrument(level = "debug", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn members(
&self,
want_members: HashSet<NodeId>,
msg: impl ToString,
) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.membership_config.members == want_members,
&format!(
"{} .membership_config.members -> {:?}",
msg.to_string(),
want_members
),
)
.await
}

/// Wait for `membership_config.members_after_consensus` to become expected node set or timeout.
#[tracing::instrument(level = "debug", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn next_members(
&self,
want_members: Option<HashSet<NodeId>>,
msg: impl ToString,
) -> Result<RaftMetrics, WaitError> {
self.metrics(
|x| x.membership_config.members_after_consensus == want_members,
&format!(
"{} .membership_config.members_after_consensus -> {:?}",
msg.to_string(),
want_members
),
)
.await
}
}
149 changes: 149 additions & 0 deletions async-raft/src/metrics_wait_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
use std::time::Duration;

use maplit::hashset;
use tokio::sync::watch;
use tokio::time::sleep;

use crate::metrics::Wait;
use crate::metrics::WaitError;
use crate::raft::MembershipConfig;
use crate::RaftMetrics;
use crate::State;

/// Test wait for different state changes
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_wait() -> anyhow::Result<()> {
{
// wait for leader
let (init, w, tx) = init_wait_test();

let h = tokio::spawn(async move {
sleep(Duration::from_millis(10)).await;
let mut update = init.clone();
update.current_leader = Some(3);
let rst = tx.send(update);
assert!(rst.is_ok());
});
let got = w.current_leader(3, "leader").await?;
h.await?;
assert_eq!(Some(3), got.current_leader);
}

{
// wait for log
let (init, w, tx) = init_wait_test();

let h = tokio::spawn(async move {
sleep(Duration::from_millis(10)).await;
let mut update = init.clone();
update.last_log_index = 3;
update.last_applied = 3;
let rst = tx.send(update);
assert!(rst.is_ok());
});
let got = w.log(3, "log").await?;
h.await?;

assert_eq!(3, got.last_log_index);
assert_eq!(3, got.last_applied);
}

{
// wait for state
let (init, w, tx) = init_wait_test();

let h = tokio::spawn(async move {
sleep(Duration::from_millis(10)).await;
let mut update = init.clone();
update.state = State::Leader;
let rst = tx.send(update);
assert!(rst.is_ok());
});
let got = w.state(State::Leader, "state").await?;
h.await?;

assert_eq!(State::Leader, got.state);
}

{
// wait for members
let (init, w, tx) = init_wait_test();

let h = tokio::spawn(async move {
sleep(Duration::from_millis(10)).await;
let mut update = init.clone();
update.membership_config.members = hashset![1, 2];
let rst = tx.send(update);
assert!(rst.is_ok());
});
let got = w.members(hashset![1, 2], "members").await?;
h.await?;

assert_eq!(hashset![1, 2], got.membership_config.members);
}

{
// wait for next_members
let (init, w, tx) = init_wait_test();

let h = tokio::spawn(async move {
sleep(Duration::from_millis(10)).await;
let mut update = init.clone();
update.membership_config.members_after_consensus = Some(hashset![1, 2]);
let rst = tx.send(update);
assert!(rst.is_ok());
});
let got = w.next_members(Some(hashset![1, 2]), "next_members").await?;
h.await?;

assert_eq!(
Some(hashset![1, 2]),
got.membership_config.members_after_consensus
);
}

{
// timeout
let (_init, w, _tx) = init_wait_test();

let h = tokio::spawn(async move {
sleep(Duration::from_millis(200)).await;
});
let got = w.state(State::Follower, "timeout").await;
h.await?;

match got.unwrap_err() {
WaitError::Timeout(t, _) => {
assert_eq!(Duration::from_millis(100), t);
}
_ => {
panic!("expect WaitError::Timeout");
}
}
}

Ok(())
}

/// Build a initial state for testing of Wait:
/// Returns init metrics, Wait, and the tx to send an updated metrics.
fn init_wait_test() -> (RaftMetrics, Wait, watch::Sender<RaftMetrics>) {
let init = RaftMetrics {
id: 0,
state: State::NonVoter,
current_term: 0,
last_log_index: 0,
last_applied: 0,
current_leader: None,
membership_config: MembershipConfig {
members: Default::default(),
members_after_consensus: None,
},
};
let (tx, rx) = watch::channel(init.clone());
let w = Wait {
timeout: Duration::from_millis(100),
rx,
};
return (init, w, tx);
}
33 changes: 32 additions & 1 deletion async-raft/src/raft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;

use serde::Deserialize;
use serde::Serialize;
Expand All @@ -20,6 +21,7 @@ use crate::error::InitializeError;
use crate::error::RaftError;
use crate::error::RaftResult;
use crate::metrics::RaftMetrics;
use crate::metrics::Wait;
use crate::AppData;
use crate::AppDataResponse;
use crate::NodeId;
Expand Down Expand Up @@ -311,6 +313,35 @@ impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> Ra
self.inner.rx_metrics.clone()
}

/// Get a handle to wait for the metrics to satisfy some condition.
///
/// ```ignore
/// # use std::time::Duration;
/// # use async_raft::{State, Raft};
///
/// let timeout = Duration::from_millis(200);
///
/// // wait for raft log-3 to be received and applied:
/// r.wait(Some(timeout)).log(3).await?;
///
/// // wait for ever for raft node's current leader to become 3:
/// r.wait(None).current_leader(2).await?;
///
/// // wait for raft state to become a follower
/// r.wait(None).state(State::Follower).await?;
///
/// ```
pub fn wait(&self, timeout: Option<Duration>) -> Wait {
let timeout = match timeout {
Some(t) => t,
None => Duration::from_millis(500),
};
Wait {
timeout,
rx: self.inner.rx_metrics.clone(),
}
}

/// Shutdown this Raft node.
pub async fn shutdown(&self) -> anyhow::Result<()> {
if let Some(tx) = self.inner.tx_shutdown.lock().await.take() {
Expand Down Expand Up @@ -512,7 +543,7 @@ pub struct EntrySnapshotPointer {
/// The membership configuration of the cluster.
/// Unlike original raft, the membership always a joint.
/// It could be a joint of one, two or more members, i.e., a quorum requires a majority of every members
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Clone, Default, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct MembershipConfig {
/// All members of the Raft cluster.
pub members: HashSet<NodeId>,
Expand Down
Loading

0 comments on commit 1ad17e8

Please sign in to comment.