diff --git a/backend/lib/config/getters.py b/backend/lib/config/getters.py index 9b4f0c5e..7d487a04 100644 --- a/backend/lib/config/getters.py +++ b/backend/lib/config/getters.py @@ -1,32 +1,19 @@ import os from lib import storage +from . import models -def get_web_credentials() -> dict: - return { - 'username': os.environ['ADMIN_USERNAME'], - 'password': os.environ['ADMIN_PASSWORD'], - } +def get_web_credentials() -> models.WebCredentials: + return models.WebCredentials() -def get_redis_config() -> dict: - return { - 'host': os.environ['REDIS_HOST'], - 'port': os.environ['REDIS_PORT'], - 'password': os.environ['REDIS_PASSWORD'], - 'db': 0, - } +def get_redis_config() -> models.Redis: + return models.Redis() -def get_db_config() -> dict: - return { - 'host': os.environ['POSTGRES_HOST'], - 'port': os.environ['POSTGRES_PORT'], - 'user': os.environ['POSTGRES_USER'], - 'password': os.environ['POSTGRES_PASSWORD'], - 'dbname': os.environ['POSTGRES_DB'], - } +def get_db_config() -> models.Database: + return models.Database() def get_broker_url() -> str: @@ -41,29 +28,11 @@ def get_broker_url() -> str: return broker_url -def get_celery_config() -> dict: - game_config = storage.game.get_current_game_config() - - host = os.environ['REDIS_HOST'] - port = os.environ['REDIS_PORT'] - password = os.environ['REDIS_PASSWORD'] - db = 1 - - result_backend = f'redis://:{password}@{host}:{port}/{db}' - - broker_url = get_broker_url() - - conf = { - 'accept_content': ['pickle'], - 'broker_url': broker_url, - 'result_backend': result_backend, - 'result_serializer': 'pickle', - 'task_serializer': 'pickle', - 'timezone': game_config.timezone, - 'worker_prefetch_multiplier': 1, - 'redis_socket_timeout': 10, - 'redis_socket_keepalive': True, - 'redis_retry_on_timeout': True, - } - - return conf +def get_celery_config() -> models.Celery: + redis_config = get_redis_config() + redis_config.db = 1 + return models.Celery( + broker_url=get_broker_url(), + result_backend=redis_config.url, + timezone=storage.game.get_current_game_config().timezone, + ) diff --git a/backend/lib/config/models.py b/backend/lib/config/models.py new file mode 100644 index 00000000..bde0384f --- /dev/null +++ b/backend/lib/config/models.py @@ -0,0 +1,48 @@ +import os +from typing import List + +from pydantic import BaseModel, Field + + +def env_field(key: str) -> Field: + return Field(default_factory=lambda: os.environ[key]) + + +class Redis(BaseModel): + host: str = env_field('REDIS_HOST') + port: int = env_field('REDIS_PORT') + password: str = env_field('REDIS_PASSWORD') + db: int = 0 + + @property + def url(self) -> str: + return f'redis://:{self.password}@{self.host}:{self.port}/{self.db}' + + +class WebCredentials(BaseModel): + username: str = env_field('ADMIN_USERNAME') + password: str = env_field('ADMIN_PASSWORD') + + +class Database(BaseModel): + host: str = env_field('POSTGRES_HOST') + port: int = env_field('POSTGRES_PORT') + user: str = env_field('POSTGRES_USER') + password: str = env_field('POSTGRES_PASSWORD') + dbname: str = env_field('POSTGRES_DB') + + +class Celery(BaseModel): + broker_url: str + result_backend: str + timezone: str + + worker_prefetch_multiplier: int = 1 + + redis_socket_timeout: int = 10 + redis_socket_keepalive: bool = True + redis_retry_on_timeout: bool = True + + accept_content: List[str] = ['pickle'] + result_serializer: str = 'pickle' + task_serializer: str = 'pickle' diff --git a/backend/lib/storage/attacks.py b/backend/lib/storage/attacks.py index 01f8bebf..0de1a91b 100644 --- a/backend/lib/storage/attacks.py +++ b/backend/lib/storage/attacks.py @@ -1,7 +1,7 @@ from lib import models, storage from lib.helpers import exceptions from lib.helpers.exceptions import FlagExceptionEnum -from lib.storage import utils +from lib.storage import utils, game from lib.storage.keys import CacheKeys @@ -40,14 +40,28 @@ def handle_attack( ) if flag is None: raise FlagExceptionEnum.FLAG_INVALID + if flag.team_id == attacker_id: + raise FlagExceptionEnum.FLAG_YOUR_OWN + + game_config = game.get_current_game_config() + if current_round - flag.round > game_config.flag_lifetime: + raise FlagExceptionEnum.FLAG_TOO_OLD result.victim_id = flag.team_id result.task_id = flag.task_id - storage.flags.try_add_stolen_flag( + success = storage.flags.try_add_stolen_flag( flag=flag, attacker=attacker_id, current_round=current_round, ) + if not success: + raise FlagExceptionEnum.FLAG_YOUR_OWN + + except exceptions.FlagSubmitException as e: + result.submit_ok = False + result.message = str(e) + + else: result.submit_ok = True with utils.db_cursor() as (conn, curs): @@ -67,7 +81,4 @@ def handle_attack( result.victim_delta = victim_delta result.message = f'Flag accepted! Earned {attacker_delta} flag points!' - except exceptions.FlagSubmitException as e: - result.message = str(e) - return result diff --git a/backend/lib/storage/flags.py b/backend/lib/storage/flags.py index 5cd25091..92428a78 100644 --- a/backend/lib/storage/flags.py +++ b/backend/lib/storage/flags.py @@ -3,7 +3,6 @@ from lib import models from lib.helpers.cache import cache_helper -from lib.helpers.exceptions import FlagExceptionEnum from lib.storage import caching, game, utils from lib.storage.keys import CacheKeys @@ -21,7 +20,7 @@ """ -def try_add_stolen_flag(flag: models.Flag, attacker: int, current_round: int) -> None: +def try_add_stolen_flag(flag: models.Flag, attacker: int, current_round: int) -> bool: """ Flag validation function. @@ -31,15 +30,7 @@ def try_add_stolen_flag(flag: models.Flag, attacker: int, current_round: int) -> :param flag: Flag model instance :param attacker: attacker team id :param current_round: current round - - :raises FlagSubmitException: on validation error """ - game_config = game.get_current_game_config() - if current_round - flag.round > game_config.flag_lifetime: - raise FlagExceptionEnum.FLAG_TOO_OLD - if flag.team_id == attacker: - raise FlagExceptionEnum.FLAG_YOUR_OWN - stolen_key = CacheKeys.team_stolen_flags(attacker) with utils.redis_pipeline(transaction=True) as pipe: # optimization of redis request count @@ -54,9 +45,7 @@ def try_add_stolen_flag(flag: models.Flag, attacker: int, current_round: int) -> ) is_new, = pipe.sadd(stolen_key, flag.id).execute() - - if not is_new: - raise FlagExceptionEnum.FLAG_ALREADY_STOLEN + return bool(is_new) def add_flag(flag: models.Flag) -> models.Flag: diff --git a/backend/lib/storage/utils.py b/backend/lib/storage/utils.py index a203bece..0a18c545 100644 --- a/backend/lib/storage/utils.py +++ b/backend/lib/storage/utils.py @@ -17,7 +17,7 @@ def create() -> pool.SimpleConnectionPool: return pool.SimpleConnectionPool( minconn=5, maxconn=20, - **database_config, + **database_config.dict(), ) @@ -26,8 +26,7 @@ class RedisStorage(Singleton[redis.Redis]): @staticmethod def create() -> redis.Redis: redis_config = config.get_redis_config() - redis_config['decode_responses'] = True - return redis.Redis(**redis_config) + return redis.Redis(decode_responses=True, **redis_config.dict()) class SIOManager(Singleton[socketio.KombuManager]): diff --git a/backend/requirements.txt b/backend/requirements.txt index a6db009a..2aceb048 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -10,6 +10,7 @@ hiredis==1.1.0 librabbitmq==2.0.0 kombu==5.0.2 prometheus-client==0.8.0 +pydantic==1.7.3 python-dateutil==2.8.1 python-socketio==5.0.4 pytz==2021.1 diff --git a/backend/services/admin/viewsets/authentication.py b/backend/services/admin/viewsets/authentication.py index d25a4946..ccca4eab 100644 --- a/backend/services/admin/viewsets/authentication.py +++ b/backend/services/admin/viewsets/authentication.py @@ -16,7 +16,7 @@ def check_session(): creds = config.get_web_credentials() - if data != creds['username']: + if data != creds.username: abort_with_error('Invalid session', 403) return True @@ -32,7 +32,7 @@ def login(): password = request.json.get('password') creds = config.get_web_credentials() - if username != creds['username'] or password != creds['password']: + if username != creds.username or password != creds.password: abort_with_error('Invalid credentials', 403) session = secrets.token_hex(32) diff --git a/backend/services/tasks/celery_factory.py b/backend/services/tasks/celery_factory.py index 804bcc63..d38b9297 100644 --- a/backend/services/tasks/celery_factory.py +++ b/backend/services/tasks/celery_factory.py @@ -13,5 +13,5 @@ def get_celery_app(): ], ) - app.conf.update(celery_config) + app.conf.update(celery_config.dict()) return app diff --git a/tests/test_flags.py b/tests/test_flags.py index bd384df1..d2196413 100644 --- a/tests/test_flags.py +++ b/tests/test_flags.py @@ -34,11 +34,11 @@ def setUp(self) -> None: self.unreachable_token = token database_config = config.get_db_config() - database_config['host'] = '127.0.0.1' + database_config.host = '127.0.0.1' self.db_pool = pool.SimpleConnectionPool( minconn=1, maxconn=20, - **database_config, + **database_config.dict(), ) def get_last_flags_from_db(self, team_token):