Skip to content

Commit

Permalink
refactor: avoid breaking changes to api
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel-G committed Sep 21, 2023
1 parent b605bf2 commit 1fab3ff
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 50 deletions.
66 changes: 50 additions & 16 deletions protocols/identify/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use crate::protocol::{Identify, InboundPush, Info, OutboundPush, Push, UpgradeError};
use crate::protocol::{Identify, InboundPush, OutboundPush, Push, UpgradeError};
use crate::protocol::{Info, PartialInfo};
use either::Either;
use futures::future::BoxFuture;
use futures::prelude::*;
Expand Down Expand Up @@ -48,7 +49,8 @@ use std::{io, task::Context, task::Poll, time::Duration};
/// permitting the underlying connection to be closed.
pub struct Handler {
remote_peer_id: PeerId,
inbound_identify_push: Option<BoxFuture<'static, Result<Info, UpgradeError>>>,
remote_public_key: Option<PublicKey>,
inbound_identify_push: Option<BoxFuture<'static, Result<PartialInfo, UpgradeError>>>,
/// Pending events to yield.
events: SmallVec<
[ConnectionHandlerEvent<Either<Identify, Push<OutboundPush>>, (), Event, io::Error>; 4],
Expand Down Expand Up @@ -121,6 +123,7 @@ impl Handler {
) -> Self {
Self {
remote_peer_id,
remote_public_key: Default::default(),
inbound_identify_push: Default::default(),
events: SmallVec::new(),
pending_replies: FuturesUnordered::new(),
Expand Down Expand Up @@ -151,7 +154,7 @@ impl Handler {
let info = self.build_info();

self.pending_replies
.push(crate::protocol::send(substream, info).boxed());
.push(crate::protocol::send(substream, info.into()).boxed());
}
future::Either::Right(fut) => {
if self.inbound_identify_push.replace(fut).is_some() {
Expand All @@ -177,10 +180,19 @@ impl Handler {
match output {
future::Either::Left(remote_info) => {
self.update_supported_protocols_for_remote(&remote_info);
self.events
.push(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified(
remote_info,
)));

match self.update_remote_public_key(remote_info).try_into() {
Ok(info) => {
self.events.push(ConnectionHandlerEvent::NotifyBehaviour(
Event::Identified(info),
));
}
Err(error) => {
warn!(
"Failed to build remote info from inbound identify push stream from {:?}: {:?}",
self.remote_peer_id, error)
}
}
}
future::Either::Right(()) => self.events.push(ConnectionHandlerEvent::NotifyBehaviour(
Event::IdentificationPushed,
Expand All @@ -204,7 +216,7 @@ impl Handler {

fn build_info(&mut self) -> Info {
Info {
public_key: Some(self.public_key.clone()),
public_key: self.public_key.clone(),
protocol_version: self.protocol_version.clone(),
agent_version: self.agent_version.clone(),
listen_addrs: Vec::from_iter(self.external_addresses.iter().cloned()),
Expand All @@ -213,7 +225,17 @@ impl Handler {
}
}

fn update_supported_protocols_for_remote(&mut self, remote_info: &Info) {
fn update_remote_public_key(&mut self, mut remote_info: PartialInfo) -> PartialInfo {
if let Some(key) = &remote_info.public_key {
self.remote_public_key.replace(key.clone());
remote_info
} else {
remote_info.public_key = self.remote_public_key.clone();
remote_info
}
}

fn update_supported_protocols_for_remote(&mut self, remote_info: &PartialInfo) {
let new_remote_protocols = HashSet::from_iter(remote_info.protocols.clone());

let remote_added_protocols = new_remote_protocols
Expand Down Expand Up @@ -274,7 +296,10 @@ impl ConnectionHandler for Handler {
let info = self.build_info();
self.events
.push(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(Either::Right(Push::outbound(info)), ()),
protocol: SubstreamProtocol::new(
Either::Right(Push::outbound(info.into())),
(),
),
});
}
}
Expand Down Expand Up @@ -318,11 +343,20 @@ impl ConnectionHandler for Handler {
{
self.inbound_identify_push.take();

if let Ok(info) = res {
self.update_supported_protocols_for_remote(&info);
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified(
info,
)));
if let Ok(remote_info) = res {
self.update_supported_protocols_for_remote(&remote_info);
match self.update_remote_public_key(remote_info).try_into() {
Ok(info) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::Identified(info),
));
}
Err(error) => {
warn!(
"Failed to build remote info from inbound identify stream from {:?}: {:?}",
self.remote_peer_id, error)
}
}
}
}

Expand Down Expand Up @@ -380,7 +414,7 @@ impl ConnectionHandler for Handler {
self.events
.push(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(
Either::Right(Push::outbound(info)),
Either::Right(Push::outbound(info.into())),
(),
),
});
Expand Down
114 changes: 83 additions & 31 deletions protocols/identify/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub struct Identify;
#[derive(Debug, Clone)]
pub struct Push<T>(T);
pub struct InboundPush();
pub struct OutboundPush(Info);
pub struct OutboundPush(PartialInfo);

impl Push<InboundPush> {
pub fn inbound() -> Self {
Expand All @@ -58,16 +58,16 @@ impl Push<InboundPush> {
}

impl Push<OutboundPush> {
pub fn outbound(info: Info) -> Self {
pub fn outbound(info: PartialInfo) -> Self {
Push(OutboundPush(info))
}
}

/// Information of a peer sent in protocol messages.
/// Identify information of a peer sent in protocol messages.
#[derive(Debug, Clone)]
pub struct Info {
/// The public key of the local peer.
pub public_key: Option<PublicKey>,
pub public_key: PublicKey,
/// Application-specific version of the protocol family used by the peer,
/// e.g. `ipfs/1.0.0` or `polkadot/1.0.0`.
pub protocol_version: String,
Expand All @@ -82,6 +82,53 @@ pub struct Info {
pub observed_addr: Multiaddr,
}

impl From<Info> for PartialInfo {
fn from(val: Info) -> Self {
PartialInfo {
public_key: Some(val.public_key),
protocol_version: Some(val.protocol_version),
agent_version: Some(val.agent_version),
listen_addrs: val.listen_addrs,
protocols: val.protocols,
observed_addr: Some(val.observed_addr),
}
}
}

/// Partial identify information of a peer sent in protocol messages.
/// Note that missing fields should be ignored, as peers may choose to send partial updates containing only the fields whose values have changed.
#[derive(Debug, Clone)]
pub struct PartialInfo {
pub public_key: Option<PublicKey>,
pub protocol_version: Option<String>,
pub agent_version: Option<String>,
pub listen_addrs: Vec<Multiaddr>,
pub protocols: Vec<StreamProtocol>,
pub observed_addr: Option<Multiaddr>,
}

#[derive(Debug)]
pub enum IdentifyError {
MissingRemotePublicKey,
}

impl TryFrom<PartialInfo> for Info {
type Error = IdentifyError;

fn try_from(info: PartialInfo) -> Result<Self, Self::Error> {
Ok(Info {
public_key: info
.public_key
.ok_or(IdentifyError::MissingRemotePublicKey)?,
protocol_version: info.protocol_version.unwrap_or_default(),
agent_version: info.agent_version.unwrap_or_default(),
listen_addrs: info.listen_addrs,
protocols: info.protocols,
observed_addr: info.observed_addr.unwrap_or_else(Multiaddr::empty),
})
}
}

impl UpgradeInfo for Identify {
type Info = StreamProtocol;
type InfoIter = iter::Once<Self::Info>;
Expand All @@ -105,7 +152,7 @@ impl<C> OutboundUpgrade<C> for Identify
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = Info;
type Output = PartialInfo;
type Error = UpgradeError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;

Expand All @@ -127,7 +174,7 @@ impl<C> InboundUpgrade<C> for Push<InboundPush>
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = BoxFuture<'static, Result<Info, UpgradeError>>;
type Output = BoxFuture<'static, Result<PartialInfo, UpgradeError>>; // @todo -- Different future result?
type Error = Void;
type Future = future::Ready<Result<Self::Output, Self::Error>>;

Expand All @@ -150,7 +197,7 @@ where
}
}

pub(crate) async fn send<T>(io: T, info: Info) -> Result<(), UpgradeError>
pub(crate) async fn send<T>(io: T, info: PartialInfo) -> Result<(), UpgradeError>
where
T: AsyncWrite + Unpin,
{
Expand All @@ -165,11 +212,11 @@ where
let pubkey_bytes = info.public_key.map(|key| key.encode_protobuf());

let message = proto::Identify {
agentVersion: Some(info.agent_version),
protocolVersion: Some(info.protocol_version),
agentVersion: info.agent_version,
protocolVersion: info.protocol_version,
publicKey: pubkey_bytes,
listenAddrs: listen_addrs,
observedAddr: Some(info.observed_addr.to_vec()),
observedAddr: info.observed_addr.map(|addr| addr.to_vec()),
protocols: info.protocols.into_iter().map(|p| p.to_string()).collect(),
};

Expand All @@ -184,7 +231,7 @@ where
Ok(())
}

async fn recv<T>(socket: T) -> Result<Info, UpgradeError>
async fn recv<T>(socket: T) -> Result<PartialInfo, UpgradeError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Expand All @@ -207,7 +254,7 @@ where
Ok(info)
}

impl TryFrom<proto::Identify> for Info {
impl TryFrom<proto::Identify> for PartialInfo {
type Error = UpgradeError;

fn try_from(msg: proto::Identify) -> Result<Self, Self::Error> {
Expand All @@ -221,32 +268,37 @@ impl TryFrom<proto::Identify> for Info {
match parse_multiaddr(addr) {
Ok(a) => addrs.push(a),
Err(e) => {
debug!("Unable to parse multiaddr: {e:?}");
debug!("Unable to parse listen multiaddr: {e:?}");
}
}
}
addrs
};

let public_key = msg.publicKey.and_then(|key| match PublicKey::try_decode_protobuf(&key) {
Ok(k) => Some(k),
Err(e) => {
debug!("Unable to decode public key: {e:?}");
None
}
});
let public_key = msg
.publicKey
.and_then(|key| match PublicKey::try_decode_protobuf(&key) {
Ok(k) => Some(k),
Err(e) => {
debug!("Unable to decode public key: {e:?}");
None
}
});

let observed_addr = msg
.observedAddr
.and_then(|bytes| match parse_multiaddr(bytes) {
Ok(a) => Some(a),
Err(e) => {
debug!("Unable to parse observed multiaddr: {e:?}");
None
}
});

let observed_addr = match parse_multiaddr(msg.observedAddr.unwrap_or_default()) {
Ok(a) => a,
Err(e) => {
debug!("Unable to parse multiaddr: {e:?}");
Multiaddr::empty()
}
};
let info = Info {
let info = PartialInfo {
public_key,
protocol_version: msg.protocolVersion.unwrap_or_default(),
agent_version: msg.agentVersion.unwrap_or_default(),
protocol_version: msg.protocolVersion,
agent_version: msg.agentVersion,
listen_addrs,
protocols: msg
.protocols
Expand Down Expand Up @@ -309,7 +361,7 @@ mod tests {
),
};

let info = Info::try_from(payload).expect("not to fail");
let info = PartialInfo::try_from(payload).expect("not to fail");

assert_eq!(info.listen_addrs, vec![valid_multiaddr])
}
Expand Down
6 changes: 3 additions & 3 deletions protocols/identify/tests/smoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async fn periodic_identify() {
[BehaviourEvent::Identify(Sent { .. }), BehaviourEvent::Identify(Received { info: s1_info, .. })],
[BehaviourEvent::Identify(Received { info: s2_info, .. }), BehaviourEvent::Identify(Sent { .. })],
) => {
assert_eq!(s1_info.public_key.unwrap().to_peer_id(), swarm2_peer_id);
assert_eq!(s1_info.public_key.to_peer_id(), swarm2_peer_id);
assert_eq!(s1_info.protocol_version, "c");
assert_eq!(s1_info.agent_version, "d");
assert!(!s1_info.protocols.is_empty());
Expand All @@ -61,7 +61,7 @@ async fn periodic_identify() {
assert!(s1_info.listen_addrs.contains(&swarm2_tcp_listen_addr));
assert!(s1_info.listen_addrs.contains(&swarm2_memory_listen));

assert_eq!(s2_info.public_key.unwrap().to_peer_id(), swarm1_peer_id);
assert_eq!(s2_info.public_key.to_peer_id(), swarm1_peer_id);
assert_eq!(s2_info.protocol_version, "a");
assert_eq!(s2_info.agent_version, "b");
assert!(!s2_info.protocols.is_empty());
Expand Down Expand Up @@ -127,7 +127,7 @@ async fn identify_push() {
};

assert_eq!(
swarm1_received_info.public_key.unwrap().to_peer_id(),
swarm1_received_info.public_key.to_peer_id(),
*swarm2.local_peer_id()
);
assert_eq!(swarm1_received_info.protocol_version, "a");
Expand Down

0 comments on commit 1fab3ff

Please sign in to comment.