Skip to content

Commit

Permalink
Merge branch 'main' into feat/logger-config
Browse files Browse the repository at this point in the history
  • Loading branch information
julienloizelet committed Feb 16, 2024
2 parents 14c154b + b2ce126 commit adfe86d
Show file tree
Hide file tree
Showing 9 changed files with 210 additions and 67 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ functions provided by the `src/cscapi` folder.

---

## [0.3.0](https://github.com/crowdsecurity/python-capi-sdk/releases/tag/v0.3.0) - 2024-02-15
[_Compare with previous release_](https://github.com/crowdsecurity/python-capi-sdk/compare/v0.2.1...v0.3.0)


### Changed

- Use context manager for Sql session ([#20](https://github.com/crowdsecurity/python-capi-sdk/pull/20))
- **Breaking change**: The `session` attribute of `SQLStorage` is now an instance of the [sessionmaker](https://docs.sqlalchemy.org/en/20/orm/session_api.html#sqlalchemy.orm.sessionmaker) class and should be used as such.

---



## [0.2.1](https://github.com/crowdsecurity/python-capi-sdk/releases/tag/v0.2.1) - 2024-02-09
[_Compare with previous release_](https://github.com/crowdsecurity/python-capi-sdk/compare/v0.2.0...v0.2.1)

Expand Down
137 changes: 137 additions & 0 deletions examples/shell_scripts/add_signal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
This script will add a simple signal in database.
"""

import argparse
import json
import sys
from cscapi.client import CAPIClient, CAPIClientConfig
from cscapi.sql_storage import SQLStorage
from cscapi.utils import create_signal
from cscapi.utils import generate_machine_id_from_key


class CustomHelpFormatter(argparse.HelpFormatter):
def __init__(self, prog, indent_increment=2, max_help_position=48, width=None):
super().__init__(prog, indent_increment, max_help_position, width)


parser = argparse.ArgumentParser(
description="Script to add a simple signal.",
formatter_class=CustomHelpFormatter,
)

try:
parser.add_argument("--prod", action="store_true", help="Use production mode")
parser.add_argument(
"--human_machine_id",
type=str,
help="Human readable machine identifier. Will be converted in CrowdSec ID. Example: 'myMachineId'",
required=True,
)
parser.add_argument("--ip", type=str, help="Attacker IP", required=True)
parser.add_argument(
"--created_at",
type=str,
help="Signal's creation date. Example:'2024-01-26 10:20:46+0000'",
default="2024-01-26 10:20:46+0000",
)
parser.add_argument(
"--scenario",
type=str,
help="Signal's scenario. Example: 'crowdsecurity/ssh-bf'",
required=True,
)
parser.add_argument(
"--machine_scenarios",
type=str,
help='Json encoded list of scenarios. Example:"[\\"crowdsecurity/ssh-bf\\", \\"acme/http-bf\\"]"',
default='["crowdsecurity/ssh-bf", "acme/http-bf"]',
)
parser.add_argument(
"--user_agent_prefix", type=str, help="User agent prefix", default=None
)
parser.add_argument(
"--database",
type=str,
help="Local database name. Example: cscapi.db",
default=None,
)
parser.add_argument(
"--context",
type=str,
help='Json encoded context. Example:"[{\\"key\\":\\"key1\\", '
'\\"value\\":\\"value1\\"}, {\\"key\\":\\"key2\\", \\"value\\":\\"value2\\"}]"',
default=None,
)
args = parser.parse_args()
except argparse.ArgumentError as e:
print(e)
parser.print_usage()
sys.exit(2)
machine_id = generate_machine_id_from_key(args.human_machine_id)
machine_id_message = f"machine ID: '{machine_id}'"
ip_message = f"\tAttacker IP: '{args.ip}'\n"
created_at_message = f"\tCreated at: '{args.created_at}'\n"
scenario_message = f"\tScenario: '{args.scenario}'\n"
context_message = f"\tContext:{args.context}\n" if args.context else ""
machine_scenarios = (
json.loads(args.machine_scenarios) if args.machine_scenarios else None
)
context = json.loads(args.context) if args.context else None
user_agent_message = (
f"\tUser agent prefix:'{args.user_agent_prefix}'\n"
if args.user_agent_prefix
else ""
)
machine_scenarios_message = (
f"\tMachine's scenarios:{args.machine_scenarios}\n" if machine_scenarios else ""
)
env_message = "\tEnv: production\n" if args.prod else "\tEnv: development\n"

database = (
args.database
if args.database
else "cscapi_examples_prod.db" if args.prod else "cscapi_examples_dev.db"
)
database_message = f"\tLocal storage database: {database}\n"

print(
f"\nSending signal for {machine_id_message}\n\n"
f"Details:\n"
f"{env_message}"
f"{ip_message}"
f"{scenario_message}"
f"{created_at_message}"
f"{context_message}"
f"{machine_scenarios_message}"
f"{database_message}"
f"{user_agent_message}"
f"\n\n"
)

confirmation = input("Do you want to proceed? (Y/n): ")
if confirmation.lower() == "n":
print("Operation cancelled by the user.")
sys.exit()

client = CAPIClient(
storage=SQLStorage(connection_string=f"sqlite:///{database}"),
config=CAPIClientConfig(
scenarios=machine_scenarios,
prod=args.prod,
user_agent_prefix=args.user_agent_prefix,
),
)

signals = [
create_signal(
attacker_ip=args.ip,
scenario=args.scenario,
created_at=args.created_at,
machine_id=machine_id,
context=context if context else [],
)
]

client.add_signals(signals)
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
sqlalchemy
sqlalchemy>=1.4
python-dateutil
httpx
dacite
Expand Down
106 changes: 52 additions & 54 deletions src/cscapi/sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,49 +143,49 @@ class SQLStorage(storage.StorageInterface):
def __init__(self, connection_string="sqlite:///cscapi.db") -> None:
engine = create_engine(connection_string, echo=False)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
self.session = Session()
self.session = sessionmaker(bind=engine)

def get_all_signals(self) -> List[storage.SignalModel]:
return [
from_dict(storage.SignalModel, res.to_dict())
for res in self.session.query(SignalDBModel).all()
]
with self.session.begin() as session:
return [
from_dict(storage.SignalModel, res.to_dict())
for res in session.query(SignalDBModel).all()
]

def get_machine_by_id(self, machine_id: str) -> Optional[storage.MachineModel]:
existing = (
self.session.query(MachineDBModel)
.filter(MachineDBModel.machine_id == machine_id)
.first()
)
if not existing:
return None
return storage.MachineModel(
machine_id=existing.machine_id,
token=existing.token,
password=existing.password,
scenarios=existing.scenarios,
is_failing=existing.is_failing,
)
with self.session.begin() as session:
existing = (
session.query(MachineDBModel)
.filter(MachineDBModel.machine_id == machine_id)
.first()
)
if not existing:
return None
return storage.MachineModel(
machine_id=existing.machine_id,
token=existing.token,
password=existing.password,
scenarios=existing.scenarios,
is_failing=existing.is_failing,
)

def update_or_create_machine(self, machine: storage.MachineModel) -> bool:
existing = (
self.session.query(MachineDBModel)
.filter(MachineDBModel.machine_id == machine.machine_id)
.all()
)
if not existing:
self.session.add(MachineDBModel(**asdict(machine)))
self.session.commit()
return True

update_stmt = (
update(MachineDBModel)
.where(MachineDBModel.machine_id == machine.machine_id)
.values(**asdict(machine))
)
self.session.execute(update_stmt)
self.session.commit()
with self.session.begin() as session:
existing = (
session.query(MachineDBModel)
.filter(MachineDBModel.machine_id == machine.machine_id)
.all()
)
if not existing:
session.add(MachineDBModel(**asdict(machine)))
return True

update_stmt = (
update(MachineDBModel)
.where(MachineDBModel.machine_id == machine.machine_id)
.values(**asdict(machine))
)
session.execute(update_stmt)
return False

def update_or_create_signal(self, signal: storage.SignalModel) -> bool:
Expand All @@ -211,34 +211,32 @@ def update_or_create_signal(self, signal: storage.SignalModel) -> bool:
DecisionDBModel(**{"signal_id": to_insert.alert_id} | asdict(dec))
for dec in signal.decisions
]
with self.session.begin() as session:
existing = (
session.query(SignalDBModel)
.filter(SignalDBModel.alert_id == signal.alert_id)
.first()
)

existing = (
self.session.query(SignalDBModel)
.filter(SignalDBModel.alert_id == signal.alert_id)
.first()
)

if not existing:
self.session.add(to_insert)
self.session.commit()
return True
if not existing:
session.add(to_insert)
return True

for c in to_insert.__table__.columns:
setattr(existing, c.name, getattr(to_insert, c.name))
for c in to_insert.__table__.columns:
setattr(existing, c.name, getattr(to_insert, c.name))

self.session.commit()
return False

def delete_signals(self, signals: List[storage.SignalModel]):
stmt = delete(SignalDBModel).where(
SignalDBModel.alert_id.in_((signal.alert_id for signal in signals))
)
self.session.execute(stmt)
self.session.commit()
with self.session.begin() as session:
session.execute(stmt)

def delete_machines(self, machines: List[storage.MachineModel]):
stmt = delete(MachineDBModel).where(
MachineDBModel.machine_id.in_((machine.machine_id for machine in machines))
)
self.session.execute(stmt)
self.session.commit()
with self.session.begin() as session:
session.execute(stmt)
1 change: 0 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def storage():
db_name = f"{time.time()}.db"
storage = SQLStorage(f"sqlite:///{db_name}")
yield storage
storage.session.close()
try:
os.remove(db_name)
except:
Expand Down
18 changes: 7 additions & 11 deletions tests/test_sql_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,11 @@ def setUp(self) -> None:
def tearDown(self) -> None:
# postgresql, mysql, mariadb
if database_exists(self.db_uri):
if self.storage.session:
self.storage.session.close()
engine = create_engine(self.db_uri, poolclass=NullPool)
conn = engine.connect()
try:
drop_database(self.db_uri)
except Exception as e:
print(f"Error occurred while dropping the database: {e}")

conn.close()
engine.dispose()

# sqlite
try:
os.remove(self.db_path)
Expand Down Expand Up @@ -105,7 +98,8 @@ def test_update_machine(self):
machine_id="1", token="2", password="2", scenarios="crowdsecurity/http-bf"
)
self.storage.update_or_create_machine(m2)
self.assertEqual(1, self.storage.session.query(MachineDBModel).count())
with self.storage.session.begin() as session:
self.assertEqual(1, session.query(MachineDBModel).count())

retrieved = self.storage.get_machine_by_id("1")

Expand All @@ -124,14 +118,16 @@ def test_create_signal(self):
assert signal.alert_id is not None
assert signal.sent == False

assert self.storage.session.query(ContextDBModel).count() == 4
with self.storage.session.begin() as session:
assert session.query(SignalDBModel).count() == 1
assert session.query(ContextDBModel).count() == 4
assert session.query(DecisionDBModel).count() == 1
assert session.query(SourceDBModel).count() == 1
assert len(signal.context) == 4

assert self.storage.session.query(DecisionDBModel).count() == 1
assert len(signal.decisions) == 1

assert isinstance(signal.source, SourceModel)
assert self.storage.session.query(SourceDBModel).count() == 1

def test_update_signal(self):
assert self.storage.get_all_signals() == []
Expand Down

0 comments on commit adfe86d

Please sign in to comment.