Skip to content

Commit

Permalink
refactor: store all previous identify info
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcel-G committed Sep 22, 2023
1 parent ee5c7fd commit 88709d6
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 80 deletions.
86 changes: 40 additions & 46 deletions protocols/identify/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
// DEALINGS IN THE SOFTWARE.

use crate::protocol::{Identify, InboundPush, OutboundPush, Push, UpgradeError};
use crate::protocol::{Info, PartialInfo};
use crate::protocol::{Info, PushInfo};
use either::Either;
use futures::future::BoxFuture;
use futures::prelude::*;
Expand Down Expand Up @@ -49,8 +49,7 @@ use std::{io, task::Context, task::Poll, time::Duration};
/// permitting the underlying connection to be closed.
pub struct Handler {
remote_peer_id: PeerId,
remote_public_key: Option<PublicKey>,
inbound_identify_push: Option<BoxFuture<'static, Result<PartialInfo, UpgradeError>>>,
inbound_identify_push: Option<BoxFuture<'static, Result<PushInfo, UpgradeError>>>,
/// Pending events to yield.
events: SmallVec<
[ConnectionHandlerEvent<Either<Identify, Push<OutboundPush>>, (), Event, io::Error>; 4],
Expand Down Expand Up @@ -82,6 +81,9 @@ pub struct Handler {
/// Address observed by or for the remote.
observed_addr: Multiaddr,

/// Identify information about the remote peer.
remote_info: Option<Info>,

local_supported_protocols: SupportedProtocols,
remote_supported_protocols: HashSet<StreamProtocol>,
external_addresses: HashSet<Multiaddr>,
Expand Down Expand Up @@ -123,7 +125,6 @@ impl Handler {
) -> Self {
Self {
remote_peer_id,
remote_public_key: Default::default(),
inbound_identify_push: Default::default(),
events: SmallVec::new(),
pending_replies: FuturesUnordered::new(),
Expand All @@ -136,6 +137,7 @@ impl Handler {
observed_addr,
local_supported_protocols: SupportedProtocols::default(),
remote_supported_protocols: HashSet::default(),
remote_info: Default::default(),
external_addresses,
}
}
Expand All @@ -154,7 +156,7 @@ impl Handler {
let info = self.build_info();

self.pending_replies
.push(crate::protocol::send(substream, info.into()).boxed());
.push(crate::protocol::send(substream, info).boxed());
}
future::Either::Right(fut) => {
if self.inbound_identify_push.replace(fut).is_some() {
Expand All @@ -178,21 +180,8 @@ impl Handler {
>,
) {
match output {
future::Either::Left(remote_info) => {
self.update_supported_protocols_for_remote(&remote_info);

match self.update_remote_public_key(remote_info).try_into() {
Ok(info) => {
self.events.push(ConnectionHandlerEvent::NotifyBehaviour(
Event::Identified(info),
));
}
Err(error) => {
warn!(
"Failed to build remote info from inbound identify push stream from {:?}: {:?}",
self.remote_peer_id, error)
}
}
future::Either::Left(info) => {
self.handle_incoming_info(info);
}
future::Either::Right(()) => self.events.push(ConnectionHandlerEvent::NotifyBehaviour(
Event::IdentificationPushed,
Expand Down Expand Up @@ -225,18 +214,38 @@ impl Handler {
}
}

fn update_remote_public_key(&mut self, mut remote_info: PartialInfo) -> PartialInfo {
if let Some(key) = &remote_info.public_key {
self.remote_public_key.replace(key.clone());
remote_info
fn handle_incoming_info(&mut self, info: Info) {
self.remote_info.replace(info.clone());

self.update_supported_protocols_for_remote(&info);

self.events
.push(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified(
info,
)));
}

fn handle_incoming_push_info(&mut self, push_info: PushInfo) {
if let Some(mut info) = self.remote_info.take() {
info.merge(push_info);
self.remote_info.replace(info.clone());

self.update_supported_protocols_for_remote(&info);

self.events
.push(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified(
info,
)));
} else {
remote_info.public_key = self.remote_public_key.clone();
remote_info
warn!(
"Failed to process push from {:?} because no identify info was received before",
self.remote_peer_id
)
}
}

fn update_supported_protocols_for_remote(&mut self, remote_info: &PartialInfo) {
let new_remote_protocols = HashSet::from_iter(remote_info.protocols.clone());
fn update_supported_protocols_for_remote(&mut self, info: &Info) {
let new_remote_protocols = HashSet::from_iter(info.protocols.clone());

let remote_added_protocols = new_remote_protocols
.difference(&self.remote_supported_protocols)
Expand Down Expand Up @@ -296,10 +305,7 @@ impl ConnectionHandler for Handler {
let info = self.build_info();
self.events
.push(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(
Either::Right(Push::outbound(info.into())),
(),
),
protocol: SubstreamProtocol::new(Either::Right(Push::outbound(info)), ()),
});
}
}
Expand Down Expand Up @@ -344,19 +350,7 @@ impl ConnectionHandler for Handler {
self.inbound_identify_push.take();

if let Ok(remote_info) = res {
self.update_supported_protocols_for_remote(&remote_info);
match self.update_remote_public_key(remote_info).try_into() {
Ok(info) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::Identified(info),
));
}
Err(error) => {
warn!(
"Failed to build remote info from inbound identify stream from {:?}: {:?}",
self.remote_peer_id, error)
}
}
self.handle_incoming_push_info(remote_info);
}
}

Expand Down Expand Up @@ -414,7 +408,7 @@ impl ConnectionHandler for Handler {
self.events
.push(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(
Either::Right(Push::outbound(info.into())),
Either::Right(Push::outbound(info)),
(),
),
});
Expand Down
85 changes: 51 additions & 34 deletions protocols/identify/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub struct Identify;
#[derive(Debug, Clone)]
pub struct Push<T>(T);
pub struct InboundPush();
pub struct OutboundPush(PartialInfo);
pub struct OutboundPush(Info);

impl Push<InboundPush> {
pub fn inbound() -> Self {
Expand All @@ -58,7 +58,7 @@ impl Push<InboundPush> {
}

impl Push<OutboundPush> {
pub fn outbound(info: PartialInfo) -> Self {
pub fn outbound(info: Info) -> Self {
Push(OutboundPush(info))
}
}
Expand All @@ -82,23 +82,33 @@ pub struct Info {
pub observed_addr: Multiaddr,
}

impl From<Info> for PartialInfo {
fn from(val: Info) -> Self {
PartialInfo {
public_key: Some(val.public_key),
protocol_version: Some(val.protocol_version),
agent_version: Some(val.agent_version),
listen_addrs: val.listen_addrs,
protocols: val.protocols,
observed_addr: Some(val.observed_addr),
impl Info {
pub fn merge(&mut self, info: PushInfo) {
if let Some(public_key) = info.public_key {
self.public_key = public_key;
}
if let Some(protocol_version) = info.protocol_version {
self.protocol_version = protocol_version;
}
if let Some(agent_version) = info.agent_version {
self.agent_version = agent_version;
}
if !info.listen_addrs.is_empty() {
self.listen_addrs = info.listen_addrs;
}
if !info.protocols.is_empty() {
self.protocols = info.protocols;
}
if let Some(observed_addr) = info.observed_addr {
self.observed_addr = observed_addr;
}
}
}

/// Partial identify information of a peer sent in protocol messages.
/// Identify push information of a peer sent in protocol messages.
/// Note that missing fields should be ignored, as peers may choose to send partial updates containing only the fields whose values have changed.
#[derive(Debug, Clone)]
pub struct PartialInfo {
pub struct PushInfo {
pub public_key: Option<PublicKey>,
pub protocol_version: Option<String>,
pub agent_version: Option<String>,
Expand All @@ -108,18 +118,16 @@ pub struct PartialInfo {
}

#[derive(Debug)]
pub enum IdentifyError {
MissingRemotePublicKey,
pub enum MissingInfoError {
PublicKey,
}

impl TryFrom<PartialInfo> for Info {
type Error = IdentifyError;
impl TryFrom<PushInfo> for Info {
type Error = MissingInfoError;

fn try_from(info: PartialInfo) -> Result<Self, Self::Error> {
fn try_from(info: PushInfo) -> Result<Self, Self::Error> {
Ok(Info {
public_key: info
.public_key
.ok_or(IdentifyError::MissingRemotePublicKey)?,
public_key: info.public_key.ok_or(MissingInfoError::PublicKey)?,
protocol_version: info.protocol_version.unwrap_or_default(),
agent_version: info.agent_version.unwrap_or_default(),
listen_addrs: info.listen_addrs,
Expand Down Expand Up @@ -152,12 +160,16 @@ impl<C> OutboundUpgrade<C> for Identify
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = PartialInfo;
type Output = Info;
type Error = UpgradeError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;

fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future {
recv(socket).boxed()
recv(socket)
.map(|result| {
result.and_then(|info| info.try_into().map_err(UpgradeError::MissingInfo))
})
.boxed()
}
}

Expand All @@ -174,7 +186,7 @@ impl<C> InboundUpgrade<C> for Push<InboundPush>
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = BoxFuture<'static, Result<PartialInfo, UpgradeError>>; // @todo -- Different future result?
type Output = BoxFuture<'static, Result<PushInfo, UpgradeError>>;
type Error = Void;
type Future = future::Ready<Result<Self::Output, Self::Error>>;

Expand All @@ -197,7 +209,7 @@ where
}
}

pub(crate) async fn send<T>(io: T, info: PartialInfo) -> Result<(), UpgradeError>
pub(crate) async fn send<T>(io: T, info: Info) -> Result<(), UpgradeError>
where
T: AsyncWrite + Unpin,
{
Expand All @@ -209,14 +221,14 @@ where
.map(|addr| addr.to_vec())
.collect();

let pubkey_bytes = info.public_key.map(|key| key.encode_protobuf());
let pubkey_bytes = info.public_key.encode_protobuf();

let message = proto::Identify {
agentVersion: info.agent_version,
protocolVersion: info.protocol_version,
publicKey: pubkey_bytes,
agentVersion: Some(info.agent_version),
protocolVersion: Some(info.protocol_version),
publicKey: Some(pubkey_bytes),
listenAddrs: listen_addrs,
observedAddr: info.observed_addr.map(|addr| addr.to_vec()),
observedAddr: Some(info.observed_addr.to_vec()),
protocols: info.protocols.into_iter().map(|p| p.to_string()).collect(),
};

Expand All @@ -231,7 +243,7 @@ where
Ok(())
}

async fn recv<T>(socket: T) -> Result<PartialInfo, UpgradeError>
async fn recv<T>(socket: T) -> Result<PushInfo, UpgradeError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Expand All @@ -254,7 +266,10 @@ where
Ok(info)
}

impl TryFrom<proto::Identify> for PartialInfo {
impl TryFrom<proto::Identify> for Info {
}

impl TryFrom<proto::Identify> for PushInfo {
type Error = UpgradeError;

fn try_from(msg: proto::Identify) -> Result<Self, Self::Error> {
Expand Down Expand Up @@ -295,7 +310,7 @@ impl TryFrom<proto::Identify> for PartialInfo {
}
});

let info = PartialInfo {
let info = PushInfo {
public_key,
protocol_version: msg.protocolVersion,
agent_version: msg.agentVersion,
Expand Down Expand Up @@ -326,6 +341,8 @@ pub enum UpgradeError {
Io(#[from] io::Error),
#[error("Stream closed")]
StreamClosed,
#[error("Missing information received")]
MissingInfo(MissingInfoError),
#[error("Failed decoding multiaddr")]
Multiaddr(#[from] multiaddr::Error),
#[error("Failed decoding public key")]
Expand Down Expand Up @@ -361,7 +378,7 @@ mod tests {
),
};

let info = PartialInfo::try_from(payload).expect("not to fail");
let info = PushInfo::try_from(payload).expect("not to fail");

assert_eq!(info.listen_addrs, vec![valid_multiaddr])
}
Expand Down

0 comments on commit 88709d6

Please sign in to comment.