diff --git a/protocols/identify/src/handler.rs b/protocols/identify/src/handler.rs index 91b988b1456f..9ba1db6886ce 100644 --- a/protocols/identify/src/handler.rs +++ b/protocols/identify/src/handler.rs @@ -19,7 +19,7 @@ // DEALINGS IN THE SOFTWARE. use crate::protocol::{Identify, InboundPush, OutboundPush, Push, UpgradeError}; -use crate::protocol::{Info, PartialInfo}; +use crate::protocol::{Info, PushInfo}; use either::Either; use futures::future::BoxFuture; use futures::prelude::*; @@ -49,8 +49,7 @@ use std::{io, task::Context, task::Poll, time::Duration}; /// permitting the underlying connection to be closed. pub struct Handler { remote_peer_id: PeerId, - remote_public_key: Option, - inbound_identify_push: Option>>, + inbound_identify_push: Option>>, /// Pending events to yield. events: SmallVec< [ConnectionHandlerEvent>, (), Event, io::Error>; 4], @@ -82,6 +81,9 @@ pub struct Handler { /// Address observed by or for the remote. observed_addr: Multiaddr, + /// Identify information about the remote peer. + remote_info: Option, + local_supported_protocols: SupportedProtocols, remote_supported_protocols: HashSet, external_addresses: HashSet, @@ -123,7 +125,6 @@ impl Handler { ) -> Self { Self { remote_peer_id, - remote_public_key: Default::default(), inbound_identify_push: Default::default(), events: SmallVec::new(), pending_replies: FuturesUnordered::new(), @@ -136,6 +137,7 @@ impl Handler { observed_addr, local_supported_protocols: SupportedProtocols::default(), remote_supported_protocols: HashSet::default(), + remote_info: Default::default(), external_addresses, } } @@ -154,7 +156,7 @@ impl Handler { let info = self.build_info(); self.pending_replies - .push(crate::protocol::send(substream, info.into()).boxed()); + .push(crate::protocol::send(substream, info).boxed()); } future::Either::Right(fut) => { if self.inbound_identify_push.replace(fut).is_some() { @@ -178,21 +180,8 @@ impl Handler { >, ) { match output { - future::Either::Left(remote_info) => { - self.update_supported_protocols_for_remote(&remote_info); - - match self.update_remote_public_key(remote_info).try_into() { - Ok(info) => { - self.events.push(ConnectionHandlerEvent::NotifyBehaviour( - Event::Identified(info), - )); - } - Err(error) => { - warn!( - "Failed to build remote info from inbound identify push stream from {:?}: {:?}", - self.remote_peer_id, error) - } - } + future::Either::Left(info) => { + self.handle_incoming_info(info); } future::Either::Right(()) => self.events.push(ConnectionHandlerEvent::NotifyBehaviour( Event::IdentificationPushed, @@ -225,18 +214,38 @@ impl Handler { } } - fn update_remote_public_key(&mut self, mut remote_info: PartialInfo) -> PartialInfo { - if let Some(key) = &remote_info.public_key { - self.remote_public_key.replace(key.clone()); - remote_info + fn handle_incoming_info(&mut self, info: Info) { + self.remote_info.replace(info.clone()); + + self.update_supported_protocols_for_remote(&info); + + self.events + .push(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified( + info, + ))); + } + + fn handle_incoming_push_info(&mut self, push_info: PushInfo) { + if let Some(mut info) = self.remote_info.take() { + info.merge(push_info); + self.remote_info.replace(info.clone()); + + self.update_supported_protocols_for_remote(&info); + + self.events + .push(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified( + info, + ))); } else { - remote_info.public_key = self.remote_public_key.clone(); - remote_info + warn!( + "Failed to process push from {:?} because no identify info was received before", + self.remote_peer_id + ) } } - fn update_supported_protocols_for_remote(&mut self, remote_info: &PartialInfo) { - let new_remote_protocols = HashSet::from_iter(remote_info.protocols.clone()); + fn update_supported_protocols_for_remote(&mut self, info: &Info) { + let new_remote_protocols = HashSet::from_iter(info.protocols.clone()); let remote_added_protocols = new_remote_protocols .difference(&self.remote_supported_protocols) @@ -296,10 +305,7 @@ impl ConnectionHandler for Handler { let info = self.build_info(); self.events .push(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new( - Either::Right(Push::outbound(info.into())), - (), - ), + protocol: SubstreamProtocol::new(Either::Right(Push::outbound(info)), ()), }); } } @@ -344,19 +350,7 @@ impl ConnectionHandler for Handler { self.inbound_identify_push.take(); if let Ok(remote_info) = res { - self.update_supported_protocols_for_remote(&remote_info); - match self.update_remote_public_key(remote_info).try_into() { - Ok(info) => { - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( - Event::Identified(info), - )); - } - Err(error) => { - warn!( - "Failed to build remote info from inbound identify stream from {:?}: {:?}", - self.remote_peer_id, error) - } - } + self.handle_incoming_push_info(remote_info); } } @@ -414,7 +408,7 @@ impl ConnectionHandler for Handler { self.events .push(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol: SubstreamProtocol::new( - Either::Right(Push::outbound(info.into())), + Either::Right(Push::outbound(info)), (), ), }); diff --git a/protocols/identify/src/protocol.rs b/protocols/identify/src/protocol.rs index 8596d8fba511..69a2f5b15f48 100644 --- a/protocols/identify/src/protocol.rs +++ b/protocols/identify/src/protocol.rs @@ -49,7 +49,7 @@ pub struct Identify; #[derive(Debug, Clone)] pub struct Push(T); pub struct InboundPush(); -pub struct OutboundPush(PartialInfo); +pub struct OutboundPush(Info); impl Push { pub fn inbound() -> Self { @@ -58,7 +58,7 @@ impl Push { } impl Push { - pub fn outbound(info: PartialInfo) -> Self { + pub fn outbound(info: Info) -> Self { Push(OutboundPush(info)) } } @@ -82,23 +82,33 @@ pub struct Info { pub observed_addr: Multiaddr, } -impl From for PartialInfo { - fn from(val: Info) -> Self { - PartialInfo { - public_key: Some(val.public_key), - protocol_version: Some(val.protocol_version), - agent_version: Some(val.agent_version), - listen_addrs: val.listen_addrs, - protocols: val.protocols, - observed_addr: Some(val.observed_addr), +impl Info { + pub fn merge(&mut self, info: PushInfo) { + if let Some(public_key) = info.public_key { + self.public_key = public_key; + } + if let Some(protocol_version) = info.protocol_version { + self.protocol_version = protocol_version; + } + if let Some(agent_version) = info.agent_version { + self.agent_version = agent_version; + } + if !info.listen_addrs.is_empty() { + self.listen_addrs = info.listen_addrs; + } + if !info.protocols.is_empty() { + self.protocols = info.protocols; + } + if let Some(observed_addr) = info.observed_addr { + self.observed_addr = observed_addr; } } } -/// Partial identify information of a peer sent in protocol messages. +/// Identify push information of a peer sent in protocol messages. /// Note that missing fields should be ignored, as peers may choose to send partial updates containing only the fields whose values have changed. #[derive(Debug, Clone)] -pub struct PartialInfo { +pub struct PushInfo { pub public_key: Option, pub protocol_version: Option, pub agent_version: Option, @@ -108,18 +118,16 @@ pub struct PartialInfo { } #[derive(Debug)] -pub enum IdentifyError { - MissingRemotePublicKey, +pub enum MissingInfoError { + PublicKey, } -impl TryFrom for Info { - type Error = IdentifyError; +impl TryFrom for Info { + type Error = MissingInfoError; - fn try_from(info: PartialInfo) -> Result { + fn try_from(info: PushInfo) -> Result { Ok(Info { - public_key: info - .public_key - .ok_or(IdentifyError::MissingRemotePublicKey)?, + public_key: info.public_key.ok_or(MissingInfoError::PublicKey)?, protocol_version: info.protocol_version.unwrap_or_default(), agent_version: info.agent_version.unwrap_or_default(), listen_addrs: info.listen_addrs, @@ -152,12 +160,16 @@ impl OutboundUpgrade for Identify where C: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - type Output = PartialInfo; + type Output = Info; type Error = UpgradeError; type Future = Pin> + Send>>; fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future { - recv(socket).boxed() + recv(socket) + .map(|result| { + result.and_then(|info| info.try_into().map_err(UpgradeError::MissingInfo)) + }) + .boxed() } } @@ -174,7 +186,7 @@ impl InboundUpgrade for Push where C: AsyncRead + AsyncWrite + Unpin + Send + 'static, { - type Output = BoxFuture<'static, Result>; // @todo -- Different future result? + type Output = BoxFuture<'static, Result>; // @todo -- Different future result? type Error = Void; type Future = future::Ready>; @@ -197,7 +209,7 @@ where } } -pub(crate) async fn send(io: T, info: PartialInfo) -> Result<(), UpgradeError> +pub(crate) async fn send(io: T, info: Info) -> Result<(), UpgradeError> where T: AsyncWrite + Unpin, { @@ -209,14 +221,14 @@ where .map(|addr| addr.to_vec()) .collect(); - let pubkey_bytes = info.public_key.map(|key| key.encode_protobuf()); + let pubkey_bytes = info.public_key.encode_protobuf(); let message = proto::Identify { - agentVersion: info.agent_version, - protocolVersion: info.protocol_version, - publicKey: pubkey_bytes, + agentVersion: Some(info.agent_version), + protocolVersion: Some(info.protocol_version), + publicKey: Some(pubkey_bytes), listenAddrs: listen_addrs, - observedAddr: info.observed_addr.map(|addr| addr.to_vec()), + observedAddr: Some(info.observed_addr.to_vec()), protocols: info.protocols.into_iter().map(|p| p.to_string()).collect(), }; @@ -231,7 +243,7 @@ where Ok(()) } -async fn recv(socket: T) -> Result +async fn recv(socket: T) -> Result where T: AsyncRead + AsyncWrite + Unpin, { @@ -254,7 +266,7 @@ where Ok(info) } -impl TryFrom for PartialInfo { +impl TryFrom for PushInfo { type Error = UpgradeError; fn try_from(msg: proto::Identify) -> Result { @@ -295,7 +307,7 @@ impl TryFrom for PartialInfo { } }); - let info = PartialInfo { + let info = PushInfo { public_key, protocol_version: msg.protocolVersion, agent_version: msg.agentVersion, @@ -326,6 +338,8 @@ pub enum UpgradeError { Io(#[from] io::Error), #[error("Stream closed")] StreamClosed, + #[error("Missing information received")] + MissingInfo(MissingInfoError), #[error("Failed decoding multiaddr")] Multiaddr(#[from] multiaddr::Error), #[error("Failed decoding public key")] @@ -361,7 +375,7 @@ mod tests { ), }; - let info = PartialInfo::try_from(payload).expect("not to fail"); + let info = PushInfo::try_from(payload).expect("not to fail"); assert_eq!(info.listen_addrs, vec![valid_multiaddr]) }