diff --git a/examples/basic.py b/examples/basic.py index aa8f005..157f2dc 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -20,6 +20,24 @@ machine_id=generate_machine_id_from_key("myMachineKeyIdentifier"), context=[{"key": "scenario-version", "value": "1.0.0"}], message="test message to see where it is written", + decisions=[ + { + "origin": "crowdsec", + "duration": "1h", + "scenario": "crowdsec/ssh-bf", + "scope": "ip", + "type": "ban", + "value": "81.81.81.81", + }, + { + "origin": "pysdk", + "duration": "2h", + "scenario": "crowdsec/ssh-bf", + "scope": "ip", + "type": "ban", + "value": "81.81.81.81", + }, + ], ) ] diff --git a/src/cscapi/sql_storage.py b/src/cscapi/sql_storage.py index 66e662d..4bbf47d 100644 --- a/src/cscapi/sql_storage.py +++ b/src/cscapi/sql_storage.py @@ -1,5 +1,5 @@ from dataclasses import asdict -from typing import List +from typing import List, Optional from dacite import from_dict from sqlalchemy import ( @@ -12,6 +12,7 @@ create_engine, delete, update, + event, ) from sqlalchemy.orm import ( DeclarativeBase, @@ -21,9 +22,23 @@ sessionmaker, ) +from sqlalchemy.engine import Engine from cscapi import storage +""" +By default, foreign key constraints are disabled in SQLite. +@see https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#foreign-key-support +""" + + +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + class Base(DeclarativeBase): def to_dict(self): return {c.name: getattr(self, c.name) for c in self.__table__.columns} @@ -33,7 +48,7 @@ class MachineDBModel(Base): __tablename__ = "machine_models" id = Column(Integer, primary_key=True, autoincrement=True) - machine_id = Column(TEXT) + machine_id = Column(TEXT, unique=True) token = Column(TEXT) password = Column(TEXT) scenarios = Column(TEXT) @@ -54,7 +69,7 @@ class DecisionDBModel(Base): type = Column(TEXT) value = Column(TEXT) signal_id: Mapped[int] = mapped_column( - "signal_id", ForeignKey("signal_models.alert_id") + "signal_id", ForeignKey("signal_models.alert_id", ondelete="CASCADE") ) @@ -71,6 +86,9 @@ class SourceDBModel(Base): value = Column(TEXT) as_name = Column(TEXT) longitude = Column(Float) + signal_id = Column( + Integer, ForeignKey("signal_models.alert_id", ondelete="CASCADE") + ) class ContextDBModel(Base): @@ -80,7 +98,7 @@ class ContextDBModel(Base): value = Column(TEXT) key = Column(TEXT) signal_id: Mapped[int] = mapped_column( - "signal_id", ForeignKey("signal_models.alert_id") + "signal_id", ForeignKey("signal_models.alert_id", ondelete="CASCADE") ) @@ -100,8 +118,6 @@ class SignalDBModel(Base): stop_at = Column(TEXT, nullable=True) sent = Column(Boolean, default=False) - source_id = Column(Integer, ForeignKey("source_models.id"), nullable=True) - context: Mapped[List["ContextDBModel"]] = relationship( "ContextDBModel", backref="signal" ) @@ -133,29 +149,29 @@ def get_all_signals(self) -> List[storage.SignalModel]: for res in self.session.query(SignalDBModel).all() ] - def get_machine_by_id(self, machine_id: str) -> storage.MachineModel: - exisiting = ( + def get_machine_by_id(self, machine_id: str) -> Optional[storage.MachineModel]: + existing = ( self.session.query(MachineDBModel) .filter(MachineDBModel.machine_id == machine_id) .first() ) - if not exisiting: - return + if not existing: + return None return storage.MachineModel( - machine_id=exisiting.machine_id, - token=exisiting.token, - password=exisiting.password, - scenarios=exisiting.scenarios, - is_failing=exisiting.is_failing, + machine_id=existing.machine_id, + token=existing.token, + password=existing.password, + scenarios=existing.scenarios, + is_failing=existing.is_failing, ) def update_or_create_machine(self, machine: storage.MachineModel) -> bool: - exisiting = ( + existing = ( self.session.query(MachineDBModel) .filter(MachineDBModel.machine_id == machine.machine_id) .all() ) - if not exisiting: + if not existing: self.session.add(MachineDBModel(**asdict(machine))) self.session.commit() return True @@ -193,19 +209,19 @@ def update_or_create_signal(self, signal: storage.SignalModel) -> bool: for dec in signal.decisions ] - exisiting = ( + existing = ( self.session.query(SignalDBModel) .filter(SignalDBModel.alert_id == signal.alert_id) .first() ) - if not exisiting: + if not existing: self.session.add(to_insert) self.session.commit() return True for c in to_insert.__table__.columns: - setattr(exisiting, c.name, getattr(to_insert, c.name)) + setattr(existing, c.name, getattr(to_insert, c.name)) self.session.commit() return False @@ -215,9 +231,11 @@ def delete_signals(self, signals: List[storage.SignalModel]): SignalDBModel.alert_id.in_((signal.alert_id for signal in signals)) ) self.session.execute(stmt) + self.session.commit() def delete_machines(self, machines: List[storage.MachineModel]): stmt = delete(MachineDBModel).where( - MachineDBModel.machine_id in ([machine.machine_id for machine in machines]) + MachineDBModel.machine_id.in_((machine.machine_id for machine in machines)) ) self.session.execute(stmt) + self.session.commit()