diff --git a/__init__.py b/__init__.py
index 7f573e7..d50998b 100644
--- a/__init__.py
+++ b/__init__.py
@@ -7,7 +7,7 @@
from lnbits.helpers import template_renderer
from lnbits.tasks import catch_everything_and_restart
-from .nostr.client.client import NostrClient as NostrClientLib
+from .nostr.client.client import NostrClient
db = Database("ext_nostrclient")
@@ -22,19 +22,14 @@
scheduled_tasks: List[asyncio.Task] = []
-class NostrClient:
- def __init__(self):
- self.client: NostrClientLib = NostrClientLib(connect=False)
-
-
-nostr = NostrClient()
+nostr_client = NostrClient()
def nostr_renderer():
return template_renderer(["nostrclient/templates"])
-from .tasks import check_relays, init_relays, subscribe_events
+from .tasks import check_relays, init_relays, subscribe_events # noqa
from .views import * # noqa
from .views_api import * # noqa
diff --git a/cbc.py b/cbc.py
deleted file mode 100644
index 0d9e04f..0000000
--- a/cbc.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from Cryptodome.Cipher import AES
-
-BLOCK_SIZE = 16
-
-
-class AESCipher(object):
- """This class is compatible with crypto.createCipheriv('aes-256-cbc')"""
-
- def __init__(self, key=None):
- self.key = key
-
- def pad(self, data):
- length = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
- return data + (chr(length) * length).encode()
-
- def unpad(self, data):
- return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))]
-
- def encrypt(self, plain_text):
- cipher = AES.new(self.key, AES.MODE_CBC)
- b = plain_text.encode("UTF-8")
- return cipher.iv, cipher.encrypt(self.pad(b))
-
- def decrypt(self, iv, enc_text):
- cipher = AES.new(self.key, AES.MODE_CBC, iv=iv)
- return self.unpad(cipher.decrypt(enc_text).decode("UTF-8"))
diff --git a/crud.py b/crud.py
index 780642d..05ca907 100644
--- a/crud.py
+++ b/crud.py
@@ -1,21 +1,17 @@
-from typing import List, Optional, Union
-
-import shortuuid
-
-from lnbits.helpers import urlsafe_short_hash
+from typing import List
from . import db
-from .models import Relay, RelayList
+from .models import Relay
-async def get_relays() -> RelayList:
- row = await db.fetchall("SELECT * FROM nostrclient.relays")
- return RelayList(__root__=row)
+async def get_relays() -> List[Relay]:
+ rows = await db.fetchall("SELECT * FROM nostrclient.relays")
+ return [Relay.from_row(r) for r in rows]
async def add_relay(relay: Relay) -> None:
await db.execute(
- f"""
+ """
INSERT INTO nostrclient.relays (
id,
url,
diff --git a/migrations.py b/migrations.py
index 5a30e45..73b9ed8 100644
--- a/migrations.py
+++ b/migrations.py
@@ -3,7 +3,7 @@ async def m001_initial(db):
Initial nostrclient table.
"""
await db.execute(
- f"""
+ """
CREATE TABLE nostrclient.relays (
id TEXT NOT NULL PRIMARY KEY,
url TEXT NOT NULL,
diff --git a/models.py b/models.py
index 88651fc..e08ade3 100644
--- a/models.py
+++ b/models.py
@@ -1,9 +1,7 @@
-from dataclasses import dataclass
-from typing import Dict, List, Optional
+from sqlite3 import Row
+from typing import List, Optional
-from fastapi import Request
-from fastapi.param_functions import Query
-from pydantic import BaseModel, Field
+from pydantic import BaseModel
from lnbits.helpers import urlsafe_short_hash
@@ -14,7 +12,8 @@ class RelayStatus(BaseModel):
error_counter: Optional[int] = 0
error_list: Optional[List] = []
notice_list: Optional[List] = []
-
+
+
class Relay(BaseModel):
id: Optional[str] = None
url: Optional[str] = None
@@ -28,33 +27,9 @@ def _init__(self):
if not self.id:
self.id = urlsafe_short_hash()
-
-class RelayList(BaseModel):
- __root__: List[Relay]
-
-
-class Event(BaseModel):
- content: str
- pubkey: str
- created_at: Optional[int]
- kind: int
- tags: Optional[List[List[str]]]
- sig: str
-
-
-class Filter(BaseModel):
- ids: Optional[List[str]]
- kinds: Optional[List[int]]
- authors: Optional[List[str]]
- since: Optional[int]
- until: Optional[int]
- e: Optional[List[str]] = Field(alias="#e")
- p: Optional[List[str]] = Field(alias="#p")
- limit: Optional[int]
-
-
-class Filters(BaseModel):
- __root__: List[Filter]
+ @classmethod
+ def from_row(cls, row: Row) -> "Relay":
+ return cls(**dict(row))
class TestMessage(BaseModel):
@@ -62,6 +37,7 @@ class TestMessage(BaseModel):
reciever_public_key: str
message: str
+
class TestMessageResponse(BaseModel):
private_key: str
public_key: str
diff --git a/nostr/__init__.py b/nostr/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/nostr/bech32.py b/nostr/bech32.py
index 61a92c4..0ae6c80 100644
--- a/nostr/bech32.py
+++ b/nostr/bech32.py
@@ -26,19 +26,22 @@
class Encoding(Enum):
"""Enumeration type to list the various supported encodings."""
+
BECH32 = 1
BECH32M = 2
+
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
-BECH32M_CONST = 0x2bc830a3
+BECH32M_CONST = 0x2BC830A3
+
def bech32_polymod(values):
"""Internal function that computes the Bech32 checksum."""
- generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3]
+ generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
chk = 1
for value in values:
top = chk >> 25
- chk = (chk & 0x1ffffff) << 5 ^ value
+ chk = (chk & 0x1FFFFFF) << 5 ^ value
for i in range(5):
chk ^= generator[i] if ((top >> i) & 1) else 0
return chk
@@ -58,6 +61,7 @@ def bech32_verify_checksum(hrp, data):
return Encoding.BECH32M
return None
+
def bech32_create_checksum(hrp, data, spec):
"""Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data
@@ -69,26 +73,29 @@ def bech32_create_checksum(hrp, data, spec):
def bech32_encode(hrp, data, spec):
"""Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data, spec)
- return hrp + '1' + ''.join([CHARSET[d] for d in combined])
+ return hrp + "1" + "".join([CHARSET[d] for d in combined])
+
def bech32_decode(bech):
"""Validate a Bech32/Bech32m string, and determine HRP and data."""
- if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
- (bech.lower() != bech and bech.upper() != bech)):
+ if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
+ bech.lower() != bech and bech.upper() != bech
+ ):
return (None, None, None)
bech = bech.lower()
- pos = bech.rfind('1')
+ pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None, None)
- if not all(x in CHARSET for x in bech[pos+1:]):
+ if not all(x in CHARSET for x in bech[pos + 1 :]):
return (None, None, None)
hrp = bech[:pos]
- data = [CHARSET.find(x) for x in bech[pos+1:]]
+ data = [CHARSET.find(x) for x in bech[pos + 1 :]]
spec = bech32_verify_checksum(hrp, data)
if spec is None:
return (None, None, None)
return (hrp, data[:-6], spec)
+
def convertbits(data, frombits, tobits, pad=True):
"""General power-of-2 base conversion."""
acc = 0
@@ -124,7 +131,12 @@ def decode(hrp, addr):
return (None, None)
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
return (None, None)
- if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M:
+ if (
+ data[0] == 0
+ and spec != Encoding.BECH32
+ or data[0] != 0
+ and spec != Encoding.BECH32M
+ ):
return (None, None)
return (data[0], decoded)
diff --git a/nostr/client/client.py b/nostr/client/client.py
index db07a06..4624ff3 100644
--- a/nostr/client/client.py
+++ b/nostr/client/client.py
@@ -1,25 +1,36 @@
import asyncio
-from typing import List
+
+from loguru import logger
from ..relay_manager import RelayManager
class NostrClient:
- relays = [ ]
relay_manager = RelayManager()
- def __init__(self, relays: List[str] = [], connect=True):
- if len(relays):
- self.relays = relays
- if connect:
- self.connect()
+ def __init__(self):
+ self.running = True
+
+ def connect(self, relays):
+ for relay in relays:
+ try:
+ self.relay_manager.add_relay(relay)
+ except Exception as e:
+ logger.debug(e)
+ self.running = True
- async def connect(self):
- for relay in self.relays:
- self.relay_manager.add_relay(relay)
+ def reconnect(self, relays):
+ self.relay_manager.remove_relays()
+ self.connect(relays)
def close(self):
- self.relay_manager.close_connections()
+ try:
+ self.relay_manager.close_all_subscriptions()
+ self.relay_manager.close_connections()
+
+ self.running = False
+ except Exception as e:
+ logger.error(e)
async def subscribe(
self,
@@ -27,18 +38,36 @@ async def subscribe(
callback_notices_func=None,
callback_eosenotices_func=None,
):
- while True:
+ while self.running:
+ self._check_events(callback_events_func)
+ self._check_notices(callback_notices_func)
+ self._check_eos_notices(callback_eosenotices_func)
+
+ await asyncio.sleep(0.2)
+
+ def _check_events(self, callback_events_func=None):
+ try:
while self.relay_manager.message_pool.has_events():
event_msg = self.relay_manager.message_pool.get_event()
if callback_events_func:
callback_events_func(event_msg)
+ except Exception as e:
+ logger.debug(e)
+
+ def _check_notices(self, callback_notices_func=None):
+ try:
while self.relay_manager.message_pool.has_notices():
event_msg = self.relay_manager.message_pool.get_notice()
if callback_notices_func:
callback_notices_func(event_msg)
+ except Exception as e:
+ logger.debug(e)
+
+ def _check_eos_notices(self, callback_eosenotices_func=None):
+ try:
while self.relay_manager.message_pool.has_eose_notices():
event_msg = self.relay_manager.message_pool.get_eose_notice()
if callback_eosenotices_func:
callback_eosenotices_func(event_msg)
-
- await asyncio.sleep(0.5)
+ except Exception as e:
+ logger.debug(e)
diff --git a/nostr/delegation.py b/nostr/delegation.py
deleted file mode 100644
index 94801f5..0000000
--- a/nostr/delegation.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import time
-from dataclasses import dataclass
-
-
-@dataclass
-class Delegation:
- delegator_pubkey: str
- delegatee_pubkey: str
- event_kind: int
- duration_secs: int = 30*24*60 # default to 30 days
- signature: str = None # set in PrivateKey.sign_delegation
-
- @property
- def expires(self) -> int:
- return int(time.time()) + self.duration_secs
-
- @property
- def conditions(self) -> str:
- return f"kind={self.event_kind}&created_at<{self.expires}"
-
- @property
- def delegation_token(self) -> str:
- return f"nostr:delegation:{self.delegatee_pubkey}:{self.conditions}"
-
- def get_tag(self) -> list[str]:
- """ Called by Event """
- return [
- "delegation",
- self.delegator_pubkey,
- self.conditions,
- self.signature,
- ]
diff --git a/nostr/event.py b/nostr/event.py
index 65b187d..a7d4f1d 100644
--- a/nostr/event.py
+++ b/nostr/event.py
@@ -122,6 +122,7 @@ def __post_init__(self):
def id(self) -> str:
if self.content is None:
raise Exception(
- "EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field"
+ "EncryptedDirectMessage `id` is undefined until its"
+ + " message is encrypted and stored in the `content` field"
)
return super().id
diff --git a/nostr/filter.py b/nostr/filter.py
deleted file mode 100644
index f119079..0000000
--- a/nostr/filter.py
+++ /dev/null
@@ -1,134 +0,0 @@
-from collections import UserList
-from typing import List
-
-from .event import Event, EventKind
-
-
-class Filter:
- """
- NIP-01 filtering.
-
- Explicitly supports "#e" and "#p" tag filters via `event_refs` and `pubkey_refs`.
-
- Arbitrary NIP-12 single-letter tag filters are also supported via `add_arbitrary_tag`.
- If a particular single-letter tag gains prominence, explicit support should be
- added. For example:
- # arbitrary tag
- filter.add_arbitrary_tag('t', [hashtags])
-
- # promoted to explicit support
- Filter(hashtag_refs=[hashtags])
- """
-
- def __init__(
- self,
- event_ids: List[str] = None,
- kinds: List[EventKind] = None,
- authors: List[str] = None,
- since: int = None,
- until: int = None,
- event_refs: List[
- str
- ] = None, # the "#e" attr; list of event ids referenced in an "e" tag
- pubkey_refs: List[
- str
- ] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag
- limit: int = None,
- ) -> None:
- self.event_ids = event_ids
- self.kinds = kinds
- self.authors = authors
- self.since = since
- self.until = until
- self.event_refs = event_refs
- self.pubkey_refs = pubkey_refs
- self.limit = limit
-
- self.tags = {}
- if self.event_refs:
- self.add_arbitrary_tag("e", self.event_refs)
- if self.pubkey_refs:
- self.add_arbitrary_tag("p", self.pubkey_refs)
-
- def add_arbitrary_tag(self, tag: str, values: list):
- """
- Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12
- single-letter tags.
- """
- # NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#"
- tag_key = tag if len(tag) > 1 else f"#{tag}"
- self.tags[tag_key] = values
-
- def matches(self, event: Event) -> bool:
- if self.event_ids is not None and event.id not in self.event_ids:
- return False
- if self.kinds is not None and event.kind not in self.kinds:
- return False
- if self.authors is not None and event.public_key not in self.authors:
- return False
- if self.since is not None and event.created_at < self.since:
- return False
- if self.until is not None and event.created_at > self.until:
- return False
- if (self.event_refs is not None or self.pubkey_refs is not None) and len(
- event.tags
- ) == 0:
- return False
-
- if self.tags:
- e_tag_identifiers = set([e_tag[0] for e_tag in event.tags])
- for f_tag, f_tag_values in self.tags.items():
- # Omit any NIP-01 or NIP-12 "#" chars on single-letter tags
- f_tag = f_tag.replace("#", "")
-
- if f_tag not in e_tag_identifiers:
- # Event is missing a tag type that we're looking for
- return False
-
- # Multiple values within f_tag_values are treated as OR search; an Event
- # needs to match only one.
- # Note: an Event could have multiple entries of the same tag type
- # (e.g. a reply to multiple people) so we have to check all of them.
- match_found = False
- for e_tag in event.tags:
- if e_tag[0] == f_tag and e_tag[1] in f_tag_values:
- match_found = True
- break
- if not match_found:
- return False
-
- return True
-
- def to_json_object(self) -> dict:
- res = {}
- if self.event_ids is not None:
- res["ids"] = self.event_ids
- if self.kinds is not None:
- res["kinds"] = self.kinds
- if self.authors is not None:
- res["authors"] = self.authors
- if self.since is not None:
- res["since"] = self.since
- if self.until is not None:
- res["until"] = self.until
- if self.limit is not None:
- res["limit"] = self.limit
- if self.tags:
- res.update(self.tags)
-
- return res
-
-
-class Filters(UserList):
- def __init__(self, initlist: "list[Filter]" = []) -> None:
- super().__init__(initlist)
- self.data: "list[Filter]"
-
- def match(self, event: Event):
- for filter in self.data:
- if filter.matches(event):
- return True
- return False
-
- def to_json_array(self) -> list:
- return [filter.to_json_object() for filter in self.data]
diff --git a/nostr/key.py b/nostr/key.py
index 8089e11..3803650 100644
--- a/nostr/key.py
+++ b/nostr/key.py
@@ -1,6 +1,5 @@
import base64
import secrets
-from hashlib import sha256
import secp256k1
from cffi import FFI
@@ -8,7 +7,6 @@
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from . import bech32
-from .delegation import Delegation
from .event import EncryptedDirectMessage, Event, EventKind
@@ -37,7 +35,7 @@ def from_npub(cls, npub: str):
class PrivateKey:
def __init__(self, raw_secret: bytes = None) -> None:
- if not raw_secret is None:
+ if raw_secret is not None:
self.raw_secret = raw_secret
else:
self.raw_secret = secrets.token_bytes(32)
@@ -79,7 +77,10 @@ def encrypt_message(self, message: str, public_key_hex: str) -> str:
encryptor = cipher.encryptor()
encrypted_message = encryptor.update(padded_data) + encryptor.finalize()
- return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}"
+ return (
+ f"{base64.b64encode(encrypted_message).decode()}"
+ + f"?iv={base64.b64encode(iv).decode()}"
+ )
def encrypt_dm(self, dm: EncryptedDirectMessage) -> None:
dm.content = self.encrypt_message(
@@ -116,11 +117,6 @@ def sign_event(self, event: Event) -> None:
event.public_key = self.public_key.hex()
event.signature = self.sign_message_hash(bytes.fromhex(event.id))
- def sign_delegation(self, delegation: Delegation) -> None:
- delegation.signature = self.sign_message_hash(
- sha256(delegation.delegation_token.encode()).digest()
- )
-
def __eq__(self, other):
return self.raw_secret == other.raw_secret
diff --git a/nostr/message_pool.py b/nostr/message_pool.py
index 02f7fd4..a3e6c5f 100644
--- a/nostr/message_pool.py
+++ b/nostr/message_pool.py
@@ -2,13 +2,15 @@
from queue import Queue
from threading import Lock
-from .event import Event
from .message_type import RelayMessageType
class EventMessage:
- def __init__(self, event: Event, subscription_id: str, url: str) -> None:
+ def __init__(
+ self, event: str, event_id: str, subscription_id: str, url: str
+ ) -> None:
self.event = event
+ self.event_id = event_id
self.subscription_id = subscription_id
self.url = url
@@ -59,18 +61,16 @@ def _process_message(self, message: str, url: str):
message_type = message_json[0]
if message_type == RelayMessageType.EVENT:
subscription_id = message_json[1]
- e = message_json[2]
- event = Event(
- e["content"],
- e["pubkey"],
- e["created_at"],
- e["kind"],
- e["tags"],
- e["sig"],
- )
+ event = message_json[2]
+ if "id" not in event:
+ return
+ event_id = event["id"]
+
with self.lock:
- if not f"{subscription_id}_{event.id}" in self._unique_events:
- self._accept_event(EventMessage(event, subscription_id, url))
+ if f"{subscription_id}_{event_id}" not in self._unique_events:
+ self._accept_event(
+ EventMessage(json.dumps(event), event_id, subscription_id, url)
+ )
elif message_type == RelayMessageType.NOTICE:
self.notices.put(NoticeMessage(message_json[1], url))
elif message_type == RelayMessageType.END_OF_STORED_EVENTS:
@@ -78,10 +78,12 @@ def _process_message(self, message: str, url: str):
def _accept_event(self, event_message: EventMessage):
"""
- Event uniqueness is considered per `subscription_id`.
- The `subscription_id` is rewritten to be unique and it is the same accross relays.
- The same event can come from different subscriptions (from the same client or from different ones).
- Clients that have joined later should receive older events.
+ Event uniqueness is considered per `subscription_id`. The `subscription_id` is
+ rewritten to be unique and it is the same accross relays. The same event can
+ come from different subscriptions (from the same client or from different ones).
+ Clients that have joined later should receive older events.
"""
self.events.put(event_message)
- self._unique_events.add(f"{event_message.subscription_id}_{event_message.event.id}")
\ No newline at end of file
+ self._unique_events.add(
+ f"{event_message.subscription_id}_{event_message.event_id}"
+ )
diff --git a/nostr/relay.py b/nostr/relay.py
index caacba0..b576cfa 100644
--- a/nostr/relay.py
+++ b/nostr/relay.py
@@ -2,43 +2,23 @@
import json
import time
from queue import Queue
-from threading import Lock
from typing import List
from loguru import logger
from websocket import WebSocketApp
-from .event import Event
-from .filter import Filters
from .message_pool import MessagePool
-from .message_type import RelayMessageType
from .subscription import Subscription
-class RelayPolicy:
- def __init__(self, should_read: bool = True, should_write: bool = True) -> None:
- self.should_read = should_read
- self.should_write = should_write
-
- def to_json_object(self) -> dict[str, bool]:
- return {"read": self.should_read, "write": self.should_write}
-
-
class Relay:
- def __init__(
- self,
- url: str,
- policy: RelayPolicy,
- message_pool: MessagePool,
- subscriptions: dict[str, Subscription] = {},
- ) -> None:
+ def __init__(self, url: str, message_pool: MessagePool) -> None:
self.url = url
- self.policy = policy
self.message_pool = message_pool
- self.subscriptions = subscriptions
self.connected: bool = False
self.reconnect: bool = True
self.shutdown: bool = False
+
self.error_counter: int = 0
self.error_threshold: int = 100
self.error_list: List[str] = []
@@ -47,12 +27,10 @@ def __init__(
self.num_received_events: int = 0
self.num_sent_events: int = 0
self.num_subscriptions: int = 0
- self.ssl_options: dict = {}
- self.proxy: dict = {}
- self.lock = Lock()
+
self.queue = Queue()
- def connect(self, ssl_options: dict = None, proxy: dict = None):
+ def connect(self):
self.ws = WebSocketApp(
self.url,
on_open=self._on_open,
@@ -62,19 +40,14 @@ def connect(self, ssl_options: dict = None, proxy: dict = None):
on_ping=self._on_ping,
on_pong=self._on_pong,
)
- self.ssl_options = ssl_options
- self.proxy = proxy
if not self.connected:
- self.ws.run_forever(
- sslopt=ssl_options,
- http_proxy_host=None if proxy is None else proxy.get("host"),
- http_proxy_port=None if proxy is None else proxy.get("port"),
- proxy_type=None if proxy is None else proxy.get("type"),
- ping_interval=5,
- )
+ self.ws.run_forever(ping_interval=10)
def close(self):
- self.ws.close()
+ try:
+ self.ws.close()
+ except Exception as e:
+ logger.warning(f"[Relay: {self.url}] Failed to close websocket: {e}")
self.connected = False
self.shutdown = True
@@ -90,10 +63,9 @@ def ping(self):
def publish(self, message: str):
self.queue.put(message)
- def publish_subscriptions(self):
- for _, subscription in self.subscriptions.items():
- s = subscription.to_json_object()
- json_str = json.dumps(["REQ", s["id"], s["filters"][0]])
+ def publish_subscriptions(self, subscriptions: List[Subscription] = []):
+ for s in subscriptions:
+ json_str = json.dumps(["REQ", s.id] + s.filters)
self.publish(json_str)
async def queue_worker(self):
@@ -103,55 +75,44 @@ async def queue_worker(self):
message = self.queue.get(timeout=1)
self.num_sent_events += 1
self.ws.send(message)
- except:
+ except Exception as _:
pass
else:
await asyncio.sleep(1)
-
- if self.shutdown:
- logger.warning(f"Closing queue worker for '{self.url}'.")
- break
- def add_subscription(self, id, filters: Filters):
- with self.lock:
- self.subscriptions[id] = Subscription(id, filters)
+ if self.shutdown:
+ logger.warning(f"[Relay: {self.url}] Closing queue worker.")
+ return
def close_subscription(self, id: str) -> None:
- with self.lock:
- self.subscriptions.pop(id)
+ try:
self.publish(json.dumps(["CLOSE", id]))
-
- def to_json_object(self) -> dict:
- return {
- "url": self.url,
- "policy": self.policy.to_json_object(),
- "subscriptions": [
- subscription.to_json_object()
- for subscription in self.subscriptions.values()
- ],
- }
+ except Exception as e:
+ logger.debug(f"[Relay: {self.url}] Failed to close subscription: {e}")
def add_notice(self, notice: str):
- self.notice_list = ([notice] + self.notice_list)[:20]
+ self.notice_list = [notice] + self.notice_list
def _on_open(self, _):
- logger.info(f"Connected to relay: '{self.url}'.")
+ logger.info(f"[Relay: {self.url}] Connected.")
self.connected = True
-
+ self.shutdown = False
+
def _on_close(self, _, status_code, message):
- logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.")
+ logger.warning(
+ f"[Relay: {self.url}] Connection closed."
+ + f" Status: '{status_code}'. Message: '{message}'."
+ )
self.close()
def _on_message(self, _, message: str):
- if self._is_valid_message(message):
- self.num_received_events += 1
- self.message_pool.add_message(message, self.url)
+ self.num_received_events += 1
+ self.message_pool.add_message(message, self.url)
def _on_error(self, _, error):
- logger.warning(f"Relay error: '{str(error)}'")
+ logger.warning(f"[Relay: {self.url}] Error: '{str(error)}'")
self._append_error_message(str(error))
- self.connected = False
- self.error_counter += 1
+ self.close()
def _on_ping(self, *_):
return
@@ -159,65 +120,7 @@ def _on_ping(self, *_):
def _on_pong(self, *_):
return
- def _is_valid_message(self, message: str) -> bool:
- message = message.strip("\n")
- if not message or message[0] != "[" or message[-1] != "]":
- return False
-
- message_json = json.loads(message)
- message_type = message_json[0]
-
- if not RelayMessageType.is_valid(message_type):
- return False
-
- if message_type == RelayMessageType.EVENT:
- return self._is_valid_event_message(message_json)
-
- if message_type == RelayMessageType.COMMAND_RESULT:
- return self._is_valid_command_result_message(message, message_json)
-
- return True
-
- def _is_valid_event_message(self, message_json):
- if not len(message_json) == 3:
- return False
-
- subscription_id = message_json[1]
- with self.lock:
- if subscription_id not in self.subscriptions:
- return False
-
- e = message_json[2]
- event = Event(
- e["content"],
- e["pubkey"],
- e["created_at"],
- e["kind"],
- e["tags"],
- e["sig"],
- )
- if not event.verify():
- return False
-
- with self.lock:
- subscription = self.subscriptions[subscription_id]
-
- if subscription.filters and not subscription.filters.match(event):
- return False
-
- return True
-
- def _is_valid_command_result_message(self, message, message_json):
- if not len(message_json) < 3:
- return False
-
- if message_json[2] != True:
- logger.warning(f"Relay '{self.url}' negative command result: '{message}'")
- self._append_error_message(message)
- return False
-
- return True
-
def _append_error_message(self, message):
- self.error_list = ([message] + self.error_list)[:20]
- self.last_error_date = int(time.time())
\ No newline at end of file
+ self.error_counter += 1
+ self.error_list = [message] + self.error_list
+ self.last_error_date = int(time.time())
diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py
index f639fb0..ff7ca9c 100644
--- a/nostr/relay_manager.py
+++ b/nostr/relay_manager.py
@@ -1,21 +1,15 @@
-
import asyncio
-import ssl
import threading
import time
+from typing import List
from loguru import logger
-from .filter import Filters
from .message_pool import MessagePool, NoticeMessage
-from .relay import Relay, RelayPolicy
+from .relay import Relay
from .subscription import Subscription
-class RelayException(Exception):
- pass
-
-
class RelayManager:
def __init__(self) -> None:
self.relays: dict[str, Relay] = {}
@@ -25,72 +19,97 @@ def __init__(self) -> None:
self._cached_subscriptions: dict[str, Subscription] = {}
self._subscriptions_lock = threading.Lock()
- def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay:
+ def add_relay(self, url: str) -> Relay:
if url in list(self.relays.keys()):
- return
-
- with self._subscriptions_lock:
- subscriptions = self._cached_subscriptions.copy()
+ logger.debug(f"Relay '{url}' already present.")
+ return self.relays[url]
- policy = RelayPolicy(read, write)
- relay = Relay(url, policy, self.message_pool, subscriptions)
+ relay = Relay(url, self.message_pool)
self.relays[url] = relay
- self._open_connection(
- relay,
- {"cert_reqs": ssl.CERT_NONE}
- ) # NOTE: This disables ssl certificate verification
+ self._open_connection(relay)
- relay.publish_subscriptions()
+ relay.publish_subscriptions(list(self._cached_subscriptions.values()))
return relay
def remove_relay(self, url: str):
- self.relays[url].close()
- self.relays.pop(url)
- self.threads[url].join(timeout=5)
- self.threads.pop(url)
- self.queue_threads[url].join(timeout=5)
- self.queue_threads.pop(url)
-
-
- def add_subscription(self, id: str, filters: Filters):
+ try:
+ self.relays[url].close()
+ except Exception as e:
+ logger.debug(e)
+
+ if url in self.relays:
+ self.relays.pop(url)
+
+ try:
+ self.threads[url].join(timeout=5)
+ except Exception as e:
+ logger.debug(e)
+
+ if url in self.threads:
+ self.threads.pop(url)
+
+ try:
+ self.queue_threads[url].join(timeout=5)
+ except Exception as e:
+ logger.debug(e)
+
+ if url in self.queue_threads:
+ self.queue_threads.pop(url)
+
+ def remove_relays(self):
+ relay_urls = list(self.relays.keys())
+ for url in relay_urls:
+ self.remove_relay(url)
+
+ def add_subscription(self, id: str, filters: List[str]):
+ s = Subscription(id, filters)
with self._subscriptions_lock:
- self._cached_subscriptions[id] = Subscription(id, filters)
+ self._cached_subscriptions[id] = s
for relay in self.relays.values():
- relay.add_subscription(id, filters)
+ relay.publish_subscriptions([s])
def close_subscription(self, id: str):
- with self._subscriptions_lock:
- self._cached_subscriptions.pop(id)
+ try:
+ with self._subscriptions_lock:
+ if id in self._cached_subscriptions:
+ self._cached_subscriptions.pop(id)
- for relay in self.relays.values():
- relay.close_subscription(id)
+ for relay in self.relays.values():
+ relay.close_subscription(id)
+ except Exception as e:
+ logger.debug(e)
+
+ def close_subscriptions(self, subscriptions: List[str]):
+ for id in subscriptions:
+ self.close_subscription(id)
+
+ def close_all_subscriptions(self):
+ all_subscriptions = list(self._cached_subscriptions.keys())
+ self.close_subscriptions(all_subscriptions)
def check_and_restart_relays(self):
stopped_relays = [r for r in self.relays.values() if r.shutdown]
for relay in stopped_relays:
self._restart_relay(relay)
-
def close_connections(self):
for relay in self.relays.values():
relay.close()
def publish_message(self, message: str):
for relay in self.relays.values():
- if relay.policy.should_write:
- relay.publish(message)
+ relay.publish(message)
def handle_notice(self, notice: NoticeMessage):
relay = next((r for r in self.relays.values() if r.url == notice.url))
if relay:
relay.add_notice(notice.content)
- def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None):
+ def _open_connection(self, relay: Relay):
self.threads[relay.url] = threading.Thread(
target=relay.connect,
- args=(ssl_options, proxy),
name=f"{relay.url}-thread",
daemon=True,
)
@@ -98,7 +117,7 @@ def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict =
def wrap_async_queue_worker():
asyncio.run(relay.queue_worker())
-
+
self.queue_threads[relay.url] = threading.Thread(
target=wrap_async_queue_worker,
name=f"{relay.url}-queue",
@@ -108,14 +127,16 @@ def wrap_async_queue_worker():
def _restart_relay(self, relay: Relay):
time_since_last_error = time.time() - relay.last_error_date
-
- min_wait_time = min(60 * relay.error_counter, 60 * 60 * 24) # try at least once a day
+
+ min_wait_time = min(
+ 60 * relay.error_counter, 60 * 60
+ ) # try at least once an hour
if time_since_last_error < min_wait_time:
return
-
+
logger.info(f"Restarting connection to relay '{relay.url}'")
self.remove_relay(relay.url)
new_relay = self.add_relay(relay.url)
new_relay.error_counter = relay.error_counter
- new_relay.error_list = relay.error_list
\ No newline at end of file
+ new_relay.error_list = relay.error_list
diff --git a/nostr/subscription.py b/nostr/subscription.py
index 76da0af..a75c1a1 100644
--- a/nostr/subscription.py
+++ b/nostr/subscription.py
@@ -1,13 +1,7 @@
-from .filter import Filters
+from typing import List
class Subscription:
- def __init__(self, id: str, filters: Filters=None) -> None:
+ def __init__(self, id: str, filters: List[str] = None) -> None:
self.id = id
self.filters = filters
-
- def to_json_object(self):
- return {
- "id": self.id,
- "filters": self.filters.to_json_array()
- }
diff --git a/router.py b/router.py
index cc0a380..e6ccdef 100644
--- a/router.py
+++ b/router.py
@@ -1,42 +1,61 @@
import asyncio
import json
-from typing import List, Union
+from typing import Dict, List
-from fastapi import WebSocketDisconnect
+from fastapi import WebSocket, WebSocketDisconnect
from loguru import logger
from lnbits.helpers import urlsafe_short_hash
-from . import nostr
-from .models import Event, Filter
-from .nostr.filter import Filter as NostrFilter
-from .nostr.filter import Filters as NostrFilters
-from .nostr.message_pool import EndOfStoredEventsMessage, NoticeMessage
+from . import nostr_client
+from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
class NostrRouter:
-
- received_subscription_events: dict[str, list[Event]] = {}
+ received_subscription_events: dict[str, List[EventMessage]] = {}
received_subscription_notices: list[NoticeMessage] = []
received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {}
- def __init__(self, websocket):
- self.subscriptions: List[str] = []
+ def __init__(self, websocket: WebSocket):
self.connected: bool = True
- self.websocket = websocket
+ self.websocket: WebSocket = websocket
self.tasks: List[asyncio.Task] = []
- self.original_subscription_ids = {}
-
- async def client_to_nostr(self):
- """Receives requests / data from the client and forwards it to relays. If the
- request was a subscription/filter, registers it with the nostr client lib.
- Remembers the subscription id so we can send back responses from the relay to this
- client in `nostr_to_client`"""
- while True:
+ self.original_subscription_ids: Dict[str, str] = {}
+
+ @property
+ def subscriptions(self) -> List[str]:
+ return list(self.original_subscription_ids.keys())
+
+ def start(self):
+ self.connected = True
+ self.tasks.append(asyncio.create_task(self._client_to_nostr()))
+ self.tasks.append(asyncio.create_task(self._nostr_to_client()))
+
+ async def stop(self):
+ nostr_client.relay_manager.close_subscriptions(self.subscriptions)
+ self.connected = False
+
+ for t in self.tasks:
+ try:
+ t.cancel()
+ except Exception as _:
+ pass
+
+ try:
+ await self.websocket.close()
+ except Exception as _:
+ pass
+
+ async def _client_to_nostr(self):
+ """
+ Receives requests / data from the client and forwards it to relays.
+ """
+ while self.connected:
try:
json_str = await self.websocket.receive_text()
- except WebSocketDisconnect:
- self.connected = False
+ except WebSocketDisconnect as e:
+ logger.debug(e)
+ await self.stop()
break
try:
@@ -44,15 +63,9 @@ async def client_to_nostr(self):
except Exception as e:
logger.debug(f"Failed to handle client message: '{str(e)}'.")
-
- async def nostr_to_client(self):
- """Sends responses from relays back to the client. Polls the subscriptions of this client
- stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which
- is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs
- the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id
- that we had previously rewritten in order to avoid collisions when multiple clients use the same id.
- """
- while True and self.connected:
+ async def _nostr_to_client(self):
+ """Sends responses from relays back to the client."""
+ while self.connected:
try:
await self._handle_subscriptions()
self._handle_notices()
@@ -61,24 +74,6 @@ async def nostr_to_client(self):
await asyncio.sleep(0.1)
- async def start(self):
- self.tasks.append(asyncio.create_task(self.client_to_nostr()))
- self.tasks.append(asyncio.create_task(self.nostr_to_client()))
-
- async def stop(self):
- for t in self.tasks:
- try:
- t.cancel()
- except:
- pass
-
- for s in self.subscriptions:
- try:
- nostr.client.relay_manager.close_subscription(s)
- except:
- pass
- self.connected = False
-
async def _handle_subscriptions(self):
for s in self.subscriptions:
if s in NostrRouter.received_subscription_events:
@@ -86,8 +81,6 @@ async def _handle_subscriptions(self):
if s in NostrRouter.received_subscription_eosenotices:
await self._handle_received_subscription_eosenotices(s)
-
-
async def _handle_received_subscription_eosenotices(self, s):
try:
if s not in self.original_subscription_ids:
@@ -95,7 +88,7 @@ async def _handle_received_subscription_eosenotices(self, s):
s_original = self.original_subscription_ids[s]
event_to_forward = ["EOSE", s_original]
del NostrRouter.received_subscription_eosenotices[s]
-
+
await self.websocket.send_text(json.dumps(event_to_forward))
except Exception as e:
logger.debug(e)
@@ -104,97 +97,62 @@ async def _handle_received_subscription_events(self, s):
try:
if s not in NostrRouter.received_subscription_events:
return
+
while len(NostrRouter.received_subscription_events[s]):
- my_event = NostrRouter.received_subscription_events[s].pop(0)
- # event.to_message() does not include the subscription ID, we have to add it manually
- event_json = {
- "id": my_event.id,
- "pubkey": my_event.public_key,
- "created_at": my_event.created_at,
- "kind": my_event.kind,
- "tags": my_event.tags,
- "content": my_event.content,
- "sig": my_event.signature,
- }
+ event_message = NostrRouter.received_subscription_events[s].pop(0)
+ event_json = event_message.event
# this reconstructs the original response from the relay
# reconstruct original subscription id
s_original = self.original_subscription_ids[s]
- event_to_forward = ["EVENT", s_original, event_json]
- await self.websocket.send_text(json.dumps(event_to_forward))
+ event_to_forward = f"""["EVENT", "{s_original}", {event_json}]"""
+ await self.websocket.send_text(event_to_forward)
except Exception as e:
- logger.debug(e)
+ logger.debug(e) # there are 2900 errors here
def _handle_notices(self):
while len(NostrRouter.received_subscription_notices):
my_event = NostrRouter.received_subscription_notices.pop(0)
- # note: we don't send it to the user because we don't know who should receive it
- logger.info(f"Relay ('{my_event.url}') notice: '{my_event.content}']")
- nostr.client.relay_manager.handle_notice(my_event)
-
-
-
- def _marshall_nostr_filters(self, data: Union[dict, list]):
- filters = data if isinstance(data, list) else [data]
- filters = [Filter.parse_obj(f) for f in filters]
- filter_list: list[NostrFilter] = []
- for filter in filters:
- filter_list.append(
- NostrFilter(
- event_ids=filter.ids, # type: ignore
- kinds=filter.kinds, # type: ignore
- authors=filter.authors, # type: ignore
- since=filter.since, # type: ignore
- until=filter.until, # type: ignore
- event_refs=filter.e, # type: ignore
- pubkey_refs=filter.p, # type: ignore
- limit=filter.limit, # type: ignore
- )
- )
- return NostrFilters(filter_list)
+ logger.info(f"[Relay '{my_event.url}'] Notice: '{my_event.content}']")
+ # Note: we don't send it to the user because
+ # we don't know who should receive it
+ nostr_client.relay_manager.handle_notice(my_event)
async def _handle_client_to_nostr(self, json_str):
- """Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will
- register the subscription in the nostr client library that we're using so we can
- receive the callbacks on it later. Will rewrite the subscription id since we expect
- multiple clients to use the router and want to avoid subscription id collisions
- """
-
json_data = json.loads(json_str)
- assert len(json_data)
-
+ assert len(json_data), "Bad JSON array"
if json_data[0] == "REQ":
self._handle_client_req(json_data)
return
-
+
if json_data[0] == "CLOSE":
self._handle_client_close(json_data[1])
return
if json_data[0] == "EVENT":
- nostr.client.relay_manager.publish_message(json_str)
+ nostr_client.relay_manager.publish_message(json_str)
return
def _handle_client_req(self, json_data):
subscription_id = json_data[1]
subscription_id_rewritten = urlsafe_short_hash()
self.original_subscription_ids[subscription_id_rewritten] = subscription_id
- fltr = json_data[2:]
- filters = self._marshall_nostr_filters(fltr)
+ filters = json_data[2:]
- nostr.client.relay_manager.add_subscription(
- subscription_id_rewritten, filters
- )
- request_rewritten = json.dumps([json_data[0], subscription_id_rewritten] + fltr)
-
- self.subscriptions.append(subscription_id_rewritten)
- nostr.client.relay_manager.publish_message(request_rewritten)
+ nostr_client.relay_manager.add_subscription(subscription_id_rewritten, filters)
def _handle_client_close(self, subscription_id):
- subscription_id_rewritten = next((k for k, v in self.original_subscription_ids.items() if v == subscription_id), None)
+ subscription_id_rewritten = next(
+ (
+ k
+ for k, v in self.original_subscription_ids.items()
+ if v == subscription_id
+ ),
+ None,
+ )
if subscription_id_rewritten:
self.original_subscription_ids.pop(subscription_id_rewritten)
- nostr.client.relay_manager.close_subscription(subscription_id_rewritten)
+ nostr_client.relay_manager.close_subscription(subscription_id_rewritten)
else:
logger.debug(f"Failed to unsubscribe from '{subscription_id}.'")
diff --git a/tasks.py b/tasks.py
index 4c316bc..69aa33c 100644
--- a/tasks.py
+++ b/tasks.py
@@ -3,75 +3,69 @@
from loguru import logger
-from . import nostr
+from . import nostr_client
from .crud import get_relays
from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
-from .router import NostrRouter, nostr
+from .router import NostrRouter
async def init_relays():
- # reinitialize the entire client
- nostr.__init__()
# get relays from db
relays = await get_relays()
# set relays and connect to them
- nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url]))
- await nostr.client.connect()
+ valid_relays = list(set([r.url for r in relays if r.url]))
+
+ nostr_client.reconnect(valid_relays)
async def check_relays():
- """ Check relays that have been disconnected """
+ """Check relays that have been disconnected"""
while True:
try:
await asyncio.sleep(20)
- nostr.client.relay_manager.check_and_restart_relays()
+ nostr_client.relay_manager.check_and_restart_relays()
except Exception as e:
logger.warning(f"Cannot restart relays: '{str(e)}'.")
-
+
async def subscribe_events():
- while not any([r.connected for r in nostr.client.relay_manager.relays.values()]):
+ while not any([r.connected for r in nostr_client.relay_manager.relays.values()]):
await asyncio.sleep(2)
def callback_events(eventMessage: EventMessage):
- if eventMessage.subscription_id in NostrRouter.received_subscription_events:
- # do not add duplicate events (by event id)
- if eventMessage.event.id in set(
- [
- e.id
- for e in NostrRouter.received_subscription_events[eventMessage.subscription_id]
- ]
- ):
- return
-
- NostrRouter.received_subscription_events[eventMessage.subscription_id].append(
- eventMessage.event
- )
- else:
- NostrRouter.received_subscription_events[eventMessage.subscription_id] = [
- eventMessage.event
- ]
- return
+ sub_id = eventMessage.subscription_id
+ if sub_id not in NostrRouter.received_subscription_events:
+ NostrRouter.received_subscription_events[sub_id] = [eventMessage]
+ return
+
+ # do not add duplicate events (by event id)
+ ids = set(
+ [e.event_id for e in NostrRouter.received_subscription_events[sub_id]]
+ )
+ if eventMessage.event_id in ids:
+ return
+
+ NostrRouter.received_subscription_events[sub_id].append(eventMessage)
def callback_notices(noticeMessage: NoticeMessage):
if noticeMessage not in NostrRouter.received_subscription_notices:
NostrRouter.received_subscription_notices.append(noticeMessage)
- return
def callback_eose_notices(eventMessage: EndOfStoredEventsMessage):
- if eventMessage.subscription_id not in NostrRouter.received_subscription_eosenotices:
- NostrRouter.received_subscription_eosenotices[
- eventMessage.subscription_id
- ] = eventMessage
+ sub_id = eventMessage.subscription_id
+ if sub_id in NostrRouter.received_subscription_eosenotices:
+ return
- return
+ NostrRouter.received_subscription_eosenotices[sub_id] = eventMessage
def wrap_async_subscribe():
- asyncio.run(nostr.client.subscribe(
- callback_events,
- callback_notices,
- callback_eose_notices,
- ))
+ asyncio.run(
+ nostr_client.subscribe(
+ callback_events,
+ callback_notices,
+ callback_eose_notices,
+ )
+ )
t = threading.Thread(
target=wrap_async_subscribe,
diff --git a/templates/nostrclient/index.html b/templates/nostrclient/index.html
index a0c5999..db0f98e 100644
--- a/templates/nostrclient/index.html
+++ b/templates/nostrclient/index.html
@@ -6,13 +6,30 @@
Nostrclient
Nostrclient
Nostrclient
Nostrclient
Sender Private Key:
Nostrclient
Nostrclient
Sender Public Key:
Nostrclient
Test Message:
Nostrclient
Receiver Public Key:
Nostrclient
Sent Data:
Nostrclient
Received Data:
Nostrclient Extension
-