From 1fab3ffd184c36c88d6199e6c8f650dbc8d15503 Mon Sep 17 00:00:00 2001 From: Marcel Gleeson Date: Thu, 21 Sep 2023 18:27:58 +0200 Subject: [PATCH] refactor: avoid breaking changes to api --- protocols/identify/src/handler.rs | 66 +++++++++++++---- protocols/identify/src/protocol.rs | 114 +++++++++++++++++++++-------- protocols/identify/tests/smoke.rs | 6 +- 3 files changed, 136 insertions(+), 50 deletions(-) diff --git a/protocols/identify/src/handler.rs b/protocols/identify/src/handler.rs index 746864a57b9..91b988b1456 100644 --- a/protocols/identify/src/handler.rs +++ b/protocols/identify/src/handler.rs @@ -18,7 +18,8 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::protocol::{Identify, InboundPush, Info, OutboundPush, Push, UpgradeError}; +use crate::protocol::{Identify, InboundPush, OutboundPush, Push, UpgradeError}; +use crate::protocol::{Info, PartialInfo}; use either::Either; use futures::future::BoxFuture; use futures::prelude::*; @@ -48,7 +49,8 @@ use std::{io, task::Context, task::Poll, time::Duration}; /// permitting the underlying connection to be closed. pub struct Handler { remote_peer_id: PeerId, - inbound_identify_push: Option>>, + remote_public_key: Option, + inbound_identify_push: Option>>, /// Pending events to yield. events: SmallVec< [ConnectionHandlerEvent>, (), Event, io::Error>; 4], @@ -121,6 +123,7 @@ impl Handler { ) -> Self { Self { remote_peer_id, + remote_public_key: Default::default(), inbound_identify_push: Default::default(), events: SmallVec::new(), pending_replies: FuturesUnordered::new(), @@ -151,7 +154,7 @@ impl Handler { let info = self.build_info(); self.pending_replies - .push(crate::protocol::send(substream, info).boxed()); + .push(crate::protocol::send(substream, info.into()).boxed()); } future::Either::Right(fut) => { if self.inbound_identify_push.replace(fut).is_some() { @@ -177,10 +180,19 @@ impl Handler { match output { future::Either::Left(remote_info) => { self.update_supported_protocols_for_remote(&remote_info); - self.events - .push(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified( - remote_info, - ))); + + match self.update_remote_public_key(remote_info).try_into() { + Ok(info) => { + self.events.push(ConnectionHandlerEvent::NotifyBehaviour( + Event::Identified(info), + )); + } + Err(error) => { + warn!( + "Failed to build remote info from inbound identify push stream from {:?}: {:?}", + self.remote_peer_id, error) + } + } } future::Either::Right(()) => self.events.push(ConnectionHandlerEvent::NotifyBehaviour( Event::IdentificationPushed, @@ -204,7 +216,7 @@ impl Handler { fn build_info(&mut self) -> Info { Info { - public_key: Some(self.public_key.clone()), + public_key: self.public_key.clone(), protocol_version: self.protocol_version.clone(), agent_version: self.agent_version.clone(), listen_addrs: Vec::from_iter(self.external_addresses.iter().cloned()), @@ -213,7 +225,17 @@ impl Handler { } } - fn update_supported_protocols_for_remote(&mut self, remote_info: &Info) { + fn update_remote_public_key(&mut self, mut remote_info: PartialInfo) -> PartialInfo { + if let Some(key) = &remote_info.public_key { + self.remote_public_key.replace(key.clone()); + remote_info + } else { + remote_info.public_key = self.remote_public_key.clone(); + remote_info + } + } + + fn update_supported_protocols_for_remote(&mut self, remote_info: &PartialInfo) { let new_remote_protocols = HashSet::from_iter(remote_info.protocols.clone()); let remote_added_protocols = new_remote_protocols @@ -274,7 +296,10 @@ impl ConnectionHandler for Handler { let info = self.build_info(); self.events .push(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(Either::Right(Push::outbound(info)), ()), + protocol: SubstreamProtocol::new( + Either::Right(Push::outbound(info.into())), + (), + ), }); } } @@ -318,11 +343,20 @@ impl ConnectionHandler for Handler { { self.inbound_identify_push.take(); - if let Ok(info) = res { - self.update_supported_protocols_for_remote(&info); - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified( - info, - ))); + if let Ok(remote_info) = res { + self.update_supported_protocols_for_remote(&remote_info); + match self.update_remote_public_key(remote_info).try_into() { + Ok(info) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::Identified(info), + )); + } + Err(error) => { + warn!( + "Failed to build remote info from inbound identify stream from {:?}: {:?}", + self.remote_peer_id, error) + } + } } } @@ -380,7 +414,7 @@ impl ConnectionHandler for Handler { self.events .push(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol: SubstreamProtocol::new( - Either::Right(Push::outbound(info)), + Either::Right(Push::outbound(info.into())), (), ), }); diff --git a/protocols/identify/src/protocol.rs b/protocols/identify/src/protocol.rs index bb0a4cc490b..8596d8fba51 100644 --- a/protocols/identify/src/protocol.rs +++ b/protocols/identify/src/protocol.rs @@ -49,7 +49,7 @@ pub struct Identify; #[derive(Debug, Clone)] pub struct Push(T); pub struct InboundPush(); -pub struct OutboundPush(Info); +pub struct OutboundPush(PartialInfo); impl Push { pub fn inbound() -> Self { @@ -58,16 +58,16 @@ impl Push { } impl Push { - pub fn outbound(info: Info) -> Self { + pub fn outbound(info: PartialInfo) -> Self { Push(OutboundPush(info)) } } -/// Information of a peer sent in protocol messages. +/// Identify information of a peer sent in protocol messages. #[derive(Debug, Clone)] pub struct Info { /// The public key of the local peer. - pub public_key: Option, + pub public_key: PublicKey, /// Application-specific version of the protocol family used by the peer, /// e.g. `ipfs/1.0.0` or `polkadot/1.0.0`. pub protocol_version: String, @@ -82,6 +82,53 @@ pub struct Info { pub observed_addr: Multiaddr, } +impl From for PartialInfo { + fn from(val: Info) -> Self { + PartialInfo { + public_key: Some(val.public_key), + protocol_version: Some(val.protocol_version), + agent_version: Some(val.agent_version), + listen_addrs: val.listen_addrs, + protocols: val.protocols, + observed_addr: Some(val.observed_addr), + } + } +} + +/// Partial identify information of a peer sent in protocol messages. +/// Note that missing fields should be ignored, as peers may choose to send partial updates containing only the fields whose values have changed. +#[derive(Debug, Clone)] +pub struct PartialInfo { + pub public_key: Option, + pub protocol_version: Option, + pub agent_version: Option, + pub listen_addrs: Vec, + pub protocols: Vec, + pub observed_addr: Option, +} + +#[derive(Debug)] +pub enum IdentifyError { + MissingRemotePublicKey, +} + +impl TryFrom for Info { + type Error = IdentifyError; + + fn try_from(info: PartialInfo) -> Result { + Ok(Info { + public_key: info + .public_key + .ok_or(IdentifyError::MissingRemotePublicKey)?, + protocol_version: info.protocol_version.unwrap_or_default(), + agent_version: info.agent_version.unwrap_or_default(), + listen_addrs: info.listen_addrs, + protocols: info.protocols, + observed_addr: info.observed_addr.unwrap_or_else(Multiaddr::empty), + }) + } +} + impl UpgradeInfo for Identify { type Info = StreamProtocol; type InfoIter = iter::Once; @@ -105,7 +152,7 @@ impl OutboundUpgrade for Identify where C: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - type Output = Info; + type Output = PartialInfo; type Error = UpgradeError; type Future = Pin> + Send>>; @@ -127,7 +174,7 @@ impl InboundUpgrade for Push where C: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - type Output = BoxFuture<'static, Result>; + type Output = BoxFuture<'static, Result>; // @todo -- Different future result? type Error = Void; type Future = future::Ready>; @@ -150,7 +197,7 @@ where } } -pub(crate) async fn send(io: T, info: Info) -> Result<(), UpgradeError> +pub(crate) async fn send(io: T, info: PartialInfo) -> Result<(), UpgradeError> where T: AsyncWrite + Unpin, { @@ -165,11 +212,11 @@ where let pubkey_bytes = info.public_key.map(|key| key.encode_protobuf()); let message = proto::Identify { - agentVersion: Some(info.agent_version), - protocolVersion: Some(info.protocol_version), + agentVersion: info.agent_version, + protocolVersion: info.protocol_version, publicKey: pubkey_bytes, listenAddrs: listen_addrs, - observedAddr: Some(info.observed_addr.to_vec()), + observedAddr: info.observed_addr.map(|addr| addr.to_vec()), protocols: info.protocols.into_iter().map(|p| p.to_string()).collect(), }; @@ -184,7 +231,7 @@ where Ok(()) } -async fn recv(socket: T) -> Result +async fn recv(socket: T) -> Result where T: AsyncRead + AsyncWrite + Unpin, { @@ -207,7 +254,7 @@ where Ok(info) } -impl TryFrom for Info { +impl TryFrom for PartialInfo { type Error = UpgradeError; fn try_from(msg: proto::Identify) -> Result { @@ -221,32 +268,37 @@ impl TryFrom for Info { match parse_multiaddr(addr) { Ok(a) => addrs.push(a), Err(e) => { - debug!("Unable to parse multiaddr: {e:?}"); + debug!("Unable to parse listen multiaddr: {e:?}"); } } } addrs }; - let public_key = msg.publicKey.and_then(|key| match PublicKey::try_decode_protobuf(&key) { - Ok(k) => Some(k), - Err(e) => { - debug!("Unable to decode public key: {e:?}"); - None - } - }); + let public_key = msg + .publicKey + .and_then(|key| match PublicKey::try_decode_protobuf(&key) { + Ok(k) => Some(k), + Err(e) => { + debug!("Unable to decode public key: {e:?}"); + None + } + }); + + let observed_addr = msg + .observedAddr + .and_then(|bytes| match parse_multiaddr(bytes) { + Ok(a) => Some(a), + Err(e) => { + debug!("Unable to parse observed multiaddr: {e:?}"); + None + } + }); - let observed_addr = match parse_multiaddr(msg.observedAddr.unwrap_or_default()) { - Ok(a) => a, - Err(e) => { - debug!("Unable to parse multiaddr: {e:?}"); - Multiaddr::empty() - } - }; - let info = Info { + let info = PartialInfo { public_key, - protocol_version: msg.protocolVersion.unwrap_or_default(), - agent_version: msg.agentVersion.unwrap_or_default(), + protocol_version: msg.protocolVersion, + agent_version: msg.agentVersion, listen_addrs, protocols: msg .protocols @@ -309,7 +361,7 @@ mod tests { ), }; - let info = Info::try_from(payload).expect("not to fail"); + let info = PartialInfo::try_from(payload).expect("not to fail"); assert_eq!(info.listen_addrs, vec![valid_multiaddr]) } diff --git a/protocols/identify/tests/smoke.rs b/protocols/identify/tests/smoke.rs index 4cfde7620dd..c70ab3181b4 100644 --- a/protocols/identify/tests/smoke.rs +++ b/protocols/identify/tests/smoke.rs @@ -48,7 +48,7 @@ async fn periodic_identify() { [BehaviourEvent::Identify(Sent { .. }), BehaviourEvent::Identify(Received { info: s1_info, .. })], [BehaviourEvent::Identify(Received { info: s2_info, .. }), BehaviourEvent::Identify(Sent { .. })], ) => { - assert_eq!(s1_info.public_key.unwrap().to_peer_id(), swarm2_peer_id); + assert_eq!(s1_info.public_key.to_peer_id(), swarm2_peer_id); assert_eq!(s1_info.protocol_version, "c"); assert_eq!(s1_info.agent_version, "d"); assert!(!s1_info.protocols.is_empty()); @@ -61,7 +61,7 @@ async fn periodic_identify() { assert!(s1_info.listen_addrs.contains(&swarm2_tcp_listen_addr)); assert!(s1_info.listen_addrs.contains(&swarm2_memory_listen)); - assert_eq!(s2_info.public_key.unwrap().to_peer_id(), swarm1_peer_id); + assert_eq!(s2_info.public_key.to_peer_id(), swarm1_peer_id); assert_eq!(s2_info.protocol_version, "a"); assert_eq!(s2_info.agent_version, "b"); assert!(!s2_info.protocols.is_empty()); @@ -127,7 +127,7 @@ async fn identify_push() { }; assert_eq!( - swarm1_received_info.public_key.unwrap().to_peer_id(), + swarm1_received_info.public_key.to_peer_id(), *swarm2.local_peer_id() ); assert_eq!(swarm1_received_info.protocol_version, "a");