diff --git a/__init__.py b/__init__.py index 7f573e7..d50998b 100644 --- a/__init__.py +++ b/__init__.py @@ -7,7 +7,7 @@ from lnbits.helpers import template_renderer from lnbits.tasks import catch_everything_and_restart -from .nostr.client.client import NostrClient as NostrClientLib +from .nostr.client.client import NostrClient db = Database("ext_nostrclient") @@ -22,19 +22,14 @@ scheduled_tasks: List[asyncio.Task] = [] -class NostrClient: - def __init__(self): - self.client: NostrClientLib = NostrClientLib(connect=False) - - -nostr = NostrClient() +nostr_client = NostrClient() def nostr_renderer(): return template_renderer(["nostrclient/templates"]) -from .tasks import check_relays, init_relays, subscribe_events +from .tasks import check_relays, init_relays, subscribe_events # noqa from .views import * # noqa from .views_api import * # noqa diff --git a/cbc.py b/cbc.py deleted file mode 100644 index 0d9e04f..0000000 --- a/cbc.py +++ /dev/null @@ -1,26 +0,0 @@ -from Cryptodome.Cipher import AES - -BLOCK_SIZE = 16 - - -class AESCipher(object): - """This class is compatible with crypto.createCipheriv('aes-256-cbc')""" - - def __init__(self, key=None): - self.key = key - - def pad(self, data): - length = BLOCK_SIZE - (len(data) % BLOCK_SIZE) - return data + (chr(length) * length).encode() - - def unpad(self, data): - return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))] - - def encrypt(self, plain_text): - cipher = AES.new(self.key, AES.MODE_CBC) - b = plain_text.encode("UTF-8") - return cipher.iv, cipher.encrypt(self.pad(b)) - - def decrypt(self, iv, enc_text): - cipher = AES.new(self.key, AES.MODE_CBC, iv=iv) - return self.unpad(cipher.decrypt(enc_text).decode("UTF-8")) diff --git a/crud.py b/crud.py index 780642d..05ca907 100644 --- a/crud.py +++ b/crud.py @@ -1,21 +1,17 @@ -from typing import List, Optional, Union - -import shortuuid - -from lnbits.helpers import urlsafe_short_hash +from typing import List from . import db -from .models import Relay, RelayList +from .models import Relay -async def get_relays() -> RelayList: - row = await db.fetchall("SELECT * FROM nostrclient.relays") - return RelayList(__root__=row) +async def get_relays() -> List[Relay]: + rows = await db.fetchall("SELECT * FROM nostrclient.relays") + return [Relay.from_row(r) for r in rows] async def add_relay(relay: Relay) -> None: await db.execute( - f""" + """ INSERT INTO nostrclient.relays ( id, url, diff --git a/migrations.py b/migrations.py index 5a30e45..73b9ed8 100644 --- a/migrations.py +++ b/migrations.py @@ -3,7 +3,7 @@ async def m001_initial(db): Initial nostrclient table. """ await db.execute( - f""" + """ CREATE TABLE nostrclient.relays ( id TEXT NOT NULL PRIMARY KEY, url TEXT NOT NULL, diff --git a/models.py b/models.py index 88651fc..e08ade3 100644 --- a/models.py +++ b/models.py @@ -1,9 +1,7 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional +from sqlite3 import Row +from typing import List, Optional -from fastapi import Request -from fastapi.param_functions import Query -from pydantic import BaseModel, Field +from pydantic import BaseModel from lnbits.helpers import urlsafe_short_hash @@ -14,7 +12,8 @@ class RelayStatus(BaseModel): error_counter: Optional[int] = 0 error_list: Optional[List] = [] notice_list: Optional[List] = [] - + + class Relay(BaseModel): id: Optional[str] = None url: Optional[str] = None @@ -28,33 +27,9 @@ def _init__(self): if not self.id: self.id = urlsafe_short_hash() - -class RelayList(BaseModel): - __root__: List[Relay] - - -class Event(BaseModel): - content: str - pubkey: str - created_at: Optional[int] - kind: int - tags: Optional[List[List[str]]] - sig: str - - -class Filter(BaseModel): - ids: Optional[List[str]] - kinds: Optional[List[int]] - authors: Optional[List[str]] - since: Optional[int] - until: Optional[int] - e: Optional[List[str]] = Field(alias="#e") - p: Optional[List[str]] = Field(alias="#p") - limit: Optional[int] - - -class Filters(BaseModel): - __root__: List[Filter] + @classmethod + def from_row(cls, row: Row) -> "Relay": + return cls(**dict(row)) class TestMessage(BaseModel): @@ -62,6 +37,7 @@ class TestMessage(BaseModel): reciever_public_key: str message: str + class TestMessageResponse(BaseModel): private_key: str public_key: str diff --git a/nostr/__init__.py b/nostr/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/nostr/bech32.py b/nostr/bech32.py index 61a92c4..0ae6c80 100644 --- a/nostr/bech32.py +++ b/nostr/bech32.py @@ -26,19 +26,22 @@ class Encoding(Enum): """Enumeration type to list the various supported encodings.""" + BECH32 = 1 BECH32M = 2 + CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" -BECH32M_CONST = 0x2bc830a3 +BECH32M_CONST = 0x2BC830A3 + def bech32_polymod(values): """Internal function that computes the Bech32 checksum.""" - generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3] + generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3] chk = 1 for value in values: top = chk >> 25 - chk = (chk & 0x1ffffff) << 5 ^ value + chk = (chk & 0x1FFFFFF) << 5 ^ value for i in range(5): chk ^= generator[i] if ((top >> i) & 1) else 0 return chk @@ -58,6 +61,7 @@ def bech32_verify_checksum(hrp, data): return Encoding.BECH32M return None + def bech32_create_checksum(hrp, data, spec): """Compute the checksum values given HRP and data.""" values = bech32_hrp_expand(hrp) + data @@ -69,26 +73,29 @@ def bech32_create_checksum(hrp, data, spec): def bech32_encode(hrp, data, spec): """Compute a Bech32 string given HRP and data values.""" combined = data + bech32_create_checksum(hrp, data, spec) - return hrp + '1' + ''.join([CHARSET[d] for d in combined]) + return hrp + "1" + "".join([CHARSET[d] for d in combined]) + def bech32_decode(bech): """Validate a Bech32/Bech32m string, and determine HRP and data.""" - if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or - (bech.lower() != bech and bech.upper() != bech)): + if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or ( + bech.lower() != bech and bech.upper() != bech + ): return (None, None, None) bech = bech.lower() - pos = bech.rfind('1') + pos = bech.rfind("1") if pos < 1 or pos + 7 > len(bech) or len(bech) > 90: return (None, None, None) - if not all(x in CHARSET for x in bech[pos+1:]): + if not all(x in CHARSET for x in bech[pos + 1 :]): return (None, None, None) hrp = bech[:pos] - data = [CHARSET.find(x) for x in bech[pos+1:]] + data = [CHARSET.find(x) for x in bech[pos + 1 :]] spec = bech32_verify_checksum(hrp, data) if spec is None: return (None, None, None) return (hrp, data[:-6], spec) + def convertbits(data, frombits, tobits, pad=True): """General power-of-2 base conversion.""" acc = 0 @@ -124,7 +131,12 @@ def decode(hrp, addr): return (None, None) if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: return (None, None) - if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M: + if ( + data[0] == 0 + and spec != Encoding.BECH32 + or data[0] != 0 + and spec != Encoding.BECH32M + ): return (None, None) return (data[0], decoded) diff --git a/nostr/client/client.py b/nostr/client/client.py index db07a06..4624ff3 100644 --- a/nostr/client/client.py +++ b/nostr/client/client.py @@ -1,25 +1,36 @@ import asyncio -from typing import List + +from loguru import logger from ..relay_manager import RelayManager class NostrClient: - relays = [ ] relay_manager = RelayManager() - def __init__(self, relays: List[str] = [], connect=True): - if len(relays): - self.relays = relays - if connect: - self.connect() + def __init__(self): + self.running = True + + def connect(self, relays): + for relay in relays: + try: + self.relay_manager.add_relay(relay) + except Exception as e: + logger.debug(e) + self.running = True - async def connect(self): - for relay in self.relays: - self.relay_manager.add_relay(relay) + def reconnect(self, relays): + self.relay_manager.remove_relays() + self.connect(relays) def close(self): - self.relay_manager.close_connections() + try: + self.relay_manager.close_all_subscriptions() + self.relay_manager.close_connections() + + self.running = False + except Exception as e: + logger.error(e) async def subscribe( self, @@ -27,18 +38,36 @@ async def subscribe( callback_notices_func=None, callback_eosenotices_func=None, ): - while True: + while self.running: + self._check_events(callback_events_func) + self._check_notices(callback_notices_func) + self._check_eos_notices(callback_eosenotices_func) + + await asyncio.sleep(0.2) + + def _check_events(self, callback_events_func=None): + try: while self.relay_manager.message_pool.has_events(): event_msg = self.relay_manager.message_pool.get_event() if callback_events_func: callback_events_func(event_msg) + except Exception as e: + logger.debug(e) + + def _check_notices(self, callback_notices_func=None): + try: while self.relay_manager.message_pool.has_notices(): event_msg = self.relay_manager.message_pool.get_notice() if callback_notices_func: callback_notices_func(event_msg) + except Exception as e: + logger.debug(e) + + def _check_eos_notices(self, callback_eosenotices_func=None): + try: while self.relay_manager.message_pool.has_eose_notices(): event_msg = self.relay_manager.message_pool.get_eose_notice() if callback_eosenotices_func: callback_eosenotices_func(event_msg) - - await asyncio.sleep(0.5) + except Exception as e: + logger.debug(e) diff --git a/nostr/delegation.py b/nostr/delegation.py deleted file mode 100644 index 94801f5..0000000 --- a/nostr/delegation.py +++ /dev/null @@ -1,32 +0,0 @@ -import time -from dataclasses import dataclass - - -@dataclass -class Delegation: - delegator_pubkey: str - delegatee_pubkey: str - event_kind: int - duration_secs: int = 30*24*60 # default to 30 days - signature: str = None # set in PrivateKey.sign_delegation - - @property - def expires(self) -> int: - return int(time.time()) + self.duration_secs - - @property - def conditions(self) -> str: - return f"kind={self.event_kind}&created_at<{self.expires}" - - @property - def delegation_token(self) -> str: - return f"nostr:delegation:{self.delegatee_pubkey}:{self.conditions}" - - def get_tag(self) -> list[str]: - """ Called by Event """ - return [ - "delegation", - self.delegator_pubkey, - self.conditions, - self.signature, - ] diff --git a/nostr/event.py b/nostr/event.py index 65b187d..a7d4f1d 100644 --- a/nostr/event.py +++ b/nostr/event.py @@ -122,6 +122,7 @@ def __post_init__(self): def id(self) -> str: if self.content is None: raise Exception( - "EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field" + "EncryptedDirectMessage `id` is undefined until its" + + " message is encrypted and stored in the `content` field" ) return super().id diff --git a/nostr/filter.py b/nostr/filter.py deleted file mode 100644 index f119079..0000000 --- a/nostr/filter.py +++ /dev/null @@ -1,134 +0,0 @@ -from collections import UserList -from typing import List - -from .event import Event, EventKind - - -class Filter: - """ - NIP-01 filtering. - - Explicitly supports "#e" and "#p" tag filters via `event_refs` and `pubkey_refs`. - - Arbitrary NIP-12 single-letter tag filters are also supported via `add_arbitrary_tag`. - If a particular single-letter tag gains prominence, explicit support should be - added. For example: - # arbitrary tag - filter.add_arbitrary_tag('t', [hashtags]) - - # promoted to explicit support - Filter(hashtag_refs=[hashtags]) - """ - - def __init__( - self, - event_ids: List[str] = None, - kinds: List[EventKind] = None, - authors: List[str] = None, - since: int = None, - until: int = None, - event_refs: List[ - str - ] = None, # the "#e" attr; list of event ids referenced in an "e" tag - pubkey_refs: List[ - str - ] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag - limit: int = None, - ) -> None: - self.event_ids = event_ids - self.kinds = kinds - self.authors = authors - self.since = since - self.until = until - self.event_refs = event_refs - self.pubkey_refs = pubkey_refs - self.limit = limit - - self.tags = {} - if self.event_refs: - self.add_arbitrary_tag("e", self.event_refs) - if self.pubkey_refs: - self.add_arbitrary_tag("p", self.pubkey_refs) - - def add_arbitrary_tag(self, tag: str, values: list): - """ - Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12 - single-letter tags. - """ - # NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#" - tag_key = tag if len(tag) > 1 else f"#{tag}" - self.tags[tag_key] = values - - def matches(self, event: Event) -> bool: - if self.event_ids is not None and event.id not in self.event_ids: - return False - if self.kinds is not None and event.kind not in self.kinds: - return False - if self.authors is not None and event.public_key not in self.authors: - return False - if self.since is not None and event.created_at < self.since: - return False - if self.until is not None and event.created_at > self.until: - return False - if (self.event_refs is not None or self.pubkey_refs is not None) and len( - event.tags - ) == 0: - return False - - if self.tags: - e_tag_identifiers = set([e_tag[0] for e_tag in event.tags]) - for f_tag, f_tag_values in self.tags.items(): - # Omit any NIP-01 or NIP-12 "#" chars on single-letter tags - f_tag = f_tag.replace("#", "") - - if f_tag not in e_tag_identifiers: - # Event is missing a tag type that we're looking for - return False - - # Multiple values within f_tag_values are treated as OR search; an Event - # needs to match only one. - # Note: an Event could have multiple entries of the same tag type - # (e.g. a reply to multiple people) so we have to check all of them. - match_found = False - for e_tag in event.tags: - if e_tag[0] == f_tag and e_tag[1] in f_tag_values: - match_found = True - break - if not match_found: - return False - - return True - - def to_json_object(self) -> dict: - res = {} - if self.event_ids is not None: - res["ids"] = self.event_ids - if self.kinds is not None: - res["kinds"] = self.kinds - if self.authors is not None: - res["authors"] = self.authors - if self.since is not None: - res["since"] = self.since - if self.until is not None: - res["until"] = self.until - if self.limit is not None: - res["limit"] = self.limit - if self.tags: - res.update(self.tags) - - return res - - -class Filters(UserList): - def __init__(self, initlist: "list[Filter]" = []) -> None: - super().__init__(initlist) - self.data: "list[Filter]" - - def match(self, event: Event): - for filter in self.data: - if filter.matches(event): - return True - return False - - def to_json_array(self) -> list: - return [filter.to_json_object() for filter in self.data] diff --git a/nostr/key.py b/nostr/key.py index 8089e11..3803650 100644 --- a/nostr/key.py +++ b/nostr/key.py @@ -1,6 +1,5 @@ import base64 import secrets -from hashlib import sha256 import secp256k1 from cffi import FFI @@ -8,7 +7,6 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from . import bech32 -from .delegation import Delegation from .event import EncryptedDirectMessage, Event, EventKind @@ -37,7 +35,7 @@ def from_npub(cls, npub: str): class PrivateKey: def __init__(self, raw_secret: bytes = None) -> None: - if not raw_secret is None: + if raw_secret is not None: self.raw_secret = raw_secret else: self.raw_secret = secrets.token_bytes(32) @@ -79,7 +77,10 @@ def encrypt_message(self, message: str, public_key_hex: str) -> str: encryptor = cipher.encryptor() encrypted_message = encryptor.update(padded_data) + encryptor.finalize() - return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}" + return ( + f"{base64.b64encode(encrypted_message).decode()}" + + f"?iv={base64.b64encode(iv).decode()}" + ) def encrypt_dm(self, dm: EncryptedDirectMessage) -> None: dm.content = self.encrypt_message( @@ -116,11 +117,6 @@ def sign_event(self, event: Event) -> None: event.public_key = self.public_key.hex() event.signature = self.sign_message_hash(bytes.fromhex(event.id)) - def sign_delegation(self, delegation: Delegation) -> None: - delegation.signature = self.sign_message_hash( - sha256(delegation.delegation_token.encode()).digest() - ) - def __eq__(self, other): return self.raw_secret == other.raw_secret diff --git a/nostr/message_pool.py b/nostr/message_pool.py index 02f7fd4..a3e6c5f 100644 --- a/nostr/message_pool.py +++ b/nostr/message_pool.py @@ -2,13 +2,15 @@ from queue import Queue from threading import Lock -from .event import Event from .message_type import RelayMessageType class EventMessage: - def __init__(self, event: Event, subscription_id: str, url: str) -> None: + def __init__( + self, event: str, event_id: str, subscription_id: str, url: str + ) -> None: self.event = event + self.event_id = event_id self.subscription_id = subscription_id self.url = url @@ -59,18 +61,16 @@ def _process_message(self, message: str, url: str): message_type = message_json[0] if message_type == RelayMessageType.EVENT: subscription_id = message_json[1] - e = message_json[2] - event = Event( - e["content"], - e["pubkey"], - e["created_at"], - e["kind"], - e["tags"], - e["sig"], - ) + event = message_json[2] + if "id" not in event: + return + event_id = event["id"] + with self.lock: - if not f"{subscription_id}_{event.id}" in self._unique_events: - self._accept_event(EventMessage(event, subscription_id, url)) + if f"{subscription_id}_{event_id}" not in self._unique_events: + self._accept_event( + EventMessage(json.dumps(event), event_id, subscription_id, url) + ) elif message_type == RelayMessageType.NOTICE: self.notices.put(NoticeMessage(message_json[1], url)) elif message_type == RelayMessageType.END_OF_STORED_EVENTS: @@ -78,10 +78,12 @@ def _process_message(self, message: str, url: str): def _accept_event(self, event_message: EventMessage): """ - Event uniqueness is considered per `subscription_id`. - The `subscription_id` is rewritten to be unique and it is the same accross relays. - The same event can come from different subscriptions (from the same client or from different ones). - Clients that have joined later should receive older events. + Event uniqueness is considered per `subscription_id`. The `subscription_id` is + rewritten to be unique and it is the same accross relays. The same event can + come from different subscriptions (from the same client or from different ones). + Clients that have joined later should receive older events. """ self.events.put(event_message) - self._unique_events.add(f"{event_message.subscription_id}_{event_message.event.id}") \ No newline at end of file + self._unique_events.add( + f"{event_message.subscription_id}_{event_message.event_id}" + ) diff --git a/nostr/relay.py b/nostr/relay.py index caacba0..b576cfa 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -2,43 +2,23 @@ import json import time from queue import Queue -from threading import Lock from typing import List from loguru import logger from websocket import WebSocketApp -from .event import Event -from .filter import Filters from .message_pool import MessagePool -from .message_type import RelayMessageType from .subscription import Subscription -class RelayPolicy: - def __init__(self, should_read: bool = True, should_write: bool = True) -> None: - self.should_read = should_read - self.should_write = should_write - - def to_json_object(self) -> dict[str, bool]: - return {"read": self.should_read, "write": self.should_write} - - class Relay: - def __init__( - self, - url: str, - policy: RelayPolicy, - message_pool: MessagePool, - subscriptions: dict[str, Subscription] = {}, - ) -> None: + def __init__(self, url: str, message_pool: MessagePool) -> None: self.url = url - self.policy = policy self.message_pool = message_pool - self.subscriptions = subscriptions self.connected: bool = False self.reconnect: bool = True self.shutdown: bool = False + self.error_counter: int = 0 self.error_threshold: int = 100 self.error_list: List[str] = [] @@ -47,12 +27,10 @@ def __init__( self.num_received_events: int = 0 self.num_sent_events: int = 0 self.num_subscriptions: int = 0 - self.ssl_options: dict = {} - self.proxy: dict = {} - self.lock = Lock() + self.queue = Queue() - def connect(self, ssl_options: dict = None, proxy: dict = None): + def connect(self): self.ws = WebSocketApp( self.url, on_open=self._on_open, @@ -62,19 +40,14 @@ def connect(self, ssl_options: dict = None, proxy: dict = None): on_ping=self._on_ping, on_pong=self._on_pong, ) - self.ssl_options = ssl_options - self.proxy = proxy if not self.connected: - self.ws.run_forever( - sslopt=ssl_options, - http_proxy_host=None if proxy is None else proxy.get("host"), - http_proxy_port=None if proxy is None else proxy.get("port"), - proxy_type=None if proxy is None else proxy.get("type"), - ping_interval=5, - ) + self.ws.run_forever(ping_interval=10) def close(self): - self.ws.close() + try: + self.ws.close() + except Exception as e: + logger.warning(f"[Relay: {self.url}] Failed to close websocket: {e}") self.connected = False self.shutdown = True @@ -90,10 +63,9 @@ def ping(self): def publish(self, message: str): self.queue.put(message) - def publish_subscriptions(self): - for _, subscription in self.subscriptions.items(): - s = subscription.to_json_object() - json_str = json.dumps(["REQ", s["id"], s["filters"][0]]) + def publish_subscriptions(self, subscriptions: List[Subscription] = []): + for s in subscriptions: + json_str = json.dumps(["REQ", s.id] + s.filters) self.publish(json_str) async def queue_worker(self): @@ -103,55 +75,44 @@ async def queue_worker(self): message = self.queue.get(timeout=1) self.num_sent_events += 1 self.ws.send(message) - except: + except Exception as _: pass else: await asyncio.sleep(1) - - if self.shutdown: - logger.warning(f"Closing queue worker for '{self.url}'.") - break - def add_subscription(self, id, filters: Filters): - with self.lock: - self.subscriptions[id] = Subscription(id, filters) + if self.shutdown: + logger.warning(f"[Relay: {self.url}] Closing queue worker.") + return def close_subscription(self, id: str) -> None: - with self.lock: - self.subscriptions.pop(id) + try: self.publish(json.dumps(["CLOSE", id])) - - def to_json_object(self) -> dict: - return { - "url": self.url, - "policy": self.policy.to_json_object(), - "subscriptions": [ - subscription.to_json_object() - for subscription in self.subscriptions.values() - ], - } + except Exception as e: + logger.debug(f"[Relay: {self.url}] Failed to close subscription: {e}") def add_notice(self, notice: str): - self.notice_list = ([notice] + self.notice_list)[:20] + self.notice_list = [notice] + self.notice_list def _on_open(self, _): - logger.info(f"Connected to relay: '{self.url}'.") + logger.info(f"[Relay: {self.url}] Connected.") self.connected = True - + self.shutdown = False + def _on_close(self, _, status_code, message): - logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.") + logger.warning( + f"[Relay: {self.url}] Connection closed." + + f" Status: '{status_code}'. Message: '{message}'." + ) self.close() def _on_message(self, _, message: str): - if self._is_valid_message(message): - self.num_received_events += 1 - self.message_pool.add_message(message, self.url) + self.num_received_events += 1 + self.message_pool.add_message(message, self.url) def _on_error(self, _, error): - logger.warning(f"Relay error: '{str(error)}'") + logger.warning(f"[Relay: {self.url}] Error: '{str(error)}'") self._append_error_message(str(error)) - self.connected = False - self.error_counter += 1 + self.close() def _on_ping(self, *_): return @@ -159,65 +120,7 @@ def _on_ping(self, *_): def _on_pong(self, *_): return - def _is_valid_message(self, message: str) -> bool: - message = message.strip("\n") - if not message or message[0] != "[" or message[-1] != "]": - return False - - message_json = json.loads(message) - message_type = message_json[0] - - if not RelayMessageType.is_valid(message_type): - return False - - if message_type == RelayMessageType.EVENT: - return self._is_valid_event_message(message_json) - - if message_type == RelayMessageType.COMMAND_RESULT: - return self._is_valid_command_result_message(message, message_json) - - return True - - def _is_valid_event_message(self, message_json): - if not len(message_json) == 3: - return False - - subscription_id = message_json[1] - with self.lock: - if subscription_id not in self.subscriptions: - return False - - e = message_json[2] - event = Event( - e["content"], - e["pubkey"], - e["created_at"], - e["kind"], - e["tags"], - e["sig"], - ) - if not event.verify(): - return False - - with self.lock: - subscription = self.subscriptions[subscription_id] - - if subscription.filters and not subscription.filters.match(event): - return False - - return True - - def _is_valid_command_result_message(self, message, message_json): - if not len(message_json) < 3: - return False - - if message_json[2] != True: - logger.warning(f"Relay '{self.url}' negative command result: '{message}'") - self._append_error_message(message) - return False - - return True - def _append_error_message(self, message): - self.error_list = ([message] + self.error_list)[:20] - self.last_error_date = int(time.time()) \ No newline at end of file + self.error_counter += 1 + self.error_list = [message] + self.error_list + self.last_error_date = int(time.time()) diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index f639fb0..ff7ca9c 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -1,21 +1,15 @@ - import asyncio -import ssl import threading import time +from typing import List from loguru import logger -from .filter import Filters from .message_pool import MessagePool, NoticeMessage -from .relay import Relay, RelayPolicy +from .relay import Relay from .subscription import Subscription -class RelayException(Exception): - pass - - class RelayManager: def __init__(self) -> None: self.relays: dict[str, Relay] = {} @@ -25,72 +19,97 @@ def __init__(self) -> None: self._cached_subscriptions: dict[str, Subscription] = {} self._subscriptions_lock = threading.Lock() - def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay: + def add_relay(self, url: str) -> Relay: if url in list(self.relays.keys()): - return - - with self._subscriptions_lock: - subscriptions = self._cached_subscriptions.copy() + logger.debug(f"Relay '{url}' already present.") + return self.relays[url] - policy = RelayPolicy(read, write) - relay = Relay(url, policy, self.message_pool, subscriptions) + relay = Relay(url, self.message_pool) self.relays[url] = relay - self._open_connection( - relay, - {"cert_reqs": ssl.CERT_NONE} - ) # NOTE: This disables ssl certificate verification + self._open_connection(relay) - relay.publish_subscriptions() + relay.publish_subscriptions(list(self._cached_subscriptions.values())) return relay def remove_relay(self, url: str): - self.relays[url].close() - self.relays.pop(url) - self.threads[url].join(timeout=5) - self.threads.pop(url) - self.queue_threads[url].join(timeout=5) - self.queue_threads.pop(url) - - - def add_subscription(self, id: str, filters: Filters): + try: + self.relays[url].close() + except Exception as e: + logger.debug(e) + + if url in self.relays: + self.relays.pop(url) + + try: + self.threads[url].join(timeout=5) + except Exception as e: + logger.debug(e) + + if url in self.threads: + self.threads.pop(url) + + try: + self.queue_threads[url].join(timeout=5) + except Exception as e: + logger.debug(e) + + if url in self.queue_threads: + self.queue_threads.pop(url) + + def remove_relays(self): + relay_urls = list(self.relays.keys()) + for url in relay_urls: + self.remove_relay(url) + + def add_subscription(self, id: str, filters: List[str]): + s = Subscription(id, filters) with self._subscriptions_lock: - self._cached_subscriptions[id] = Subscription(id, filters) + self._cached_subscriptions[id] = s for relay in self.relays.values(): - relay.add_subscription(id, filters) + relay.publish_subscriptions([s]) def close_subscription(self, id: str): - with self._subscriptions_lock: - self._cached_subscriptions.pop(id) + try: + with self._subscriptions_lock: + if id in self._cached_subscriptions: + self._cached_subscriptions.pop(id) - for relay in self.relays.values(): - relay.close_subscription(id) + for relay in self.relays.values(): + relay.close_subscription(id) + except Exception as e: + logger.debug(e) + + def close_subscriptions(self, subscriptions: List[str]): + for id in subscriptions: + self.close_subscription(id) + + def close_all_subscriptions(self): + all_subscriptions = list(self._cached_subscriptions.keys()) + self.close_subscriptions(all_subscriptions) def check_and_restart_relays(self): stopped_relays = [r for r in self.relays.values() if r.shutdown] for relay in stopped_relays: self._restart_relay(relay) - def close_connections(self): for relay in self.relays.values(): relay.close() def publish_message(self, message: str): for relay in self.relays.values(): - if relay.policy.should_write: - relay.publish(message) + relay.publish(message) def handle_notice(self, notice: NoticeMessage): relay = next((r for r in self.relays.values() if r.url == notice.url)) if relay: relay.add_notice(notice.content) - def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None): + def _open_connection(self, relay: Relay): self.threads[relay.url] = threading.Thread( target=relay.connect, - args=(ssl_options, proxy), name=f"{relay.url}-thread", daemon=True, ) @@ -98,7 +117,7 @@ def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = def wrap_async_queue_worker(): asyncio.run(relay.queue_worker()) - + self.queue_threads[relay.url] = threading.Thread( target=wrap_async_queue_worker, name=f"{relay.url}-queue", @@ -108,14 +127,16 @@ def wrap_async_queue_worker(): def _restart_relay(self, relay: Relay): time_since_last_error = time.time() - relay.last_error_date - - min_wait_time = min(60 * relay.error_counter, 60 * 60 * 24) # try at least once a day + + min_wait_time = min( + 60 * relay.error_counter, 60 * 60 + ) # try at least once an hour if time_since_last_error < min_wait_time: return - + logger.info(f"Restarting connection to relay '{relay.url}'") self.remove_relay(relay.url) new_relay = self.add_relay(relay.url) new_relay.error_counter = relay.error_counter - new_relay.error_list = relay.error_list \ No newline at end of file + new_relay.error_list = relay.error_list diff --git a/nostr/subscription.py b/nostr/subscription.py index 76da0af..a75c1a1 100644 --- a/nostr/subscription.py +++ b/nostr/subscription.py @@ -1,13 +1,7 @@ -from .filter import Filters +from typing import List class Subscription: - def __init__(self, id: str, filters: Filters=None) -> None: + def __init__(self, id: str, filters: List[str] = None) -> None: self.id = id self.filters = filters - - def to_json_object(self): - return { - "id": self.id, - "filters": self.filters.to_json_array() - } diff --git a/router.py b/router.py index cc0a380..e6ccdef 100644 --- a/router.py +++ b/router.py @@ -1,42 +1,61 @@ import asyncio import json -from typing import List, Union +from typing import Dict, List -from fastapi import WebSocketDisconnect +from fastapi import WebSocket, WebSocketDisconnect from loguru import logger from lnbits.helpers import urlsafe_short_hash -from . import nostr -from .models import Event, Filter -from .nostr.filter import Filter as NostrFilter -from .nostr.filter import Filters as NostrFilters -from .nostr.message_pool import EndOfStoredEventsMessage, NoticeMessage +from . import nostr_client +from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage class NostrRouter: - - received_subscription_events: dict[str, list[Event]] = {} + received_subscription_events: dict[str, List[EventMessage]] = {} received_subscription_notices: list[NoticeMessage] = [] received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {} - def __init__(self, websocket): - self.subscriptions: List[str] = [] + def __init__(self, websocket: WebSocket): self.connected: bool = True - self.websocket = websocket + self.websocket: WebSocket = websocket self.tasks: List[asyncio.Task] = [] - self.original_subscription_ids = {} - - async def client_to_nostr(self): - """Receives requests / data from the client and forwards it to relays. If the - request was a subscription/filter, registers it with the nostr client lib. - Remembers the subscription id so we can send back responses from the relay to this - client in `nostr_to_client`""" - while True: + self.original_subscription_ids: Dict[str, str] = {} + + @property + def subscriptions(self) -> List[str]: + return list(self.original_subscription_ids.keys()) + + def start(self): + self.connected = True + self.tasks.append(asyncio.create_task(self._client_to_nostr())) + self.tasks.append(asyncio.create_task(self._nostr_to_client())) + + async def stop(self): + nostr_client.relay_manager.close_subscriptions(self.subscriptions) + self.connected = False + + for t in self.tasks: + try: + t.cancel() + except Exception as _: + pass + + try: + await self.websocket.close() + except Exception as _: + pass + + async def _client_to_nostr(self): + """ + Receives requests / data from the client and forwards it to relays. + """ + while self.connected: try: json_str = await self.websocket.receive_text() - except WebSocketDisconnect: - self.connected = False + except WebSocketDisconnect as e: + logger.debug(e) + await self.stop() break try: @@ -44,15 +63,9 @@ async def client_to_nostr(self): except Exception as e: logger.debug(f"Failed to handle client message: '{str(e)}'.") - - async def nostr_to_client(self): - """Sends responses from relays back to the client. Polls the subscriptions of this client - stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which - is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs - the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id - that we had previously rewritten in order to avoid collisions when multiple clients use the same id. - """ - while True and self.connected: + async def _nostr_to_client(self): + """Sends responses from relays back to the client.""" + while self.connected: try: await self._handle_subscriptions() self._handle_notices() @@ -61,24 +74,6 @@ async def nostr_to_client(self): await asyncio.sleep(0.1) - async def start(self): - self.tasks.append(asyncio.create_task(self.client_to_nostr())) - self.tasks.append(asyncio.create_task(self.nostr_to_client())) - - async def stop(self): - for t in self.tasks: - try: - t.cancel() - except: - pass - - for s in self.subscriptions: - try: - nostr.client.relay_manager.close_subscription(s) - except: - pass - self.connected = False - async def _handle_subscriptions(self): for s in self.subscriptions: if s in NostrRouter.received_subscription_events: @@ -86,8 +81,6 @@ async def _handle_subscriptions(self): if s in NostrRouter.received_subscription_eosenotices: await self._handle_received_subscription_eosenotices(s) - - async def _handle_received_subscription_eosenotices(self, s): try: if s not in self.original_subscription_ids: @@ -95,7 +88,7 @@ async def _handle_received_subscription_eosenotices(self, s): s_original = self.original_subscription_ids[s] event_to_forward = ["EOSE", s_original] del NostrRouter.received_subscription_eosenotices[s] - + await self.websocket.send_text(json.dumps(event_to_forward)) except Exception as e: logger.debug(e) @@ -104,97 +97,62 @@ async def _handle_received_subscription_events(self, s): try: if s not in NostrRouter.received_subscription_events: return + while len(NostrRouter.received_subscription_events[s]): - my_event = NostrRouter.received_subscription_events[s].pop(0) - # event.to_message() does not include the subscription ID, we have to add it manually - event_json = { - "id": my_event.id, - "pubkey": my_event.public_key, - "created_at": my_event.created_at, - "kind": my_event.kind, - "tags": my_event.tags, - "content": my_event.content, - "sig": my_event.signature, - } + event_message = NostrRouter.received_subscription_events[s].pop(0) + event_json = event_message.event # this reconstructs the original response from the relay # reconstruct original subscription id s_original = self.original_subscription_ids[s] - event_to_forward = ["EVENT", s_original, event_json] - await self.websocket.send_text(json.dumps(event_to_forward)) + event_to_forward = f"""["EVENT", "{s_original}", {event_json}]""" + await self.websocket.send_text(event_to_forward) except Exception as e: - logger.debug(e) + logger.debug(e) # there are 2900 errors here def _handle_notices(self): while len(NostrRouter.received_subscription_notices): my_event = NostrRouter.received_subscription_notices.pop(0) - # note: we don't send it to the user because we don't know who should receive it - logger.info(f"Relay ('{my_event.url}') notice: '{my_event.content}']") - nostr.client.relay_manager.handle_notice(my_event) - - - - def _marshall_nostr_filters(self, data: Union[dict, list]): - filters = data if isinstance(data, list) else [data] - filters = [Filter.parse_obj(f) for f in filters] - filter_list: list[NostrFilter] = [] - for filter in filters: - filter_list.append( - NostrFilter( - event_ids=filter.ids, # type: ignore - kinds=filter.kinds, # type: ignore - authors=filter.authors, # type: ignore - since=filter.since, # type: ignore - until=filter.until, # type: ignore - event_refs=filter.e, # type: ignore - pubkey_refs=filter.p, # type: ignore - limit=filter.limit, # type: ignore - ) - ) - return NostrFilters(filter_list) + logger.info(f"[Relay '{my_event.url}'] Notice: '{my_event.content}']") + # Note: we don't send it to the user because + # we don't know who should receive it + nostr_client.relay_manager.handle_notice(my_event) async def _handle_client_to_nostr(self, json_str): - """Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will - register the subscription in the nostr client library that we're using so we can - receive the callbacks on it later. Will rewrite the subscription id since we expect - multiple clients to use the router and want to avoid subscription id collisions - """ - json_data = json.loads(json_str) - assert len(json_data) - + assert len(json_data), "Bad JSON array" if json_data[0] == "REQ": self._handle_client_req(json_data) return - + if json_data[0] == "CLOSE": self._handle_client_close(json_data[1]) return if json_data[0] == "EVENT": - nostr.client.relay_manager.publish_message(json_str) + nostr_client.relay_manager.publish_message(json_str) return def _handle_client_req(self, json_data): subscription_id = json_data[1] subscription_id_rewritten = urlsafe_short_hash() self.original_subscription_ids[subscription_id_rewritten] = subscription_id - fltr = json_data[2:] - filters = self._marshall_nostr_filters(fltr) + filters = json_data[2:] - nostr.client.relay_manager.add_subscription( - subscription_id_rewritten, filters - ) - request_rewritten = json.dumps([json_data[0], subscription_id_rewritten] + fltr) - - self.subscriptions.append(subscription_id_rewritten) - nostr.client.relay_manager.publish_message(request_rewritten) + nostr_client.relay_manager.add_subscription(subscription_id_rewritten, filters) def _handle_client_close(self, subscription_id): - subscription_id_rewritten = next((k for k, v in self.original_subscription_ids.items() if v == subscription_id), None) + subscription_id_rewritten = next( + ( + k + for k, v in self.original_subscription_ids.items() + if v == subscription_id + ), + None, + ) if subscription_id_rewritten: self.original_subscription_ids.pop(subscription_id_rewritten) - nostr.client.relay_manager.close_subscription(subscription_id_rewritten) + nostr_client.relay_manager.close_subscription(subscription_id_rewritten) else: logger.debug(f"Failed to unsubscribe from '{subscription_id}.'") diff --git a/tasks.py b/tasks.py index 4c316bc..69aa33c 100644 --- a/tasks.py +++ b/tasks.py @@ -3,75 +3,69 @@ from loguru import logger -from . import nostr +from . import nostr_client from .crud import get_relays from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage -from .router import NostrRouter, nostr +from .router import NostrRouter async def init_relays(): - # reinitialize the entire client - nostr.__init__() # get relays from db relays = await get_relays() # set relays and connect to them - nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url])) - await nostr.client.connect() + valid_relays = list(set([r.url for r in relays if r.url])) + + nostr_client.reconnect(valid_relays) async def check_relays(): - """ Check relays that have been disconnected """ + """Check relays that have been disconnected""" while True: try: await asyncio.sleep(20) - nostr.client.relay_manager.check_and_restart_relays() + nostr_client.relay_manager.check_and_restart_relays() except Exception as e: logger.warning(f"Cannot restart relays: '{str(e)}'.") - + async def subscribe_events(): - while not any([r.connected for r in nostr.client.relay_manager.relays.values()]): + while not any([r.connected for r in nostr_client.relay_manager.relays.values()]): await asyncio.sleep(2) def callback_events(eventMessage: EventMessage): - if eventMessage.subscription_id in NostrRouter.received_subscription_events: - # do not add duplicate events (by event id) - if eventMessage.event.id in set( - [ - e.id - for e in NostrRouter.received_subscription_events[eventMessage.subscription_id] - ] - ): - return - - NostrRouter.received_subscription_events[eventMessage.subscription_id].append( - eventMessage.event - ) - else: - NostrRouter.received_subscription_events[eventMessage.subscription_id] = [ - eventMessage.event - ] - return + sub_id = eventMessage.subscription_id + if sub_id not in NostrRouter.received_subscription_events: + NostrRouter.received_subscription_events[sub_id] = [eventMessage] + return + + # do not add duplicate events (by event id) + ids = set( + [e.event_id for e in NostrRouter.received_subscription_events[sub_id]] + ) + if eventMessage.event_id in ids: + return + + NostrRouter.received_subscription_events[sub_id].append(eventMessage) def callback_notices(noticeMessage: NoticeMessage): if noticeMessage not in NostrRouter.received_subscription_notices: NostrRouter.received_subscription_notices.append(noticeMessage) - return def callback_eose_notices(eventMessage: EndOfStoredEventsMessage): - if eventMessage.subscription_id not in NostrRouter.received_subscription_eosenotices: - NostrRouter.received_subscription_eosenotices[ - eventMessage.subscription_id - ] = eventMessage + sub_id = eventMessage.subscription_id + if sub_id in NostrRouter.received_subscription_eosenotices: + return - return + NostrRouter.received_subscription_eosenotices[sub_id] = eventMessage def wrap_async_subscribe(): - asyncio.run(nostr.client.subscribe( - callback_events, - callback_notices, - callback_eose_notices, - )) + asyncio.run( + nostr_client.subscribe( + callback_events, + callback_notices, + callback_eose_notices, + ) + ) t = threading.Thread( target=wrap_async_subscribe, diff --git a/templates/nostrclient/index.html b/templates/nostrclient/index.html index a0c5999..db0f98e 100644 --- a/templates/nostrclient/index.html +++ b/templates/nostrclient/index.html @@ -6,13 +6,30 @@
- +
- - - + + @@ -29,18 +46,36 @@
Nostrclient
- +
- +