From 4c761c4c745beabfb16a685c1211c752439732b3 Mon Sep 17 00:00:00 2001 From: Rusty Russell Date: Thu, 26 Aug 2021 11:32:17 +0930 Subject: [PATCH] pyln-client/gossmap: more fixes, make mypy happier. Mainly fixing type annotations, but some real fixes: 1. GossmapHalfchannel.from_str() should be a classmethod. 2. update_channel had weird, unusable default values (fields can't be NULL, since we use it below). Signed-off-by: Rusty Russell --- contrib/pyln-client/pyln/client/gossmap.py | 31 ++++++++++--------- .../pyln-proto/pyln/proto/message/message.py | 2 +- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/contrib/pyln-client/pyln/client/gossmap.py b/contrib/pyln-client/pyln/client/gossmap.py index 3dc412d89a13..abda4c26b1e1 100755 --- a/contrib/pyln-client/pyln/client/gossmap.py +++ b/contrib/pyln-client/pyln/client/gossmap.py @@ -5,7 +5,7 @@ node_announcement, gossip_store_channel_amount) from pyln.proto import ShortChannelId, PublicKey -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union, cast import io import struct @@ -33,7 +33,7 @@ def __init__(self, buf: bytes): class GossmapHalfchannel(object): """One direction of a GossmapChannel.""" - def __init__(self, channel: GossmapChannel, direction: int, + def __init__(self, channel: 'GossmapChannel', direction: int, timestamp: int, cltv_expiry_delta: int, htlc_minimum_msat: int, htlc_maximum_msat: int, fee_base_msat: int, fee_proportional_millionths: int): @@ -71,12 +71,13 @@ def __hash__(self): def __repr__(self): return "GossmapNodeId[{}]".format(self.nodeid.hex()) - def from_str(self, s: str): + @classmethod + def from_str(cls, s: str): if s.startswith('0x'): s = s[2:] if len(s) != 67: raise ValueError(f"{s} is not a valid hexstring of a node_id") - return GossmapNodeId(bytes.fromhex(s)) + return cls(bytes.fromhex(s)) class GossmapChannel(object): @@ -97,14 +98,14 @@ def __init__(self, self.updates_fields: List[Optional[Dict[str, Any]]] = [None, None] self.updates_offset: List[Optional[int]] = [None, None] self.satoshis = None - self.half_channels: List[GossmapHalfchannel] = [None, None] + self.half_channels: List[Optional[GossmapHalfchannel]] = [None, None] def update_channel(self, direction: int, - fields: List[Optional[Dict[str, Any]]] = [None, None], - off: List[Optional[int]] = [None, None]): + fields: Dict[str, Any], + off: int): self.updates_fields[direction] = fields - self.updates_offset = off + self.updates_offset[direction] = off half = GossmapHalfchannel(self, direction, fields['timestamp'], @@ -132,8 +133,8 @@ class GossmapNode(object): """ def __init__(self, node_id: GossmapNodeId): self.announce_fields: Optional[Dict[str, Any]] = None - self.announce_offset = None - self.channels = [] + self.announce_offset: Optional[int] = None + self.channels: List[GossmapChannel] = [] self.node_id = node_id def __repr__(self): @@ -148,10 +149,10 @@ def __init__(self, store_filename: str = "gossip_store"): self.store_buf = bytes() self.nodes: Dict[GossmapNodeId, GossmapNode] = {} self.channels: Dict[ShortChannelId, GossmapChannel] = {} - self._last_scid: str = None + self._last_scid: Optional[str] = None version = self.store_file.read(1) if version[0] != GOSSIP_STORE_VERSION: - raise ValueError("Invalid gossip store version {}".format(version)) + raise ValueError("Invalid gossip store version {}".format(int(version))) self.bytes_read = 1 self.refresh() @@ -205,11 +206,11 @@ def get_channel(self, short_channel_id: ShortChannelId): short_channel_id = ShortChannelId.from_str(short_channel_id) return self.channels.get(short_channel_id) - def get_node(self, node_id: GossmapNodeId): + def get_node(self, node_id: Union[GossmapNodeId, str]): """ Resolves a node by its public key node_id """ - if type(node_id) == str: + if isinstance(node_id, str): node_id = GossmapNodeId.from_str(node_id) - return self.nodes.get(node_id) + return self.nodes.get(cast(GossmapNodeId, node_id)) def update_channel(self, rec: bytes, off: int): fields = channel_update.read(io.BytesIO(rec[2:]), {}) diff --git a/contrib/pyln-proto/pyln/proto/message/message.py b/contrib/pyln-proto/pyln/proto/message/message.py index eb3d7ec556e5..127755a993f2 100644 --- a/contrib/pyln-proto/pyln/proto/message/message.py +++ b/contrib/pyln-proto/pyln/proto/message/message.py @@ -310,7 +310,7 @@ def write(self, io_out: BufferedIOBase, v: Dict[str, Any], otherfields: Dict[str f.fieldtype.write(io_out, val, otherfields) def read(self, io_in: BufferedIOBase, otherfields: Dict[str, Any]) -> Optional[Dict[str, Any]]: - vals = {} + vals: Dict[str, Any] = {} for field in self.fields: val = field.fieldtype.read(io_in, vals) if val is None: