diff --git a/sample.env b/.env.example similarity index 52% rename from sample.env rename to .env.example index 43136b0..5f2dd68 100644 --- a/sample.env +++ b/.env.example @@ -2,6 +2,9 @@ HOST=misskey.example.com SECRET_TOKEN=misskey_token CONFIG_DIR=./config +DB_TYPE=redis +DB_URL=redis://localhost:6379 + RUN_SERVER=false -DB_TYPE=pickle -DB_URL=ssss +SERVER_HOST=0.0.0.0 +SERVER_PORT=8000 diff --git a/.gitignore b/.gitignore index 7e5f432..f577d01 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,17 @@ +# venv env/ + +# cache __pycache__ +.pytest_cache +# private *.env -.pytest_cache -!sample.env data/users.pickle +ngWords.txt +# vscode +.vscode/* !.vscode/settings.json -.vscode/ -ngWords.txt +*.rdb diff --git a/.vscode/settings.json b/.vscode/settings.json index a5c31cd..98af5f9 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,6 +1,11 @@ { "cSpell.words": [ - "dotenv" + "coloredlogs", + "dotenv", + "levelname", + "misskey", + "renote", + "websockets" ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..81021ad --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +pythonpath = "src" +testpaths = ["tests"] diff --git a/requirements.txt b/requirements.txt index d697330..c36f3ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,39 @@ +annotated-types==0.7.0 blinker==1.8.2 certifi==2024.6.2 +cffi==1.16.0 charset-normalizer==3.3.2 click==8.1.7 colorama==0.4.6 coloredlogs==15.0.1 +cryptography==42.0.8 Flask==3.0.3 +hiredis==2.3.2 humanfriendly==10.0 idna==3.7 iniconfig==2.0.0 itsdangerous==2.2.0 Jinja2==3.1.4 MarkupSafe==2.1.5 +more-itertools==10.3.0 packaging==24.0 pluggy==1.5.0 +pycparser==2.22 +pydantic==2.7.3 +pydantic-settings==2.3.1 +pydantic_core==2.18.4 pyreadline3==3.4.1 pytest==8.2.2 python-dotenv==1.0.1 +python-ulid==1.1.0 redis==5.0.5 +redis-om==0.3.1 requests==2.32.3 +types-cffi==1.16.0.20240331 +types-pyOpenSSL==24.1.0.20240425 +types-redis==4.6.0.20240425 +types-setuptools==70.0.0.20240524 +typing_extensions==4.12.2 urllib3==2.2.1 -websocket-client==1.8.0 +websockets==12.0 Werkzeug==3.0.3 diff --git a/src/emojis.py b/src/emojis.py index 75fac96..5e2f560 100644 --- a/src/emojis.py +++ b/src/emojis.py @@ -1,6 +1,6 @@ -import json import random -from typing import Any + +from utils import load_from_json_path class ConfigJsonError(Exception): @@ -9,22 +9,22 @@ class ConfigJsonError(Exception): class EmojiSet: def __init__(self, data: str | dict) -> None: - if isinstance(data, str): - with open(data) as f: - loaded = json.load(f) - else: - loaded = data + loaded = load_from_json_path(data, dict) self._check_format(loaded) self.response_emojis = loaded["triggers"] self.others = loaded["others"] - def _check_format(self, json: Any) -> None: + def _check_format(self, json: dict) -> None: if not isinstance(json, dict) or sorted(json.keys()) != ["others", "triggers"]: - raise ConfigJsonError("response.jsonは{'triggers': [], 'others': []}の形にしてください。") - - if any([tuple(i.keys()) != ("keywords", "emoji") for i in json["triggers"]]): - raise ConfigJsonError("response.jsonのトリガーのキーはkeywordsとemojiにしてください。") + raise ConfigJsonError( + "response.jsonは{'triggers': [], 'others': []}の形にしてください。" + ) + + if any([sorted(i.keys()) != ["emoji", "keywords"] for i in json["triggers"]]): + raise ConfigJsonError( + "response.jsonのトリガーのキーはkeywordsとemojiにしてください。" + ) def get_response_emoji(self, text: str) -> str: for i in self.response_emojis: diff --git a/src/environs.py b/src/environs.py new file mode 100644 index 0000000..84b31e5 --- /dev/null +++ b/src/environs.py @@ -0,0 +1,33 @@ +from ipaddress import IPv4Address +from pathlib import Path +from typing import Literal, Optional + +from pydantic import DirectoryPath, Field, model_validator +from pydantic.networks import IPvAnyAddress, RedisDsn +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class DotenvSettings(BaseSettings): + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") + + +class Settings(DotenvSettings): + host: str = Field(default=...) + secret_token: str = Field(default=...) + + db_type: Literal["redis", "pickle"] = "redis" + db_url: Optional[RedisDsn] = None + + config_dir: DirectoryPath = Path("./config") + run_server: bool = False + server_host: IPvAnyAddress = IPv4Address("0.0.0.0") # type: ignore + server_port: int = 8000 + + @model_validator(mode="after") + def validate_environ(self): + if self.db_url is None and self.db_type == "redis": + raise ValueError("DB_URL is required when DB_TYPE is redis") + + if not (self.server_port > 0 and self.server_port < 65536): + raise ValueError("SERVER_PORT must be between 1 and 65535") + return self diff --git a/src/keep_alive.py b/src/keep_alive.py index 6b2f113..92ce11f 100644 --- a/src/keep_alive.py +++ b/src/keep_alive.py @@ -1,20 +1,15 @@ -import logging +import asyncio from multiprocessing import Process -from os import getenv -import coloredlogs -from dotenv import load_dotenv from flask import Flask, jsonify +import environs import logging_styles import mainbot - app = Flask("app") -logger = logging.getLogger(__name__) -logging_styles.set_default() -coloredlogs.install(logger=logger) +logger = logging_styles.getLogger(__name__) @app.get("/") @@ -22,13 +17,13 @@ def pong(): return jsonify({"message": "Pong!"}) -def run_server(): - app.run(host="0.0.0.0", port=8080) +def run_server(host: str, port: int): + app.run(host=host, port=port) if __name__ == "__main__": - load_dotenv() - if getenv("RUN_SERVER", False): - Process(target=run_server).start() + config = environs.Settings() + if config.run_server: + Process(target=run_server, args=(str(config.server_host), config.server_port)).start() logger.info("Web server started!") - mainbot.Bot().start_bot() + asyncio.run(mainbot.Bot(config).start_bot()) diff --git a/src/logging_styles.py b/src/logging_styles.py index 4ff1ca3..93ab711 100644 --- a/src/logging_styles.py +++ b/src/logging_styles.py @@ -19,10 +19,17 @@ "critical": {"color": "red"}, } +formatter = coloredlogs.ColoredFormatter( + fmt=DEFAULT_LOG_FORMAT, + datefmt=DEFAULT_DATE_FORMAT, + field_styles=DEFAULT_FIELD_STYLES, + level_styles=DEFAULT_LEVEL_STYLES, +) -def set_default(): - coloredlogs.DEFAULT_LOG_LEVEL = DEFAULT_LOG_LEVEL - coloredlogs.DEFAULT_LOG_FORMAT = DEFAULT_LOG_FORMAT - coloredlogs.DEFAULT_DATE_FORMAT = DEFAULT_DATE_FORMAT - coloredlogs.DEFAULT_FIELD_STYLES = DEFAULT_FIELD_STYLES - coloredlogs.DEFAULT_LEVEL_STYLES = DEFAULT_LEVEL_STYLES + +def getLogger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(DEFAULT_LOG_LEVEL) + logger.addHandler(logging.StreamHandler()) + logger.handlers[0].setFormatter(formatter) + return logger diff --git a/src/mainbot.py b/src/mainbot.py index 98ed5c5..4e506e2 100644 --- a/src/mainbot.py +++ b/src/mainbot.py @@ -1,14 +1,14 @@ +import asyncio import json -import logging -import os from threading import Thread -import coloredlogs -import websocket +import websockets import utils import logging_styles import misskey_api as misskey +from environs import Settings +from userdb import UserDB from ngwords import NGWords from emojis import EmojiSet @@ -16,21 +16,19 @@ class Bot: counter = utils.Counter(100, lambda: None) - def __init__(self, restart: bool = True) -> None: - logger = logging.getLogger(__name__) - logging_styles.set_default() - coloredlogs.install(logger=logger) - self.logger = logger + def __init__(self, settings: Settings, restart: bool = True) -> None: + self.logger = logging_styles.getLogger(__name__) + self.config = settings self._restart = restart - self.config_dir = utils.config_dir() + self.config_dir = self.config.config_dir - logger.info("Loading response.json...") - self.emojis = EmojiSet(os.path.join(self.config_dir, "response.json")) - logger.info("Loading ngwords.txt...") - self.ngw = NGWords(os.path.join(self.config_dir, "ngwords.txt")) + self.logger.info("Loading response.json...") + self.emojis = EmojiSet(str(self.config_dir.joinpath("response.json"))) + self.logger.info("Loading ngwords.txt...") + self.ngw = NGWords(str(self.config_dir.joinpath("ngwords.txt"))) - self.db = utils.get_db() + self.db = UserDB(str(self.config.db_url)) # TODO: redis以外への対応 # TODO: なんか良い名前に変えたい def send_welcome(self, note_id: str, note_text: str) -> None: @@ -50,10 +48,11 @@ def need(self) -> bool: return True return False - def on_message(self, ws, message: str) -> None: + async def on_message(self, ws, message: str) -> None: note_body = json.loads(message)["body"]["body"] note_id = note_body["id"] note_text = note_body["text"] + user_id = note_body["userId"] if note_text is None: note_text = "" @@ -64,13 +63,15 @@ def on_message(self, ws, message: str) -> None: # Renote不可ならreturn return_flg = True if self.ngw.match(note_text): - self.logger.info(f"Detected NG word. | noteId: {note_id}, \ - word: {self.ngw.why(note_text)}") + self.logger.info( + f"Detected NG word. | noteId: {note_id}, \ + word: {self.ngw.why(note_text)}" + ) elif misskey.can_reply(note_body): Thread(target=misskey.reply, args=(note_id, "Pong!")).start() elif not misskey.can_renote(note_body): pass - elif note_body["userId"] in set(self.db): + elif await self.db.get_user_by_id(user_id): self.logger.debug("Skiped api request because it was registered in DB.") else: return_flg = False @@ -81,40 +82,53 @@ def on_message(self, ws, message: str) -> None: self.logger.debug( f"Notes not registered in database. | body: {note_text} , id: {note_id}" ) - user_info = misskey.get_user_info(user_id=note_body["userId"]) + user_info = misskey.get_user_info(user_id=user_id) if (notes_count := user_info["notesCount"]) == 1: self.send_welcome(note_id, note_text) elif notes_count <= 10: # ノート数が10以下ならRenote出来る可能性 - notes = misskey.get_user_notes(note_body["userId"], note_id, 10) + notes = misskey.get_user_notes(user_id, note_id, 10) if all([not misskey.can_renote(note) for note in notes]): self.send_welcome(note_id, note_text) return None if notes_count > 5: - self.db.append(note_body["userId"]) - if (count := len(self.db)) % 100 == 0: - utils.update_db("have_note_user_ids", self.db, False) - self.logger.info(f"DataBase Updated. | length: {count}") + await self.db.add_user(user_id, note_body["user"]["username"]) + self.logger.info("DataBase Updated.") - def on_error(self, ws, error) -> None: + async def on_error(self, ws, error) -> None: self.logger.warning(str(error)) - def on_close(self, ws, status_code, msg) -> None: + async def on_close(self, ws, status_code, msg) -> bool: self.logger.error(f"WebSocket closed. | code:{status_code} msg: {msg}") - if self._restart: - self.start_bot() + return self._restart - def start_bot(self): + async def start_bot(self): streaming_api = f"wss://{misskey.HOST}/streaming?i={misskey.TOKEN}" USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36" # NOQA - MESSAGE = {"type": "connect", "body": {"channel": "hybridTimeline", "id": "1"}} - # WebSocketの接続 - ws = websocket.WebSocketApp( - streaming_api, - on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, - header={"User-Agent": USER_AGENT} - ) - ws.on_open = lambda ws: ws.send(json.dumps(MESSAGE)) - self.logger.info("Bot was started!") - ws.run_forever() + CONNECTMSG = { + "type": "connect", + "body": {"channel": "hybridTimeline", "id": "1"}, + } + + pong = await self.db.ping() + if not pong: + raise Exception("DB connection failed.") + + while True: + async with websockets.connect(streaming_api, user_agent_header=USER_AGENT) as ws: + # self.on_open(ws) + self.logger.info("Bot was started!") + await ws.send(json.dumps(CONNECTMSG)) + while True: + try: + msg = await ws.recv() + await self.on_message(ws, str(msg)) + except websockets.ConnectionClosed: + await self.on_close(ws, ws.close_code, ws.close_reason) + if not self._restart: + return + break + except Exception as e: + await self.on_error(ws, e) + await asyncio.sleep(5) diff --git a/src/misskey_api.py b/src/misskey_api.py index e676a66..1f64c70 100644 --- a/src/misskey_api.py +++ b/src/misskey_api.py @@ -1,10 +1,6 @@ -import pickle -import logging from os import getenv -from collections import deque import requests -import coloredlogs from dotenv import load_dotenv from requests import Timeout @@ -17,9 +13,7 @@ TOKEN = getenv("SECRET_TOKEN") USERNAME = requests.post(f"https://{HOST}/api/i", json={"i": TOKEN}).json()["username"] -logger = logging.getLogger(__name__) -logging_styles.set_default() -coloredlogs.install(logger=logger) +logger = logging_styles.getLogger(__name__) limiter = RateLimiter(0.5) limiter2 = RateLimiter(0.5) diff --git a/src/ngwords.py b/src/ngwords.py index b23e5b2..7fd4647 100644 --- a/src/ngwords.py +++ b/src/ngwords.py @@ -1,4 +1,6 @@ -import pathlib +import os + +from utils import load_from_path class NGWords: @@ -7,11 +9,11 @@ class NGWords: initにngワードのテキストファイルのパスを渡してください。 """ - def __init__(self, path) -> None: - self._path = pathlib.Path(path) + def __init__(self, path: str | os.PathLike) -> None: + self.raw = load_from_path(path) self._load() - def __getitem__(self, key) -> dict: + def __getitem__(self, key) -> set: if (key := key.lower()) == "ng": return self._ng elif key == "excluded": @@ -20,11 +22,18 @@ def __getitem__(self, key) -> dict: raise KeyError('"ng"か"excluded"を指定してください') def _load(self) -> None: - with self._path.open() as f: - data = f.read().split("\n") - data = [j.lower() for j in data if j != ""] - self._ng = {j for j in data if (j[0] != "-") and (j[0] != "#")} - self._allow = {j[1:] for j in data if j[0] == "-"} + data = self.raw.split("\n") + data = [i.lower() for i in data if i != ""] + + ng = set() + allow = set() + for i in data: + if i[0] == "-": + allow.add(i[1:].lstrip(" ")) + elif i[0] != "#": + ng.add(i) + self._ng = ng + self._allow = allow def match(self, text) -> bool: text = text.lower() @@ -45,8 +54,3 @@ def all_ng_words(self) -> set: @property def all_excluded_words(self) -> set: return self._allow - - -if __name__ == "__main__": - print(NGWords(r"ng_words/ngWords.txt").match("r-18")) - print(NGWords(r"ngWords_Hiraassssssss.txt").all_ng_words) # FileNotFound diff --git a/src/userdb.py b/src/userdb.py new file mode 100644 index 0000000..b2706c9 --- /dev/null +++ b/src/userdb.py @@ -0,0 +1,68 @@ +import asyncio +from datetime import datetime + +from pydantic import PastDatetime +from aredis_om import ( + HashModel, + Migrator, + get_redis_connection, + Field, + NotFoundError, +) + + +class UserInfo(HashModel): + user_id: str = Field(primary_key=True) + user_name: str = Field(index=True) + last_received_date: PastDatetime + + class Meta: + global_key_prefix = "MisskeyWelcomeBot" + model_key_prefix = "PostedUsers" + + +class UserDB: + def __init__(self, redis_url: str) -> None: + self._db_url = redis_url + UserInfo.Meta.database = get_redis_connection(url=self._db_url) # type: ignore + + async def ping(self): + return await UserInfo.db().ping() + + async def get_all_users(self) -> list[UserInfo]: + all_pks = await UserInfo.all_pks() + return await asyncio.gather(*[UserInfo.get(i) async for i in all_pks]) + + async def get_user_by_id(self, id: str) -> UserInfo | None: + try: + return await UserInfo.get(id) + except NotFoundError: + return None + + async def get_user_by_name(self, name: str) -> UserInfo | None: + await self._migrate() + found = UserInfo.find(UserInfo.user_name == name) + try: + return await found.first() # type: ignore + except NotFoundError: + return None + + async def add_user(self, user_id: str, username: str) -> None: + await UserInfo( + user_id=user_id, + user_name=username, + last_received_date=datetime.now() + ).save() + + async def _migrate(self) -> None: + await Migrator().run() + + +if __name__ == "__main__": + import asyncio, environs + + db_url = environs.Settings().db_url + if db_url is None: + exit(print("db_type is not redis")) + db = UserDB(str(db_url)) + print(asyncio.run(db.get_user_by_name("ffdi"))) diff --git a/src/utils.py b/src/utils.py index 9b6f5f1..e6e2810 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,10 +1,11 @@ import time -import os -import pickle -from typing import Any -from collections import deque +from typing import Any, TypeVar, Type -import dotenv +import json +from os import PathLike + + +T = TypeVar("T") class RateLimiter: @@ -49,48 +50,17 @@ def wrapper(*args, **kwargs): return wrapper -def config_dir(): - dotenv.load_dotenv() - dir = os.getenv("CONFIG_DIR", "./config") - if not os.path.exists(dir): - raise FileNotFoundError("Config directory not found.") - return dir - - -def db_type(): - dotenv.load_dotenv() - return os.getenv("DB_TYPE") - - -def update_db(key: str, value, allow_duplicates: bool = True) -> None: - if not allow_duplicates: - value = set(value) - - if db_type() == "redis": - import redis - dotenv.load_dotenv() - - r = redis.from_url(os.getenv("DB_URL")) - p = r.pipeline() - for i in value: - p.sadd("have_note_user_ids", i) - p.execute() - elif db_type() == "pickle": - with open("./data/users.pickle", "wb") as f: - pickle.dump(deque(value), f) +def load_from_path(path: str | PathLike | T, extend: Type[T | None] = type(None)) -> str | T: + if not isinstance(path, (str, PathLike, extend)): + raise TypeError(f"Invalid type for path: {type(path)}. Expected str, PathLike, or {extend.__name__}.") + if extend is not None and isinstance(path, extend): + return path + with open(path, "r", encoding="utf-8") as f: + return f.read() -def get_db(): - if db_type() == "redis": - import redis - dotenv.load_dotenv() - r = redis.from_url(os.getenv("DB_URL")) - return deque(map(lambda x: x.decode(), r.smembers("have_note_user_ids"))) - elif db_type() == "pickle": - try: - with open('./data/users.pickle', "rb") as f: - have_note_user_ids = pickle.load(f) - except FileNotFoundError: - have_note_user_ids = deque() - return have_note_user_ids +def load_from_json_path(path: str | PathLike | T, extend: Type[T | None] = type(None)) -> dict | T: + if isinstance(path, extend): + return path + return json.loads(load_from_path(path)) diff --git a/tests/test_emoji.py b/tests/test_emoji.py index 7636060..2e9b1df 100644 --- a/tests/test_emoji.py +++ b/tests/test_emoji.py @@ -1,20 +1,18 @@ import pytest -from src import emojis +import emojis +inputs = [ + ["aaa", "[Errno 2] No such file or directory: 'aaa'", FileNotFoundError], + [[], "Invalid type for path: . Expected str, PathLike, or NoneType.", TypeError], + [{"a": "b"}, "response.jsonは{'triggers': [], 'others': []}の形にしてください。", emojis.ConfigJsonError], + [{"triggers": [{}], "others": []}, "response.jsonのトリガーのキーはkeywordsとemojiにしてください。", emojis.ConfigJsonError] +] -messages = ["response.jsonは{'triggers': [], 'others': []}の形にしてください。", - "response.jsonのトリガーのキーはkeywordsとemojiにしてください。", - "[Errno 2] No such file or directory: 'aaaaaaaa'"] - -@pytest.mark.parametrize("input,msg", - (["aaaaaaaa", messages[2]], - [[], messages[0]], - [{"d": "a"}, messages[0]], - [{"triggers": [{}], "others": []}, messages[1]])) -def test_emojiset_error(input, msg): - error_cls = FileNotFoundError if isinstance(input, str) else emojis.ConfigJsonError - with pytest.raises(error_cls) as e: +@pytest.mark.parametrize("input,msg,exception", inputs) +def test_emojiset_error(input, msg, exception): + with pytest.raises(exception) as e: emojis.EmojiSet(input) + print(msg) assert str(e.value) == msg diff --git a/tests/test_misskey.py b/tests/test_misskey.py index 46abd3c..af9368c 100644 --- a/tests/test_misskey.py +++ b/tests/test_misskey.py @@ -1,8 +1,6 @@ import json -import sys -sys.path.append("./src") -from src import misskey_api as misskey # NOQA +import misskey_api as misskey # NOQA def test_can_renote(): diff --git a/tests/test_ngwords.py b/tests/test_ngwords.py index 79d43e5..4f2d801 100644 --- a/tests/test_ngwords.py +++ b/tests/test_ngwords.py @@ -1,16 +1,15 @@ -from src import ngwords - +import ngwords ng = {"r-18", "荒らす", "twitter", "ワード"} allow = {"除外ワード", "除外ワード2"} def test_ngwords(): - _ng = ngwords.NGWords("tests/test_ngwords.txt") - assert _ng.all_ng_words == ng - assert _ng.all_excluded_words == allow + words = ngwords.NGWords("tests/test_ngwords.txt") + assert words.all_ng_words == ng + assert words.all_excluded_words == allow for word in ng: - assert _ng.match(word) + assert words.match(word) for word in allow: - assert not _ng.match(word) - assert _ng.why("くぁwせdR-18rftgyふじこ") == "r-18" + assert not words.match(word) + assert words.why("くぁwせdR-18rftgyふじこ") == "r-18"