Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(identify): handle partial push messages #4495

Merged
merged 13 commits into from
Sep 24, 2023
Merged
64 changes: 49 additions & 15 deletions protocols/identify/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use crate::protocol::{Identify, InboundPush, Info, OutboundPush, Push, UpgradeError};
use crate::protocol::{Identify, InboundPush, OutboundPush, Push, UpgradeError};
use crate::protocol::{Info, PartialInfo};
use either::Either;
use futures::future::BoxFuture;
use futures::prelude::*;
Expand Down Expand Up @@ -48,7 +49,8 @@ use std::{io, task::Context, task::Poll, time::Duration};
/// permitting the underlying connection to be closed.
pub struct Handler {
remote_peer_id: PeerId,
inbound_identify_push: Option<BoxFuture<'static, Result<Info, UpgradeError>>>,
remote_public_key: Option<PublicKey>,
inbound_identify_push: Option<BoxFuture<'static, Result<PartialInfo, UpgradeError>>>,
/// Pending events to yield.
events: SmallVec<
[ConnectionHandlerEvent<Either<Identify, Push<OutboundPush>>, (), Event, io::Error>; 4],
Expand Down Expand Up @@ -121,6 +123,7 @@ impl Handler {
) -> Self {
Self {
remote_peer_id,
remote_public_key: Default::default(),
inbound_identify_push: Default::default(),
events: SmallVec::new(),
pending_replies: FuturesUnordered::new(),
Expand Down Expand Up @@ -151,7 +154,7 @@ impl Handler {
let info = self.build_info();

self.pending_replies
.push(crate::protocol::send(substream, info).boxed());
.push(crate::protocol::send(substream, info.into()).boxed());
}
future::Either::Right(fut) => {
if self.inbound_identify_push.replace(fut).is_some() {
Expand All @@ -177,10 +180,19 @@ impl Handler {
match output {
future::Either::Left(remote_info) => {
self.update_supported_protocols_for_remote(&remote_info);
self.events
.push(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified(
remote_info,
)));

match self.update_remote_public_key(remote_info).try_into() {
Ok(info) => {
self.events.push(ConnectionHandlerEvent::NotifyBehaviour(
Event::Identified(info),
));
}
Err(error) => {
warn!(
"Failed to build remote info from inbound identify push stream from {:?}: {:?}",
self.remote_peer_id, error)
}
}
}
future::Either::Right(()) => self.events.push(ConnectionHandlerEvent::NotifyBehaviour(
Event::IdentificationPushed,
Expand Down Expand Up @@ -213,7 +225,17 @@ impl Handler {
}
}

fn update_supported_protocols_for_remote(&mut self, remote_info: &Info) {
fn update_remote_public_key(&mut self, mut remote_info: PartialInfo) -> PartialInfo {
if let Some(key) = &remote_info.public_key {
self.remote_public_key.replace(key.clone());
remote_info
} else {
remote_info.public_key = self.remote_public_key.clone();
remote_info
}
}

fn update_supported_protocols_for_remote(&mut self, remote_info: &PartialInfo) {
let new_remote_protocols = HashSet::from_iter(remote_info.protocols.clone());

let remote_added_protocols = new_remote_protocols
Expand Down Expand Up @@ -274,7 +296,10 @@ impl ConnectionHandler for Handler {
let info = self.build_info();
self.events
.push(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(Either::Right(Push::outbound(info)), ()),
protocol: SubstreamProtocol::new(
Either::Right(Push::outbound(info.into())),
(),
),
});
}
}
Expand Down Expand Up @@ -318,11 +343,20 @@ impl ConnectionHandler for Handler {
{
self.inbound_identify_push.take();

if let Ok(info) = res {
self.update_supported_protocols_for_remote(&info);
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(Event::Identified(
info,
)));
if let Ok(remote_info) = res {
self.update_supported_protocols_for_remote(&remote_info);
match self.update_remote_public_key(remote_info).try_into() {
thomaseizinger marked this conversation as resolved.
Show resolved Hide resolved
Ok(info) => {
return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(
Event::Identified(info),
));
}
Err(error) => {
warn!(
"Failed to build remote info from inbound identify stream from {:?}: {:?}",
self.remote_peer_id, error)
}
}
}
}

Expand Down Expand Up @@ -380,7 +414,7 @@ impl ConnectionHandler for Handler {
self.events
.push(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(
Either::Right(Push::outbound(info)),
Either::Right(Push::outbound(info.into())),
(),
),
});
Expand Down
110 changes: 84 additions & 26 deletions protocols/identify/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pub struct Identify;
#[derive(Debug, Clone)]
pub struct Push<T>(T);
pub struct InboundPush();
pub struct OutboundPush(Info);
pub struct OutboundPush(PartialInfo);

impl Push<InboundPush> {
pub fn inbound() -> Self {
Expand All @@ -58,12 +58,12 @@ impl Push<InboundPush> {
}

impl Push<OutboundPush> {
pub fn outbound(info: Info) -> Self {
pub fn outbound(info: PartialInfo) -> Self {
Push(OutboundPush(info))
}
}

/// Information of a peer sent in protocol messages.
/// Identify information of a peer sent in protocol messages.
#[derive(Debug, Clone)]
pub struct Info {
/// The public key of the local peer.
Expand All @@ -82,6 +82,53 @@ pub struct Info {
pub observed_addr: Multiaddr,
}

impl From<Info> for PartialInfo {
fn from(val: Info) -> Self {
PartialInfo {
public_key: Some(val.public_key),
protocol_version: Some(val.protocol_version),
agent_version: Some(val.agent_version),
listen_addrs: val.listen_addrs,
protocols: val.protocols,
observed_addr: Some(val.observed_addr),
}
}
}

/// Partial identify information of a peer sent in protocol messages.
/// Note that missing fields should be ignored, as peers may choose to send partial updates containing only the fields whose values have changed.
#[derive(Debug, Clone)]
pub struct PartialInfo {
pub public_key: Option<PublicKey>,
pub protocol_version: Option<String>,
pub agent_version: Option<String>,
pub listen_addrs: Vec<Multiaddr>,
pub protocols: Vec<StreamProtocol>,
pub observed_addr: Option<Multiaddr>,
}

#[derive(Debug)]
pub enum IdentifyError {
MissingRemotePublicKey,
}

impl TryFrom<PartialInfo> for Info {
type Error = IdentifyError;

fn try_from(info: PartialInfo) -> Result<Self, Self::Error> {
Ok(Info {
public_key: info
.public_key
.ok_or(IdentifyError::MissingRemotePublicKey)?,
protocol_version: info.protocol_version.unwrap_or_default(),
agent_version: info.agent_version.unwrap_or_default(),
listen_addrs: info.listen_addrs,
protocols: info.protocols,
observed_addr: info.observed_addr.unwrap_or_else(Multiaddr::empty),
})
}
}

impl UpgradeInfo for Identify {
type Info = StreamProtocol;
type InfoIter = iter::Once<Self::Info>;
Expand All @@ -105,7 +152,7 @@ impl<C> OutboundUpgrade<C> for Identify
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = Info;
type Output = PartialInfo;
thomaseizinger marked this conversation as resolved.
Show resolved Hide resolved
type Error = UpgradeError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Output, Self::Error>> + Send>>;

Expand All @@ -127,7 +174,7 @@ impl<C> InboundUpgrade<C> for Push<InboundPush>
where
C: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Output = BoxFuture<'static, Result<Info, UpgradeError>>;
type Output = BoxFuture<'static, Result<PartialInfo, UpgradeError>>; // @todo -- Different future result?
type Error = Void;
type Future = future::Ready<Result<Self::Output, Self::Error>>;

Expand All @@ -150,7 +197,7 @@ where
}
}

pub(crate) async fn send<T>(io: T, info: Info) -> Result<(), UpgradeError>
pub(crate) async fn send<T>(io: T, info: PartialInfo) -> Result<(), UpgradeError>
thomaseizinger marked this conversation as resolved.
Show resolved Hide resolved
where
T: AsyncWrite + Unpin,
{
Expand All @@ -162,14 +209,14 @@ where
.map(|addr| addr.to_vec())
.collect();

let pubkey_bytes = info.public_key.encode_protobuf();
let pubkey_bytes = info.public_key.map(|key| key.encode_protobuf());

let message = proto::Identify {
agentVersion: Some(info.agent_version),
protocolVersion: Some(info.protocol_version),
publicKey: Some(pubkey_bytes),
agentVersion: info.agent_version,
protocolVersion: info.protocol_version,
publicKey: pubkey_bytes,
listenAddrs: listen_addrs,
observedAddr: Some(info.observed_addr.to_vec()),
observedAddr: info.observed_addr.map(|addr| addr.to_vec()),
protocols: info.protocols.into_iter().map(|p| p.to_string()).collect(),
};

Expand All @@ -184,7 +231,7 @@ where
Ok(())
}

async fn recv<T>(socket: T) -> Result<Info, UpgradeError>
async fn recv<T>(socket: T) -> Result<PartialInfo, UpgradeError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Expand All @@ -207,7 +254,7 @@ where
Ok(info)
}

impl TryFrom<proto::Identify> for Info {
impl TryFrom<proto::Identify> for PartialInfo {
type Error = UpgradeError;

fn try_from(msg: proto::Identify) -> Result<Self, Self::Error> {
Expand All @@ -221,26 +268,37 @@ impl TryFrom<proto::Identify> for Info {
match parse_multiaddr(addr) {
Ok(a) => addrs.push(a),
Err(e) => {
debug!("Unable to parse multiaddr: {e:?}");
debug!("Unable to parse listen multiaddr: {e:?}");
}
}
}
addrs
};

let public_key = PublicKey::try_decode_protobuf(&msg.publicKey.unwrap_or_default())?;
let public_key = msg
.publicKey
.and_then(|key| match PublicKey::try_decode_protobuf(&key) {
Ok(k) => Some(k),
Err(e) => {
debug!("Unable to decode public key: {e:?}");
None
}
});

let observed_addr = msg
.observedAddr
.and_then(|bytes| match parse_multiaddr(bytes) {
Ok(a) => Some(a),
Err(e) => {
debug!("Unable to parse observed multiaddr: {e:?}");
None
}
});

let observed_addr = match parse_multiaddr(msg.observedAddr.unwrap_or_default()) {
Ok(a) => a,
Err(e) => {
debug!("Unable to parse multiaddr: {e:?}");
Multiaddr::empty()
}
};
let info = Info {
let info = PartialInfo {
public_key,
protocol_version: msg.protocolVersion.unwrap_or_default(),
agent_version: msg.agentVersion.unwrap_or_default(),
protocol_version: msg.protocolVersion,
agent_version: msg.agentVersion,
listen_addrs,
protocols: msg
.protocols
Expand Down Expand Up @@ -303,7 +361,7 @@ mod tests {
),
};

let info = Info::try_from(payload).expect("not to fail");
let info = PartialInfo::try_from(payload).expect("not to fail");

assert_eq!(info.listen_addrs, vec![valid_multiaddr])
}
Expand Down