From 12cd95ae9f65ab635c582f4eb6f0fd8547ab8c7b Mon Sep 17 00:00:00 2001 From: Michael Butt <869064+m3mike@users.noreply.github.com> Date: Mon, 7 Feb 2022 20:26:43 -0500 Subject: [PATCH] chore(*): reformat codebase using pre-commit (#140) * chore(*): reformat codebase using pre-commit * chore(ci): fix flake8 exclusion after reformat, add sonar exclusion for tests * fix(ci): fix nosec issue caused by reformat --- .pre-commit-config.yaml | 2 + .sonarcloud.properties | 2 + LICENSE.txt | 2 +- NOTICE.txt | 2 +- docker/README.md | 2 +- pyproject.toml | 2 + src/archive/pipeline/regex.yml | 2 +- src/scripts/reformat_training_data.py | 363 +++++++++--------- src/tram/manage.py | 4 +- src/tram/tram/admin.py | 19 +- src/tram/tram/asgi.py | 2 +- .../tram/management/commands/attackdata.py | 92 +++-- src/tram/tram/management/commands/pipeline.py | 80 ++-- src/tram/tram/ml/base.py | 193 ++++++---- src/tram/tram/models.py | 89 +++-- src/tram/tram/serializers.py | 149 ++++--- src/tram/tram/templates/analyze.html | 2 +- src/tram/tram/templates/index.html | 2 +- src/tram/tram/templates/ml_home.html | 2 +- src/tram/tram/templates/model_detail.html | 2 +- .../tram/templates/registration/login.html | 2 +- .../tram/templates/technique_sentences.html | 2 +- .../tram/templates/tram_documentation.html | 2 +- src/tram/tram/urls.py | 37 +- src/tram/tram/views.py | 86 +++-- src/tram/tram/wsgi.py | 2 +- tests/conftest.py | 46 +-- tests/tram/test_base.py | 116 +++--- tests/tram/test_commands.py | 41 +- tests/tram/test_models.py | 14 +- tests/tram/test_views.py | 141 +++---- 31 files changed, 836 insertions(+), 666 deletions(-) create mode 100644 .sonarcloud.properties diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d2176289a9..a02a4961cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,5 @@ +exclude: '^.*\b(migrations)\b.*$|\.min\.js$|\.min\.css$|\.min\.*\b|\.svg|^tests/data|^data/|static/js' + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.0.1 diff --git a/.sonarcloud.properties b/.sonarcloud.properties new file mode 100644 index 0000000000..3ed4eb003d --- /dev/null +++ b/.sonarcloud.properties @@ -0,0 +1,2 @@ +# sonar settings +sonar.exclusions=tests/ diff --git a/LICENSE.txt b/LICENSE.txt index f49a4e16e6..261eeb9e9f 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/NOTICE.txt b/NOTICE.txt index b4a6e43e24..5379ff1a95 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -16,4 +16,4 @@ limitations under the License. This project makes use of ATT&CK® -ATT&CK® Terms of Use - https://attack.mitre.org/resources/terms-of-use/ \ No newline at end of file +ATT&CK® Terms of Use - https://attack.mitre.org/resources/terms-of-use/ diff --git a/docker/README.md b/docker/README.md index ec5b2503a7..64c56f63e1 100644 --- a/docker/README.md +++ b/docker/README.md @@ -24,7 +24,7 @@ services: - "8000:8000" environment: - DATA_DIRECTORY=/data - - ALLOWED_HOSTS=["example_host1", "localhost"] + - ALLOWED_HOSTS=["example_host1", "localhost"] - SECRET_KEY=Ij0WGee73k9OESwqddmSKCx6SY9aJ_7bDojs485Z6ec # your secret key here - DEBUG=True - DJANGO_SUPERUSER_USERNAME=djangoSuperuser diff --git a/pyproject.toml b/pyproject.toml index 57d0ea5efe..88c5c0ba79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,8 @@ exclude = ''' | venv | migrations | node_modules + | .tox + | .git )/ ''' diff --git a/src/archive/pipeline/regex.yml b/src/archive/pipeline/regex.yml index 17cf147d06..81a500ce93 100644 --- a/src/archive/pipeline/regex.yml +++ b/src/archive/pipeline/regex.yml @@ -67,4 +67,4 @@ - name: sha512 code: sha512_hash regex: | - \b[a-fA-f0-9]{128}\b \ No newline at end of file + \b[a-fA-f0-9]{128}\b diff --git a/src/scripts/reformat_training_data.py b/src/scripts/reformat_training_data.py index 6bf74e2884..9a95972dce 100644 --- a/src/scripts/reformat_training_data.py +++ b/src/scripts/reformat_training_data.py @@ -28,173 +28,173 @@ The target format is defined by tram.serializers.ReportExportSerializer """ -from datetime import datetime -from functools import partial import json import os import sys +from datetime import datetime +from functools import partial import django -sys.path.append('src/tram/') -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tram.settings') +sys.path.append("src/tram/") +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tram.settings") django.setup() from tram.serializers import ReportExportSerializer # noqa: E402 -outfile = 'data/training/bootstrap-training-data.json' +outfile = "data/training/bootstrap-training-data.json" ATTACK_LOOKUP = { # A mapping of attack descriptions to technique IDs - 'drive-by compromise': 'T1189', - 'system information discovery': 'T1082', - 'new service': 'T1543', - 'service execution': 'T1569.002', - 'command-line interface': 'T1059', # Maps to: T1059 - Command and Scripting Interpreter - 'obfuscated files or information': 'T1027', - 'custom cryptographic protocol': 'T1573', # Maps to: T1573 - Encrypted Channel - 'system network configuration discovery': 'T1016', - 'web shell': 'T1505.003', - 'application window discovery': 'T1010', - 'file deletion': 'T1070.004', # Technique that became a subtechnique - 'standard application layer protocol': 'T1071', - 'web service': 'T1102', - 'exfiltration over command and control channel': 'T1041', - 'fallback channels': 'T1008', - 'bypass user account control': 'T1548.002', # Technique that became a subtechnique - 'system time discovery': 'T1124', - 'deobfuscate/decode files or information': 'T1140', - 'disabling security tools': 'T1562.001', # Maps to: T1562.001 - Impair Defenses: Disable or Modify Tools - 'registry run keys / startup folder': 'T1547.001', - 'remote file copy': 'T1105', # Maps to: T1105 - Ingress Tool Transfer - 'dll search order hijacking': 'T1574.001', - 'screen capture': 'T1113', - 'file and directory discovery': 'T1083', - 'tor': 'S0183', # Software?? - 'shortcut modification': 'T1547.009', - 'remote services': 'T1021', - 'connection proxy': 'T1090', - 'data encoding': 'T1132', - 'spearphishing link': 'T1566.002', - 'spearphishing attachment': 'T1566.001', - 'arp': 'S0099', - 'user execution': 'T1204', - 'process hollowing': 'T1055.012', - 'execution through api': 'T1106', # Maps to T1106 - Native API - 'masquerading': 'T1036', - 'code signing': 'T1553.002', - 'standard cryptographic protocol': 'T1521', - 'scripting': 'T1059', - 'remote system discovery': 'T1018', - 'credential dumping': 'T1003', - 'exploitation for client execution': 'T1203', - 'exploitation for privilege escalation': 'T1068', - 'security software discovery': 'T1518.001', - 'data from local system': 'T1533', - 'remote desktop protocol': 'T1021.001', - 'data compressed': 'T1560', # Maps to T1560 - Archive Collected Data - 'software packing': 'T1027.002', - 'ping': 'S0097', - 'brute force': 'T1110', - 'commonly used port': 'T1571', - 'modify registry': 'T1112', - 'uncommonly used port': 'T1571', - 'process injection': 'T1055', - 'timestomp': 'T1070.006', - 'windows management instrumentation': 'T1047', - 'data staged': 'T1074', - 'rundll32': 'T1218.011', - 'regsvr32': 'T1218.010', - 'account discovery': 'T1087', - 'process discovery': 'T1057', - 'clipboard data': 'T1115', - 'binary padding': 'T1027.001', - 'pass the hash': 'T1550.002', - 'network service scanning': 'T1046', - 'system service discovery': 'T1007', - 'data encrypted': 'T1486', - 'system network connections discovery': 'T1049', - 'windows admin shares': 'T1021.002', - 'system owner/user discovery': 'T1033', - 'launch agent': 'T1543.001', - 'permission groups discovery': 'T1069', - 'indicator removal on host': 'T1070', - 'input capture': 'T1056', - 'virtualization/sandbox evasion': 'T1497.001', - 'dll side-loading': 'T1574.002', - 'scheduled task': 'T1053', - 'access token manipulation': 'T1134', - 'powershell': 'T1546.013', - 'exfiltration over alternative protocol': 'T1048', - 'hidden files and directories': 'T1564.001', - 'network share discovery': 'T1135', - 'query registry': 'T1012', - 'credentials in files': 'T1552.001', - 'audio capture': 'T1123', - 'video capture': 'T1125', - 'peripheral device discovery': 'T1120', - 'spearphishing via service': 'T1566.003', - 'data encrypted for impact': 'T1486', - 'data destruction': 'T1485', - 'template injection': 'T1221', - 'inhibit system recovery': 'T1490', - 'create account': 'T1136', - 'exploitation of remote services': 'T1210', - 'valid accounts': 'T1078', - 'dynamic data exchange': 'T1559.002', - 'office application startup': 'T1137', - 'data obfuscation': 'T1001', - 'domain trust discovery': 'T1482', - 'email collection': 'T1114', - 'man in the browser': 'T1185', - 'data from removable media': 'T1025', - 'bootkit': 'T1542.003', - 'logon scripts': 'T1037', - 'execution through module load': 'T1129', - 'llmnr/nbt-ns poisoning and relay': 'T1557.001', - 'external remote services': 'T1133', - 'domain fronting': 'T1090.004', - 'sid-history injection': 'T1134.005', - 'service stop': 'T1489', - 'disk structure wipe': 'T1561.002', - 'credentials in registry': 'T1552.002', - 'appinit dlls': 'T1546.010', - 'exploit public-facing application': 'T1190', - 'remote access tools': 'T1219', - 'signed binary proxy execution': 'T1218', - 'appcert dlls': 'T1546.009', - 'winlogon helper dll': 'T1547.004', - 'file permissions modification': 'T1222', - 'hooking': 'T1056.004', - 'system firmware': 'T1542.001', - 'lsass driver': 'T1547.008', - 'distributed component object model': 'T1021.003', - 'cmstp': 'T1218.003', - 'execution guardrails': 'T1480', - 'component object model hijacking': 'T1546.015', - 'accessibility features': 'T1546.008', # TODO: Help wanted - 'keychain': 'T1555.001', - 'mshta': 'T1218.005', - 'pass the ticket': 'T1550.003', - 'kerberoasting': 'T1558.003', - 'password policy discovery': 'T1201', - 'local job scheduling': 'T1053.001', - 'windows remote management': 'T1021.006', - 'bits jobs': 'T1197', - 'data from information repositories': 'T1213', - 'lc_load_dylib addition': 'T1546.006', - 'histcontrol': 'T1562.003', - 'file system logical offsets': 'T1006', - 'regsvcs/regasm': 'T1218.009', - 'exploitation for credential access': 'T1212', - 'sudo': 'T1548.003', - 'installutil': 'T1218.004', - 'query registry ': 'T1012', - 'launchctl': 'T1569.001', - '.bash_profile and .bashrc': 'T1546.004', - 'applescript': 'T1059.002', - 'emond': 'T1546.014', - 'control panel items': 'T1218.002', - 'application shimming': 'T1546.011', + "drive-by compromise": "T1189", + "system information discovery": "T1082", + "new service": "T1543", + "service execution": "T1569.002", + "command-line interface": "T1059", # Maps to: T1059 - Command and Scripting Interpreter + "obfuscated files or information": "T1027", + "custom cryptographic protocol": "T1573", # Maps to: T1573 - Encrypted Channel + "system network configuration discovery": "T1016", + "web shell": "T1505.003", + "application window discovery": "T1010", + "file deletion": "T1070.004", # Technique that became a subtechnique + "standard application layer protocol": "T1071", + "web service": "T1102", + "exfiltration over command and control channel": "T1041", + "fallback channels": "T1008", + "bypass user account control": "T1548.002", # Technique that became a subtechnique + "system time discovery": "T1124", + "deobfuscate/decode files or information": "T1140", + "disabling security tools": "T1562.001", # Maps to: T1562.001 - Impair Defenses: Disable or Modify Tools + "registry run keys / startup folder": "T1547.001", + "remote file copy": "T1105", # Maps to: T1105 - Ingress Tool Transfer + "dll search order hijacking": "T1574.001", + "screen capture": "T1113", + "file and directory discovery": "T1083", + "tor": "S0183", # Software?? + "shortcut modification": "T1547.009", + "remote services": "T1021", + "connection proxy": "T1090", + "data encoding": "T1132", + "spearphishing link": "T1566.002", + "spearphishing attachment": "T1566.001", + "arp": "S0099", + "user execution": "T1204", + "process hollowing": "T1055.012", + "execution through api": "T1106", # Maps to T1106 - Native API + "masquerading": "T1036", + "code signing": "T1553.002", + "standard cryptographic protocol": "T1521", + "scripting": "T1059", + "remote system discovery": "T1018", + "credential dumping": "T1003", + "exploitation for client execution": "T1203", + "exploitation for privilege escalation": "T1068", + "security software discovery": "T1518.001", + "data from local system": "T1533", + "remote desktop protocol": "T1021.001", + "data compressed": "T1560", # Maps to T1560 - Archive Collected Data + "software packing": "T1027.002", + "ping": "S0097", + "brute force": "T1110", + "commonly used port": "T1571", + "modify registry": "T1112", + "uncommonly used port": "T1571", + "process injection": "T1055", + "timestomp": "T1070.006", + "windows management instrumentation": "T1047", + "data staged": "T1074", + "rundll32": "T1218.011", + "regsvr32": "T1218.010", + "account discovery": "T1087", + "process discovery": "T1057", + "clipboard data": "T1115", + "binary padding": "T1027.001", + "pass the hash": "T1550.002", + "network service scanning": "T1046", + "system service discovery": "T1007", + "data encrypted": "T1486", + "system network connections discovery": "T1049", + "windows admin shares": "T1021.002", + "system owner/user discovery": "T1033", + "launch agent": "T1543.001", + "permission groups discovery": "T1069", + "indicator removal on host": "T1070", + "input capture": "T1056", + "virtualization/sandbox evasion": "T1497.001", + "dll side-loading": "T1574.002", + "scheduled task": "T1053", + "access token manipulation": "T1134", + "powershell": "T1546.013", + "exfiltration over alternative protocol": "T1048", + "hidden files and directories": "T1564.001", + "network share discovery": "T1135", + "query registry": "T1012", + "credentials in files": "T1552.001", + "audio capture": "T1123", + "video capture": "T1125", + "peripheral device discovery": "T1120", + "spearphishing via service": "T1566.003", + "data encrypted for impact": "T1486", + "data destruction": "T1485", + "template injection": "T1221", + "inhibit system recovery": "T1490", + "create account": "T1136", + "exploitation of remote services": "T1210", + "valid accounts": "T1078", + "dynamic data exchange": "T1559.002", + "office application startup": "T1137", + "data obfuscation": "T1001", + "domain trust discovery": "T1482", + "email collection": "T1114", + "man in the browser": "T1185", + "data from removable media": "T1025", + "bootkit": "T1542.003", + "logon scripts": "T1037", + "execution through module load": "T1129", + "llmnr/nbt-ns poisoning and relay": "T1557.001", + "external remote services": "T1133", + "domain fronting": "T1090.004", + "sid-history injection": "T1134.005", + "service stop": "T1489", + "disk structure wipe": "T1561.002", + "credentials in registry": "T1552.002", + "appinit dlls": "T1546.010", + "exploit public-facing application": "T1190", + "remote access tools": "T1219", + "signed binary proxy execution": "T1218", + "appcert dlls": "T1546.009", + "winlogon helper dll": "T1547.004", + "file permissions modification": "T1222", + "hooking": "T1056.004", + "system firmware": "T1542.001", + "lsass driver": "T1547.008", + "distributed component object model": "T1021.003", + "cmstp": "T1218.003", + "execution guardrails": "T1480", + "component object model hijacking": "T1546.015", + "accessibility features": "T1546.008", # TODO: Help wanted + "keychain": "T1555.001", + "mshta": "T1218.005", + "pass the ticket": "T1550.003", + "kerberoasting": "T1558.003", + "password policy discovery": "T1201", + "local job scheduling": "T1053.001", + "windows remote management": "T1021.006", + "bits jobs": "T1197", + "data from information repositories": "T1213", + "lc_load_dylib addition": "T1546.006", + "histcontrol": "T1562.003", + "file system logical offsets": "T1006", + "regsvcs/regasm": "T1218.009", + "exploitation for credential access": "T1212", + "sudo": "T1548.003", + "installutil": "T1218.004", + "query registry ": "T1012", + "launchctl": "T1569.001", + ".bash_profile and .bashrc": "T1546.004", + "applescript": "T1059.002", + "emond": "T1546.014", + "control panel items": "T1218.002", + "application shimming": "T1546.011", } @@ -215,14 +215,14 @@ def to_report_export_serializer_json(self): """Creates a dict that can be used to create a serializers.ReportExportSerializer instance """ - utc_now = datetime.utcnow().isoformat() + 'Z' + utc_now = datetime.utcnow().isoformat() + "Z" res_json = { - 'name': 'Bootstrap Training Data', - 'text': 'There is no text for this report. These sentences were mapped by human analysts.', - 'ml_model': 'humans', - 'created_on': utc_now, - 'updated_on': utc_now, - 'sentences': [] + "name": "Bootstrap Training Data", + "text": "There is no text for this report. These sentences were mapped by human analysts.", + "ml_model": "humans", + "created_on": utc_now, + "updated_on": utc_now, + "sentences": [], } order = 0 @@ -231,21 +231,18 @@ def to_report_export_serializer_json(self): continue sentence = { - 'text': sentence_text, - 'order': order, - 'disposition': 'accept', - 'mappings': [], + "text": sentence_text, + "order": order, + "disposition": "accept", + "mappings": [], } order += 1 for mapping in mappings: - mapping = { - 'attack_id': mapping, - 'confidence': '100.0' - } - sentence['mappings'].append(mapping) - res_json['sentences'].append(sentence) + mapping = {"attack_id": mapping, "confidence": "100.0"} + sentence["mappings"].append(mapping) + res_json["sentences"].append(sentence) return res_json @@ -258,20 +255,24 @@ def get_attack_id(description): def main(): - with open('data/training/archive/all_analyzed_reports.json') as f: + with open("data/training/archive/all_analyzed_reports.json") as f: all_analyzed_reports = json.load(f) - with open('data/training/archive/negative_data.json') as f: + with open("data/training/archive/negative_data.json") as f: negative_data = json.load(f) training_data = TrainingData() # Add the positives for key, value in all_analyzed_reports.items(): - if key.endswith('-multi'): # It's a multi-mapping, value is a dictionary - for sentence in value['sentances']: # Sentences is misspelled in the source data - map(partial(training_data.add_mapping, sentence), - [ATTACK_LOOKUP[name.lower()] for name in value['technique_names']]) + if key.endswith("-multi"): # It's a multi-mapping, value is a dictionary + for sentence in value[ + "sentances" + ]: # Sentences is misspelled in the source data + map( + partial(training_data.add_mapping, sentence), + [ATTACK_LOOKUP[name.lower()] for name in value["technique_names"]], + ) else: # It's a single-mapping, value is a list of sentences technique_id = get_attack_id(key) for sentence in value: @@ -285,11 +286,11 @@ def main(): res = ReportExportSerializer(data=res_json) res.is_valid(raise_exception=True) - with open(outfile, 'w') as f: + with open(outfile, "w") as f: json.dump(res.initial_data, f, indent=4) - print('Wrote data to %s' % outfile) + print("Wrote data to %s" % outfile) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/tram/manage.py b/src/tram/manage.py index ee76876f16..a83fedd617 100755 --- a/src/tram/manage.py +++ b/src/tram/manage.py @@ -6,7 +6,7 @@ def main(): """Run administrative tasks.""" - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tram.settings') + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tram.settings") try: from django.core.management import execute_from_command_line except ImportError as exc: @@ -18,5 +18,5 @@ def main(): execute_from_command_line(sys.argv) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/tram/tram/admin.py b/src/tram/tram/admin.py index b39459fcdc..359142be98 100644 --- a/src/tram/tram/admin.py +++ b/src/tram/tram/admin.py @@ -1,7 +1,14 @@ from django.contrib import admin -from tram.models import AttackObject, Document, DocumentProcessingJob, \ - Indicator, Mapping, Report, Sentence +from tram.models import ( + AttackObject, + Document, + DocumentProcessingJob, + Indicator, + Mapping, + Report, + Sentence, +) class IndicatorInline(admin.TabularInline): @@ -17,11 +24,11 @@ class MappingInline(admin.TabularInline): class SentenceInline(admin.TabularInline): extra = 0 model = Sentence - readonly_fields = ('text', 'document', 'order') + readonly_fields = ("text", "document", "order") class AttackObjectAdmin(admin.ModelAdmin): - readonly_fields = ('name', 'stix_id', 'attack_id', 'attack_url', 'matrix') + readonly_fields = ("name", "stix_id", "attack_id", "attack_url", "matrix") class DocumentAdmin(admin.ModelAdmin): @@ -30,11 +37,11 @@ class DocumentAdmin(admin.ModelAdmin): class ReportAdmin(admin.ModelAdmin): inlines = [IndicatorInline, MappingInline] - readonly_fields = ('document', 'text', 'ml_model') + readonly_fields = ("document", "text", "ml_model") class SentenceAdmin(admin.ModelAdmin): - readonly_fields = ('text', 'document', 'order') + readonly_fields = ("text", "document", "order") admin.site.register(AttackObject, AttackObjectAdmin) diff --git a/src/tram/tram/asgi.py b/src/tram/tram/asgi.py index f0128d4eab..a559a941e9 100644 --- a/src/tram/tram/asgi.py +++ b/src/tram/tram/asgi.py @@ -11,6 +11,6 @@ from django.core.asgi import get_asgi_application -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tram.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tram.settings") application = get_asgi_application() diff --git a/src/tram/tram/management/commands/attackdata.py b/src/tram/tram/management/commands/attackdata.py index b793128c98..934b0d21d0 100644 --- a/src/tram/tram/management/commands/attackdata.py +++ b/src/tram/tram/management/commands/attackdata.py @@ -5,58 +5,66 @@ from tram.models import AttackObject -LOAD = 'load' -CLEAR = 'clear' +LOAD = "load" +CLEAR = "clear" STIX_TYPE_TO_ATTACK_TYPE = { - 'attack-pattern': 'technique', - 'course-of-action': 'mitigation', - 'intrusion-set': 'group', - 'malware': 'software', - 'tool': 'software', - 'x-mitre-tactic': 'tactic', + "attack-pattern": "technique", + "course-of-action": "mitigation", + "intrusion-set": "group", + "malware": "software", + "tool": "software", + "x-mitre-tactic": "tactic", } class Command(BaseCommand): - help = 'Machine learning pipeline commands' + help = "Machine learning pipeline commands" def add_arguments(self, parser): - sp = parser.add_subparsers(title='subcommands', - dest='subcommand', - required=True) - sp_load = sp.add_parser(LOAD, help='Load ATT&CK Data into the Database') # noqa: F841 - sp_clear = sp.add_parser(CLEAR, help='Clear ATT&CK Data from the Database') # noqa: F841 + sp = parser.add_subparsers( + title="subcommands", dest="subcommand", required=True + ) + sp_load = sp.add_parser( # noqa: F841 + LOAD, help="Load ATT&CK Data into the Database" + ) + sp_clear = sp.add_parser( # noqa: F841 + CLEAR, help="Clear ATT&CK Data from the Database" + ) def clear_attack_data(self): deleted = AttackObject.objects.all().delete() - print(f'Deleted {deleted[0]} Attack objects') + print(f"Deleted {deleted[0]} Attack objects") def create_attack_object(self, obj): - for external_reference in obj['external_references']: - if external_reference['source_name'] not in ('mitre-attack', 'mitre-pre-attack', 'mitre-mobile-attack'): + for external_reference in obj["external_references"]: + if external_reference["source_name"] not in ( + "mitre-attack", + "mitre-pre-attack", + "mitre-mobile-attack", + ): continue - attack_id = external_reference['external_id'] - attack_url = external_reference['url'] - matrix = external_reference['source_name'] + attack_id = external_reference["external_id"] + attack_url = external_reference["url"] + matrix = external_reference["source_name"] assert attack_id is not None assert attack_url is not None assert matrix is not None - stix_type = obj['type'] + stix_type = obj["type"] attack_type = STIX_TYPE_TO_ATTACK_TYPE[stix_type] obj, created = AttackObject.objects.get_or_create( - name=obj['name'], - stix_id=obj['id'], + name=obj["name"], + stix_id=obj["id"], stix_type=stix_type, attack_id=attack_id, attack_type=attack_type, attack_url=attack_url, - matrix=matrix + matrix=matrix, ) return obj, created @@ -65,21 +73,27 @@ def load_attack_data(self, filepath): created_stats = {} skipped_stats = {} - with open(filepath, 'r') as f: + with open(filepath, "r") as f: attack_json = json.load(f) - assert attack_json['spec_version'] == '2.0' - assert attack_json['type'] == 'bundle' + assert attack_json["spec_version"] == "2.0" + assert attack_json["type"] == "bundle" - for obj in attack_json['objects']: - obj_type = obj['type'] + for obj in attack_json["objects"]: + obj_type = obj["type"] # TODO: Skip deprecated objects - if obj.get('revoked', False): # Skip revoked objects + if obj.get("revoked", False): # Skip revoked objects skipped_stats[obj_type] = skipped_stats.get(obj_type, 0) + 1 continue - if obj_type in ('relationship', 'course-of-action', 'identity', 'x-mitre-matrix', 'marking-definition'): + if obj_type in ( + "relationship", + "course-of-action", + "identity", + "x-mitre-matrix", + "marking-definition", + ): skipped_stats[obj_type] = skipped_stats.get(obj_type, 0) + 1 continue @@ -92,21 +106,23 @@ def load_attack_data(self, filepath): except ValueError: # Value error means unsupported object type skipped_stats[obj_type] = skipped_stats.get(obj_type, 0) + 1 - print(f'Load stats for {filepath}:') + print(f"Load stats for {filepath}:") for k, v in created_stats.items(): - print(f'\tCreated {v} {k} objects') + print(f"\tCreated {v} {k} objects") for k, v in skipped_stats.items(): - print(f'\tSkipped {v} {k} objects') + print(f"\tSkipped {v} {k} objects") def handle(self, *args, **options): - subcommand = options['subcommand'] + subcommand = options["subcommand"] if subcommand == LOAD: # Note - as of ATT&CK v8.2 # Techniques are unique among files, but # Groups are not unique among files - self.load_attack_data(settings.DATA_DIRECTORY / 'attack/enterprise-attack.json') - self.load_attack_data(settings.DATA_DIRECTORY / 'attack/mobile-attack.json') - self.load_attack_data(settings.DATA_DIRECTORY / 'attack/pre-attack.json') + self.load_attack_data( + settings.DATA_DIRECTORY / "attack/enterprise-attack.json" + ) + self.load_attack_data(settings.DATA_DIRECTORY / "attack/mobile-attack.json") + self.load_attack_data(settings.DATA_DIRECTORY / "attack/pre-attack.json") elif subcommand == CLEAR: self.clear_attack_data() diff --git a/src/tram/tram/management/commands/pipeline.py b/src/tram/tram/management/commands/pipeline.py index 9bebb61da8..2cffc3b019 100644 --- a/src/tram/tram/management/commands/pipeline.py +++ b/src/tram/tram/management/commands/pipeline.py @@ -4,68 +4,80 @@ from django.core.files import File from django.core.management.base import BaseCommand - -from tram.ml import base import tram.models as db_models from tram import serializers +from tram.ml import base - -ADD = 'add' -RUN = 'run' -TRAIN = 'train' -LOAD_TRAINING_DATA = 'load-training-data' +ADD = "add" +RUN = "run" +TRAIN = "train" +LOAD_TRAINING_DATA = "load-training-data" class Command(BaseCommand): - help = 'Machine learning pipeline commands' + help = "Machine learning pipeline commands" def add_arguments(self, parser): - sp = parser.add_subparsers(title='subcommands', - dest='subcommand', - required=True) - sp_run = sp.add_parser(RUN, help='Run the ML Pipeline') - sp_run.add_argument('--model', default='logreg', help='Select the ML model.') - sp_run.add_argument('--run-forever', default=False, action='store_true', - help='Specify whether to run forever, or quit when there are no more jobs to process') - sp_train = sp.add_parser(TRAIN, help='Train the ML Pipeline') # noqa: F841 - sp_train.add_argument('--model', default='logreg', help='Select the ML model.') - sp_add = sp.add_parser(ADD, help='Add a document for processing by the ML pipeline') - sp_add.add_argument('--file', required=True, help='Specify the file to be added') - sp_load = sp.add_parser(LOAD_TRAINING_DATA, help='Load training data. Must be formatted as a Report Export.') - sp_load.add_argument('--file', default='data/training/bootstrap-training-data.json', - help='Training data file to load. Defaults: data/training/bootstrap-training-data.json') + sp = parser.add_subparsers( + title="subcommands", dest="subcommand", required=True + ) + sp_run = sp.add_parser(RUN, help="Run the ML Pipeline") + sp_run.add_argument("--model", default="logreg", help="Select the ML model.") + sp_run.add_argument( + "--run-forever", + default=False, + action="store_true", + help="Specify whether to run forever, or quit when there are no more jobs to process", + ) + sp_train = sp.add_parser(TRAIN, help="Train the ML Pipeline") # noqa: F841 + sp_train.add_argument("--model", default="logreg", help="Select the ML model.") + sp_add = sp.add_parser( + ADD, help="Add a document for processing by the ML pipeline" + ) + sp_add.add_argument( + "--file", required=True, help="Specify the file to be added" + ) + sp_load = sp.add_parser( + LOAD_TRAINING_DATA, + help="Load training data. Must be formatted as a Report Export.", + ) + sp_load.add_argument( + "--file", + default="data/training/bootstrap-training-data.json", + help="Training data file to load. Defaults: data/training/bootstrap-training-data.json", + ) def handle(self, *args, **options): - subcommand = options['subcommand'] + subcommand = options["subcommand"] if subcommand == ADD: - filepath = options['file'] - with open(filepath, 'rb') as f: + filepath = options["file"] + with open(filepath, "rb") as f: django_file = File(f) db_models.DocumentProcessingJob.create_from_file(django_file) - self.stdout.write(f'Added file to ML Pipeline: {filepath}') + self.stdout.write(f"Added file to ML Pipeline: {filepath}") return if subcommand == LOAD_TRAINING_DATA: - filepath = options['file'] - self.stdout.write(f'Loading training data from {filepath}') - with open(filepath, 'r') as f: + filepath = options["file"] + self.stdout.write(f"Loading training data from {filepath}") + with open(filepath, "r") as f: res = serializers.ReportExportSerializer(data=json.load(f)) res.is_valid(raise_exception=True) res.save() return - model = options['model'] + model = options["model"] model_manager = base.ModelManager(model) if subcommand == RUN: - self.stdout.write(f'Running ML Pipeline with Model: {model}') - return model_manager.run_model(options['run_forever']) + self.stdout.write(f"Running ML Pipeline with Model: {model}") + return model_manager.run_model(options["run_forever"]) elif subcommand == TRAIN: - self.stdout.write(f'Training ML Model: {model}') + self.stdout.write(f"Training ML Model: {model}") start = time.time() return_value = model_manager.train_model() end = time.time() elapsed = end - start - self.stdout.write(f'Trained ML model in {elapsed} seconds') + self.stdout.write(f"Trained ML model in {elapsed} seconds") return return_value diff --git a/src/tram/tram/ml/base.py b/src/tram/tram/ml/base.py index 9fa9ccc02f..6a7b9acbe0 100644 --- a/src/tram/tram/ml/base.py +++ b/src/tram/tram/ml/base.py @@ -1,20 +1,20 @@ -from abc import ABC, abstractmethod -from datetime import datetime, timezone -from io import BytesIO -from os import path import pathlib import pickle +import re import time import traceback +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from io import BytesIO +from os import path -from bs4 import BeautifulSoup -from constance import config -from django.db import transaction -from django.conf import settings import docx import nltk import pdfplumber -import re +from bs4 import BeautifulSoup +from constance import config +from django.conf import settings +from django.db import transaction from sklearn.dummy import DummyClassifier from sklearn.feature_extraction.text import CountVectorizer from sklearn.linear_model import LogisticRegression @@ -40,7 +40,7 @@ def __init__(self, confidence=0.0, attack_id=None): self.attack_id = attack_id def __repr__(self): - return 'Confidence=%f; Attack ID=%s' % (self.confidence, self.attack_id) + return "Confidence=%f; Attack ID=%s" % (self.confidence, self.attack_id) class Report(object): @@ -55,6 +55,7 @@ class SKLearnModel(ABC): TODO: 1. Move text extraction and tokenization out of the SKLearnModel """ + def __init__(self): self.techniques_model = self.get_model() self.last_trained = None @@ -62,12 +63,13 @@ def __init__(self): self.detailed_f1_score = None if not isinstance(self.techniques_model, Pipeline): - raise TypeError('get_model() must return an sklearn.pipeline.Pipeline instance') + raise TypeError( + "get_model() must return an sklearn.pipeline.Pipeline instance" + ) @abstractmethod def get_model(self): - """Returns an sklearn.Pipeline that has fit() and predict() methods - """ + """Returns an sklearn.Pipeline that has fit() and predict() methods""" def train(self): """ @@ -86,8 +88,9 @@ def test(self): X, y = self.get_training_data() # Create training set and test set - X_train, X_test, y_train, y_test = \ - train_test_split(X, y, test_size=0.2, shuffle=True, random_state=0, stratify=y) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, shuffle=True, random_state=0, stratify=y + ) # Train model test_model = self.get_model() @@ -100,26 +103,28 @@ def test(self): # Calculate an f1 score for each technique labels = sorted(list(set(y))) scores = f1_score(y_test, y_predicted, labels=list(set(y)), average=None) - self.detailed_f1_score = sorted(zip(labels, scores), key=lambda t: t[1], reverse=True) + self.detailed_f1_score = sorted( + zip(labels, scores), key=lambda t: t[1], reverse=True + ) # Average F1 score across techniques, weighted by the # of training examples per technique - weighted_f1 = f1_score(y_test, y_predicted, average='weighted') + weighted_f1 = f1_score(y_test, y_predicted, average="weighted") self.average_f1_score = weighted_f1 def _get_report_name(self, job): name = pathlib.Path(job.document.docfile.path).name - return 'Report for %s' % name + return "Report for %s" % name def _extract_text(self, document): suffix = pathlib.Path(document.docfile.path).suffix - if suffix == '.pdf': + if suffix == ".pdf": text = self._extract_pdf_text(document) - elif suffix == '.docx': + elif suffix == ".docx": text = self._extract_docx_text(document) - elif suffix == '.html': + elif suffix == ".html": text = self._extract_html_text(document) else: - raise ValueError('Unknown file suffix: %s' % suffix) + raise ValueError("Unknown file suffix: %s" % suffix) return text @@ -132,8 +137,12 @@ def lemmatize(self, sentence): lemma = nltk.stem.WordNetLemmatizer() # Lemmatize each word in sentence - lemmatized_sentence = ' '.join([lemma.lemmatize(w) for w in sentence.rstrip().split()]) - lemmatized_sentence = re.sub(r'\d+', '', lemmatized_sentence) # Remove digits with regex + lemmatized_sentence = " ".join( + [lemma.lemmatize(w) for w in sentence.rstrip().split()] + ) + lemmatized_sentence = re.sub( + r"\d+", "", lemmatized_sentence + ) # Remove digits with regex return lemmatized_sentence @@ -153,9 +162,14 @@ def get_training_data(self): return X, y def get_attack_object_ids(self): - objects = [obj.attack_id for obj in db_models.AttackObject.objects.all().order_by('attack_id')] + objects = [ + obj.attack_id + for obj in db_models.AttackObject.objects.all().order_by("attack_id") + ] if len(objects) == 0: - raise ValueError('Zero techniques found. Maybe run `python manage.py attackdata load` ?') + raise ValueError( + "Zero techniques found. Maybe run `python manage.py attackdata load` ?" + ) return objects def get_mappings(self, sentence): @@ -165,7 +179,9 @@ def get_mappings(self, sentence): mappings = [] techniques = self.techniques_model.classes_ - probs = self.techniques_model.predict_proba([sentence])[0] # Probability is a range between 0-1 + probs = self.techniques_model.predict_proba([sentence])[ + 0 + ] # Probability is a range between 0-1 # Create a list of tuples of (confidence, technique) confidences_and_techniques = zip(probs, techniques) @@ -185,7 +201,7 @@ def _sentence_tokenize(self, text): def _extract_pdf_text(self, document): with pdfplumber.open(BytesIO(document.docfile.read())) as pdf: - text = ''.join(page.extract_text() for page in pdf.pages) + text = "".join(page.extract_text() for page in pdf.pages) return text @@ -197,7 +213,7 @@ def _extract_html_text(self, document): def _extract_docx_text(self, document): parsed_docx = docx.Document(BytesIO(document.docfile.read())) - text = ' '.join([paragraph.text for paragraph in parsed_docx.paragraphs]) + text = " ".join([paragraph.text for paragraph in parsed_docx.paragraphs]) return text def process_job(self, job): @@ -217,24 +233,29 @@ def process_job(self, job): return report def save_to_file(self, filepath): - with open(filepath, 'wb') as f: + with open(filepath, "wb") as f: pickle.dump(self, f) @classmethod def load_from_file(cls, filepath): - with open(filepath, 'rb') as f: - model = pickle.load(f) # nosec - Accept the risk until a better design is implemented - + with open(filepath, "rb") as f: + model = pickle.load(f) # nosec + # accept risk until better design implemented assert cls == model.__class__ return model class DummyModel(SKLearnModel): def get_model(self): - return Pipeline([ - ("features", CountVectorizer(lowercase=True, stop_words='english', min_df=3)), - ("clf", DummyClassifier(strategy='uniform')) - ]) + return Pipeline( + [ + ( + "features", + CountVectorizer(lowercase=True, stop_words="english", min_df=3), + ), + ("clf", DummyClassifier(strategy="uniform")), + ] + ) class NaiveBayesModel(SKLearnModel): @@ -244,10 +265,15 @@ def get_model(self): 1) Features = document-term matrix, with stop words removed from the term vocabulary. 2) Classifier (clf) = multinomial Naive Bayes """ - return Pipeline([ - ("features", CountVectorizer(lowercase=True, stop_words='english', min_df=3)), - ("clf", MultinomialNB()) - ]) + return Pipeline( + [ + ( + "features", + CountVectorizer(lowercase=True, stop_words="english", min_df=3), + ), + ("clf", MultinomialNB()), + ] + ) class LogisticRegressionModel(SKLearnModel): @@ -257,38 +283,43 @@ def get_model(self): 1) Features = document-term matrix, with stop words removed from the term vocabulary. 2) Classifier (clf) = multinomial logistic regression """ - return Pipeline([ - ("features", CountVectorizer(lowercase=True, stop_words='english', min_df=3)), - ("clf", LogisticRegression()) - ]) + return Pipeline( + [ + ( + "features", + CountVectorizer(lowercase=True, stop_words="english", min_df=3), + ), + ("clf", LogisticRegression()), + ] + ) class ModelManager(object): model_registry = { # TODO: Add a hook to register user-created models - 'dummy': DummyModel, - 'nb': NaiveBayesModel, - 'logreg': LogisticRegressionModel, + "dummy": DummyModel, + "nb": NaiveBayesModel, + "logreg": LogisticRegressionModel, } def __init__(self, model): model_class = self.model_registry.get(model) if not model_class: - raise ValueError('Unrecognized model: %s' % model) + raise ValueError("Unrecognized model: %s" % model) model_filepath = self.get_model_filepath(model_class) if path.exists(model_filepath): self.model = model_class.load_from_file(model_filepath) - print('%s loaded from %s' % (model_class.__name__, model_filepath)) + print("%s loaded from %s" % (model_class.__name__, model_filepath)) else: self.model = model_class() - print('%s loaded from __init__' % model_class.__name__) + print("%s loaded from __init__" % model_class.__name__) def _save_report(self, report, document): rpt = db_models.Report( name=report.name, document=document, text=report.text, - ml_model=self.model.__class__.__name__ + ml_model=self.model.__class__.__name__, ) rpt.save() @@ -304,7 +335,9 @@ def _save_report(self, report, document): for mapping in sentence.mappings: if mapping.attack_id: - obj = db_models.AttackObject.objects.get(attack_id=mapping.attack_id) + obj = db_models.AttackObject.objects.get( + attack_id=mapping.attack_id + ) else: obj = None @@ -318,21 +351,23 @@ def _save_report(self, report, document): def run_model(self, run_forever=False): while True: - jobs = db_models.DocumentProcessingJob.objects.filter(status='queued').order_by('created_on') + jobs = db_models.DocumentProcessingJob.objects.filter( + status="queued" + ).order_by("created_on") for job in jobs: filename = job.document.docfile.name - print('Processing Job #%d: %s' % (job.id, filename)) + print("Processing Job #%d: %s" % (job.id, filename)) try: report = self.model.process_job(job) with transaction.atomic(): self._save_report(report, job.document) job.delete() - print('Created report %s' % report.name) + print("Created report %s" % report.name) except Exception as ex: - job.status = 'error' + job.status = "error" job.message = str(ex) job.save() - print(f'Failed to create report for {filename}.') + print(f"Failed to create report for {filename}.") print(traceback.format_exc()) if not run_forever: @@ -340,7 +375,7 @@ def run_model(self, run_forever=False): time.sleep(1) def get_model_filepath(self, model_class): - filepath = settings.ML_MODEL_DIR + '/' + model_class.__name__ + '.pkl' + filepath = settings.ML_MODEL_DIR + "/" + model_class.__name__ + ".pkl" return filepath def train_model(self): @@ -348,7 +383,7 @@ def train_model(self): self.model.test() filepath = self.get_model_filepath(self.model.__class__) self.model.save_to_file(filepath) - print('Trained model saved to %s' % filepath) + print("Trained model saved to %s" % filepath) return @staticmethod @@ -361,7 +396,9 @@ def get_all_model_metadata(): model_metadata = ModelManager.get_model_metadata(model_key) all_model_metadata.append(model_metadata) - all_model_metadata = sorted(all_model_metadata, key=lambda i: i['average_f1_score'], reverse=True) + all_model_metadata = sorted( + all_model_metadata, key=lambda i: i["average_f1_score"], reverse=True + ) return all_model_metadata @@ -373,34 +410,38 @@ def get_model_metadata(model_key): mm = ModelManager(model_key) model_name = mm.model.__class__.__name__ if mm.model.last_trained is None: - last_trained = 'Never trained' + last_trained = "Never trained" trained_techniques_count = 0 else: - last_trained = mm.model.last_trained.strftime('%m/%d/%Y %H:%M:%S UTC') + last_trained = mm.model.last_trained.strftime("%m/%d/%Y %H:%M:%S UTC") trained_techniques_count = len(mm.model.detailed_f1_score) average_f1_score = round((mm.model.average_f1_score or 0.0) * 100, 2) stored_scores = mm.model.detailed_f1_score or [] attack_ids = set([score[0] for score in stored_scores]) - attack_techniques = db_models.AttackObject.objects.filter(attack_id__in=attack_ids) + attack_techniques = db_models.AttackObject.objects.filter( + attack_id__in=attack_ids + ) detailed_f1_score = [] for score in stored_scores: score_id = score[0] score_value = round(score[1] * 100, 2) attack_technique = attack_techniques.get(attack_id=score_id) - detailed_f1_score.append({ - 'technique': score_id, - 'technique_name': attack_technique.name, - 'attack_url': attack_technique.attack_url, - 'score': score_value - }) + detailed_f1_score.append( + { + "technique": score_id, + "technique_name": attack_technique.name, + "attack_url": attack_technique.attack_url, + "score": score_value, + } + ) model_metadata = { - 'model_key': model_key, - 'name': model_name, - 'last_trained': last_trained, - 'trained_techniques_count': trained_techniques_count, - 'average_f1_score': average_f1_score, - 'detailed_f1_score': detailed_f1_score, + "model_key": model_key, + "name": model_name, + "last_trained": last_trained, + "trained_techniques_count": trained_techniques_count, + "average_f1_score": average_f1_score, + "detailed_f1_score": detailed_f1_score, } return model_metadata diff --git a/src/tram/tram/models.py b/src/tram/tram/models.py index f1b2a7fd4c..60367ad9a7 100644 --- a/src/tram/tram/models.py +++ b/src/tram/tram/models.py @@ -9,22 +9,22 @@ from django.dispatch.dispatcher import receiver DISPOSITION_CHOICES = ( - ('accept', 'Accepted'), - ('reject', 'Rejected'), - (None, 'No Disposition'), + ("accept", "Accepted"), + ("reject", "Rejected"), + (None, "No Disposition"), ) JOB_STATUS_CHOICES = ( - ('queued', 'Queued'), - ('error', 'Error'), + ("queued", "Queued"), + ("error", "Error"), ) SENTENCE_PREVIEW_CHARS = 40 class AttackObject(models.Model): - """Attack Techniques - """ + """Attack Techniques""" + name = models.CharField(max_length=200) stix_id = models.CharField(max_length=128, unique=True) stix_type = models.CharField(max_length=128) @@ -35,7 +35,7 @@ class AttackObject(models.Model): created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) - sentences = models.ManyToManyField('Sentence', through='Mapping') + sentences = models.ManyToManyField("Sentence", through="Mapping") @classmethod def get_sentence_counts(cls, accept_threshold=0): @@ -44,20 +44,28 @@ def get_sentence_counts(cls, accept_threshold=0): return: The list of AttackTechnique objects, annotated with how many training sentences have been accepted, pending, and there are in total. """ - sentence_counts = cls.objects.annotate( - accepted_sentences=Count('sentences', filter=Q(sentences__disposition='accept')), - pending_sentences=Count('sentences', filter=Q(sentences__disposition=None)), - total_sentences=Count('sentences') - ).order_by('-accepted_sentences', 'attack_id').filter(accepted_sentences__gte=accept_threshold) + sentence_counts = ( + cls.objects.annotate( + accepted_sentences=Count( + "sentences", filter=Q(sentences__disposition="accept") + ), + pending_sentences=Count( + "sentences", filter=Q(sentences__disposition=None) + ), + total_sentences=Count("sentences"), + ) + .order_by("-accepted_sentences", "attack_id") + .filter(accepted_sentences__gte=accept_threshold) + ) return sentence_counts def __str__(self): - return '(%s) %s' % (self.attack_id, self.name) + return "(%s) %s" % (self.attack_id, self.name) class Document(models.Model): - """Store all documents that can be analyzed to create reports - """ + """Store all documents that can be analyzed to create reports""" + docfile = models.FileField() created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) @@ -68,11 +76,13 @@ def __str__(self): class DocumentProcessingJob(models.Model): - """Queue of document processing jobs - """ + """Queue of document processing jobs""" + document = models.ForeignKey(Document, on_delete=models.CASCADE) - status = models.CharField(max_length=255, default='queued', choices=JOB_STATUS_CHOICES) - message = models.CharField(max_length=16384, default='') + status = models.CharField( + max_length=255, default="queued", choices=JOB_STATUS_CHOICES + ) + message = models.CharField(max_length=16384, default="") created_by = models.ForeignKey(User, null=True, on_delete=models.SET_NULL) created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) @@ -87,12 +97,12 @@ def create_from_file(cls, f): return dpj def __str__(self): - return 'Process %s' % self.document.docfile.name + return "Process %s" % self.document.docfile.name class Report(models.Model): - """Store reports - """ + """Store reports""" + name = models.CharField(max_length=200) document = models.ForeignKey(Document, null=True, on_delete=models.CASCADE) text = models.TextField() @@ -106,8 +116,8 @@ def __str__(self): class Indicator(models.Model): - """Indicators extracted from a document for a report - """ + """Indicators extracted from a document for a report""" + report = models.ForeignKey(Report, on_delete=models.CASCADE) indicator_type = models.CharField(max_length=200) value = models.CharField(max_length=200) @@ -115,31 +125,37 @@ class Indicator(models.Model): updated_on = models.DateTimeField(auto_now=True) def __str__(self): - return '%s: %s' % (self.indicator_type, self.value) + return "%s: %s" % (self.indicator_type, self.value) class Sentence(models.Model): text = models.TextField() document = models.ForeignKey(Document, null=True, on_delete=models.CASCADE) - order = models.IntegerField(default=1000) # Sentences with lower numbers are displayed first + order = models.IntegerField( + default=1000 + ) # Sentences with lower numbers are displayed first report = models.ForeignKey(Report, on_delete=models.CASCADE) - disposition = models.CharField(max_length=200, default=None, null=True, blank=True, choices=DISPOSITION_CHOICES) + disposition = models.CharField( + max_length=200, default=None, null=True, blank=True, choices=DISPOSITION_CHOICES + ) created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) def __str__(self): - append = '' + append = "" if len(self.text) > SENTENCE_PREVIEW_CHARS: - append = '...' + append = "..." return self.text[:SENTENCE_PREVIEW_CHARS] + append class Mapping(models.Model): - """Maps sentences to Attack TTPs - """ + """Maps sentences to Attack TTPs""" + report = models.ForeignKey(Report, on_delete=models.CASCADE) sentence = models.ForeignKey(Sentence, on_delete=models.CASCADE) - attack_object = models.ForeignKey(AttackObject, on_delete=models.CASCADE, blank=True, null=True) + attack_object = models.ForeignKey( + AttackObject, on_delete=models.CASCADE, blank=True, null=True + ) confidence = models.FloatField() created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) @@ -150,15 +166,16 @@ def __str__(self): @classmethod def get_accepted_mappings(cls): # Get Attack techniques that have the required amount of positive examples - attack_objects = AttackObject.get_sentence_counts(accept_threshold=config.ML_ACCEPT_THRESHOLD) + attack_objects = AttackObject.get_sentence_counts( + accept_threshold=config.ML_ACCEPT_THRESHOLD + ) # Get mappings for the attack techniques above threshold mappings = Mapping.objects.filter(attack_object__in=attack_objects) return mappings class MLSettings(models.Model): - """Settings for Machine Learning models - """ + """Settings for Machine Learning models""" def _delete_file(path): diff --git a/src/tram/tram/serializers.py b/src/tram/tram/serializers.py index 60475ee857..f17e73bdb5 100644 --- a/src/tram/tram/serializers.py +++ b/src/tram/tram/serializers.py @@ -7,35 +7,48 @@ class AttackObjectSerializer(serializers.ModelSerializer): class Meta: model = db_models.AttackObject - fields = ['id', 'attack_id', 'name'] + fields = ["id", "attack_id", "name"] class DocumentProcessingJobSerializer(serializers.ModelSerializer): """Needs to be kept in sync with ReportSerializer for display purposes""" + name = serializers.SerializerMethodField() byline = serializers.SerializerMethodField() status = serializers.SerializerMethodField() class Meta: model = db_models.DocumentProcessingJob - fields = ['id', 'name', 'byline', 'status', 'message', 'created_by', 'created_on', 'updated_on'] - order = ['-created_on'] + fields = [ + "id", + "name", + "byline", + "status", + "message", + "created_by", + "created_on", + "updated_on", + ] + order = ["-created_on"] def get_name(self, obj): name = obj.document.docfile.name return name def get_byline(self, obj): - byline = '%s on %s' % (obj.created_by, obj.created_on.strftime('%Y-%M-%d %H:%M:%S UTC')) + byline = "%s on %s" % ( + obj.created_by, + obj.created_on.strftime("%Y-%M-%d %H:%M:%S UTC"), + ) return byline def get_status(self, obj): - if obj.status == 'queued': - return 'Queued' - elif obj.status == 'error': - return 'Error' + if obj.status == "queued": + return "Queued" + elif obj.status == "error": + return "Error" else: - return 'Unknown' + return "Unknown" class MappingSerializer(serializers.ModelSerializer): @@ -45,7 +58,7 @@ class MappingSerializer(serializers.ModelSerializer): class Meta: model = db_models.Mapping - fields = ['id', 'attack_id', 'name', 'confidence'] + fields = ["id", "attack_id", "name", "confidence"] def get_attack_id(self, obj): return obj.attack_object.attack_id @@ -63,24 +76,26 @@ def to_internal_value(self, data): internal_value = super().to_internal_value(data) # Keeps model fields # Add necessary fields - attack_object = db_models.AttackObject.objects.get(attack_id=data['attack_id']) - sentence = db_models.Sentence.objects.get(id=data['sentence']) - report = db_models.Report.objects.get(id=data['report']) - - internal_value.update({ - 'report': report, - 'sentence': sentence, - 'attack_object': attack_object, - }) + attack_object = db_models.AttackObject.objects.get(attack_id=data["attack_id"]) + sentence = db_models.Sentence.objects.get(id=data["sentence"]) + report = db_models.Report.objects.get(id=data["report"]) + + internal_value.update( + { + "report": report, + "sentence": sentence, + "attack_object": attack_object, + } + ) return internal_value def create(self, validated_data): mapping = db_models.Mapping.objects.create( - report=validated_data['report'], - sentence=validated_data['sentence'], - attack_object=validated_data['attack_object'], - confidence=validated_data['confidence'] + report=validated_data["report"], + sentence=validated_data["sentence"], + attack_object=validated_data["attack_object"], + confidence=validated_data["confidence"], ) return mapping @@ -95,12 +110,26 @@ class ReportSerializer(serializers.ModelSerializer): class Meta: model = db_models.Report - fields = ['id', 'name', 'byline', 'accepted_sentences', 'reviewing_sentences', 'total_sentences', - 'text', 'ml_model', 'created_by', 'created_on', 'updated_on', 'status'] - order = ['-created_on'] + fields = [ + "id", + "name", + "byline", + "accepted_sentences", + "reviewing_sentences", + "total_sentences", + "text", + "ml_model", + "created_by", + "created_on", + "updated_on", + "status", + ] + order = ["-created_on"] def get_accepted_sentences(self, obj): - count = db_models.Sentence.objects.filter(disposition='accept', report=obj).count() + count = db_models.Sentence.objects.filter( + disposition="accept", report=obj + ).count() return count def get_reviewing_sentences(self, obj): @@ -112,26 +141,32 @@ def get_total_sentences(self, obj): return count def get_byline(self, obj): - byline = '%s on %s' % (obj.created_by, obj.created_on.strftime('%Y-%m-%d %H:%M:%S UTC')) + byline = "%s on %s" % ( + obj.created_by, + obj.created_on.strftime("%Y-%m-%d %H:%M:%S UTC"), + ) return byline def get_status(self, obj): reviewing_sentences = self.get_reviewing_sentences(obj) - status = 'Reviewing' + status = "Reviewing" if reviewing_sentences == 0: - status = 'Accepted' + status = "Accepted" return status class ReportExportSerializer(ReportSerializer): """Defines the export format for reports. Defined separately from ReportSerializer so that: - 1. ReportSerializer and ReportExportSerializer can evolve independently - 2. The export is larger than what the REST API needs + 1. ReportSerializer and ReportExportSerializer can evolve independently + 2. The export is larger than what the REST API needs """ + sentences = serializers.SerializerMethodField() class Meta(ReportSerializer.Meta): - fields = ReportSerializer.Meta.fields + ['sentences', ] + fields = ReportSerializer.Meta.fields + [ + "sentences", + ] def get_sentences(self, obj): sentences = db_models.Sentence.objects.filter(report=obj) @@ -148,28 +183,30 @@ def to_internal_value(self, data): internal_value = super().to_internal_value(data) # Keeps model fields # Add sentences - sentence_serializers = [SentenceSerializer(data=sentence) for sentence in data.get('sentences', [])] + sentence_serializers = [ + SentenceSerializer(data=sentence) for sentence in data.get("sentences", []) + ] - internal_value.update({'sentences': sentence_serializers}) + internal_value.update({"sentences": sentence_serializers}) return internal_value def create(self, validated_data): with transaction.atomic(): report = db_models.Report.objects.create( - name=validated_data['name'], - document=None, - text=validated_data['text'], - ml_model=validated_data['ml_model'], - created_by=None, # TODO: Get user from session - ) - - for sentence in validated_data['sentences']: + name=validated_data["name"], + document=None, + text=validated_data["text"], + ml_model=validated_data["ml_model"], + created_by=None, # TODO: Get user from session + ) + + for sentence in validated_data["sentences"]: if sentence.is_valid(): - sentence.validated_data['report'] = report + sentence.validated_data["report"] = report sentence.save() else: # TODO: Handle this case better - raise Exception('Sentence validation needs to be handled better') + raise Exception("Sentence validation needs to be handled better") return report @@ -182,7 +219,7 @@ class SentenceSerializer(serializers.ModelSerializer): class Meta: model = db_models.Sentence - fields = ['id', 'text', 'order', 'disposition', 'mappings'] + fields = ["id", "text", "order", "disposition", "mappings"] def get_mappings(self, obj): mappings = db_models.Mapping.objects.filter(sentence=obj) @@ -199,27 +236,29 @@ def to_internal_value(self, data): internal_value = super().to_internal_value(data) # Keeps model fields # Add mappings - mapping_serializers = [MappingSerializer(data=mapping) for mapping in data.get('mappings', [])] + mapping_serializers = [ + MappingSerializer(data=mapping) for mapping in data.get("mappings", []) + ] - internal_value.update({'mappings': mapping_serializers}) + internal_value.update({"mappings": mapping_serializers}) return internal_value def create(self, validated_data): with transaction.atomic(): sentence = db_models.Sentence.objects.create( - text=validated_data['text'], + text=validated_data["text"], document=None, - report=validated_data['report'], - disposition=validated_data['disposition'] + report=validated_data["report"], + disposition=validated_data["disposition"], ) - for mapping in validated_data.get('mappings', []): - mapping.initial_data['sentence'] = sentence.id - mapping.initial_data['report'] = validated_data['report'].id + for mapping in validated_data.get("mappings", []): + mapping.initial_data["sentence"] = sentence.id + mapping.initial_data["report"] = validated_data["report"].id if mapping.is_valid(): mapping.save() else: # TODO: Handle this case better - raise Exception('Mapping validation needs to be handled better') + raise Exception("Mapping validation needs to be handled better") return sentence diff --git a/src/tram/tram/templates/analyze.html b/src/tram/tram/templates/analyze.html index cba23cebd2..637c266791 100644 --- a/src/tram/tram/templates/analyze.html +++ b/src/tram/tram/templates/analyze.html @@ -53,4 +53,4 @@ -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/src/tram/tram/templates/index.html b/src/tram/tram/templates/index.html index 02ebfd0aeb..98e3dc3f9f 100644 --- a/src/tram/tram/templates/index.html +++ b/src/tram/tram/templates/index.html @@ -86,4 +86,4 @@

Reports

{% endfor %} -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/src/tram/tram/templates/ml_home.html b/src/tram/tram/templates/ml_home.html index c66567eac0..ae7dedce0c 100644 --- a/src/tram/tram/templates/ml_home.html +++ b/src/tram/tram/templates/ml_home.html @@ -79,4 +79,4 @@

ML Settings Login Help

- \ No newline at end of file + diff --git a/src/tram/tram/templates/technique_sentences.html b/src/tram/tram/templates/technique_sentences.html index 1c40548309..858416e0a8 100644 --- a/src/tram/tram/templates/technique_sentences.html +++ b/src/tram/tram/templates/technique_sentences.html @@ -56,4 +56,4 @@ -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/src/tram/tram/templates/tram_documentation.html b/src/tram/tram/templates/tram_documentation.html index 5a1e0d4c3a..5f4c5b9b05 100644 --- a/src/tram/tram/templates/tram_documentation.html +++ b/src/tram/tram/templates/tram_documentation.html @@ -117,4 +117,4 @@

Processing Pipeline

-{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/src/tram/tram/urls.py b/src/tram/tram/urls.py index e162080e61..837ce08b50 100644 --- a/src/tram/tram/urls.py +++ b/src/tram/tram/urls.py @@ -17,32 +17,31 @@ from django.conf.urls.static import static from django.contrib import admin from django.contrib.auth import views as auth_views -from django.urls import path, include +from django.urls import include, path from django.views.generic.base import TemplateView from rest_framework.routers import DefaultRouter from tram import views - router = DefaultRouter() -router.register(r'attack', views.AttackObjectViewSet) -router.register(r'jobs', views.DocumentProcessingJobViewSet) -router.register(r'mappings', views.MappingViewSet) -router.register(r'reports', views.ReportViewSet) -router.register(r'report-export', views.ReportExportViewSet) -router.register(r'sentences', views.SentenceViewSet) +router.register(r"attack", views.AttackObjectViewSet) +router.register(r"jobs", views.DocumentProcessingJobViewSet) +router.register(r"mappings", views.MappingViewSet) +router.register(r"reports", views.ReportViewSet) +router.register(r"report-export", views.ReportExportViewSet) +router.register(r"sentences", views.SentenceViewSet) urlpatterns = [ - path('', views.index), - path('analyze//', views.analyze), - path('api/', include(router.urls)), - path('docs/', TemplateView.as_view(template_name='tram_documentation.html')), - path('login/', auth_views.LoginView.as_view()), - path('logout/', auth_views.LogoutView.as_view()), - path('upload/', views.upload), - path('admin/', admin.site.urls), - path('ml/', views.ml_home), - path('ml/techniques/', views.ml_technique_sentences), - path('ml/models/', views.ml_model_detail), + path("", views.index), + path("analyze//", views.analyze), + path("api/", include(router.urls)), + path("docs/", TemplateView.as_view(template_name="tram_documentation.html")), + path("login/", auth_views.LoginView.as_view()), + path("logout/", auth_views.LogoutView.as_view()), + path("upload/", views.upload), + path("admin/", admin.site.urls), + path("ml/", views.ml_home), + path("ml/techniques/", views.ml_technique_sentences), + path("ml/models/", views.ml_model_detail), ] + static(settings.STATIC_URL, document_root=settings.STATIC_ROOT) diff --git a/src/tram/tram/views.py b/src/tram/tram/views.py index bd0b418071..ee35e50e93 100644 --- a/src/tram/tram/views.py +++ b/src/tram/tram/views.py @@ -2,7 +2,7 @@ from constance import config from django.contrib.auth.decorators import login_required -from django.http import HttpResponse, HttpResponseBadRequest, Http404 +from django.http import Http404, HttpResponse, HttpResponseBadRequest from django.shortcuts import render from django.utils.text import slugify from rest_framework import viewsets @@ -28,7 +28,7 @@ class MappingViewSet(viewsets.ModelViewSet): def get_queryset(self): queryset = MappingViewSet.queryset - sentence_id = self.request.query_params.get('sentence-id', None) + sentence_id = self.request.query_params.get("sentence-id", None) if sentence_id: queryset = queryset.filter(sentence__id=sentence_id) @@ -46,8 +46,8 @@ class ReportExportViewSet(viewsets.ModelViewSet): def retrieve(self, request, *args, **kwargs): response = super().retrieve(request, *args, **kwargs) - filename = slugify(self.get_object().name) + '.json' - response['Content-Disposition'] = 'attachment; filename="%s"' % filename + filename = slugify(self.get_object().name) + ".json" + response["Content-Disposition"] = 'attachment; filename="%s"' % filename return response @@ -57,13 +57,15 @@ class SentenceViewSet(viewsets.ModelViewSet): def get_queryset(self): queryset = SentenceViewSet.queryset - report_id = self.request.query_params.get('report-id', None) + report_id = self.request.query_params.get("report-id", None) if report_id: queryset = queryset.filter(report__id=report_id) - attack_id = self.request.query_params.get('attack-id', None) + attack_id = self.request.query_params.get("attack-id", None) if attack_id: - sentences = Mapping.objects.filter(attack_technique__attack_id=attack_id).values('sentence') + sentences = Mapping.objects.filter( + attack_technique__attack_id=attack_id + ).values("sentence") queryset = queryset.filter(id__in=sentences) return queryset @@ -78,28 +80,28 @@ def index(request): report_serializer = serializers.ReportSerializer(reports, many=True) context = { - 'job_queue': job_serializer.data, - 'reports': report_serializer.data, + "job_queue": job_serializer.data, + "reports": report_serializer.data, } - return render(request, 'index.html', context=context) + return render(request, "index.html", context=context) @login_required def upload(request): - """Places a file into ml-pipeline for analysis - """ - if request.method != 'POST': - return HttpResponse('Request method must be POST', status=405) - - file_content_type = request.FILES['file'].content_type - if file_content_type in ('application/pdf', # .pdf files - 'text/html', # .html files - 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', # .docx files - ): - DocumentProcessingJob.create_from_file(request.FILES['file']) - elif file_content_type in ('application/json', ): # .json files - json_data = json.loads(request.FILES['file'].read()) + """Places a file into ml-pipeline for analysis""" + if request.method != "POST": + return HttpResponse("Request method must be POST", status=405) + + file_content_type = request.FILES["file"].content_type + if file_content_type in ( + "application/pdf", # .pdf files + "text/html", # .html files + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", # .docx files + ): + DocumentProcessingJob.create_from_file(request.FILES["file"]) + elif file_content_type in ("application/json",): # .json files + json_data = json.loads(request.FILES["file"].read()) res = serializers.ReportExportSerializer(data=json_data) if res.is_valid(): @@ -107,9 +109,9 @@ def upload(request): else: return HttpResponseBadRequest(res.errors) else: - return HttpResponseBadRequest('Unsupported file type') + return HttpResponseBadRequest("Unsupported file type") - return HttpResponse('File saved for processing', status=200) + return HttpResponse("File saved for processing", status=200) @login_required @@ -118,19 +120,19 @@ def ml_home(request): model_metadata = base.ModelManager.get_all_model_metadata() context = { - 'techniques': techniques, - 'ML_ACCEPT_THRESHOLD': config.ML_ACCEPT_THRESHOLD, - 'ML_CONFIDENCE_THRESHOLD': config.ML_CONFIDENCE_THRESHOLD, - 'models': model_metadata - } + "techniques": techniques, + "ML_ACCEPT_THRESHOLD": config.ML_ACCEPT_THRESHOLD, + "ML_CONFIDENCE_THRESHOLD": config.ML_CONFIDENCE_THRESHOLD, + "models": model_metadata, + } - return render(request, 'ml_home.html', context) + return render(request, "ml_home.html", context) @login_required def ml_technique_sentences(request, attack_id): - context = {'attack_id': attack_id} - return render(request, 'technique_sentences.html', context) + context = {"attack_id": attack_id} + return render(request, "technique_sentences.html", context) @login_required @@ -138,20 +140,20 @@ def ml_model_detail(request, model_key): try: model_metadata = base.ModelManager.get_model_metadata(model_key) except ValueError: - raise Http404('Model does not exists') - context = {'model': model_metadata} - return render(request, 'model_detail.html', context) + raise Http404("Model does not exists") + context = {"model": model_metadata} + return render(request, "model_detail.html", context) @login_required def analyze(request, pk): report = Report.objects.get(id=pk) - techniques = AttackObject.objects.all().order_by('attack_id') + techniques = AttackObject.objects.all().order_by("attack_id") tecniques_serializer = serializers.AttackObjectSerializer(techniques, many=True) context = { - 'report_id': report.id, - 'report_name': report.name, - 'attack_techniques': tecniques_serializer.data, - } - return render(request, 'analyze.html', context) + "report_id": report.id, + "report_name": report.name, + "attack_techniques": tecniques_serializer.data, + } + return render(request, "analyze.html", context) diff --git a/src/tram/tram/wsgi.py b/src/tram/tram/wsgi.py index ae0b2192b6..1c31827d24 100644 --- a/src/tram/tram/wsgi.py +++ b/src/tram/tram/wsgi.py @@ -11,6 +11,6 @@ from django.core.wsgi import get_wsgi_application -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tram.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tram.settings") application = get_wsgi_application() diff --git a/tests/conftest.py b/tests/conftest.py index 2366826411..fb7a46ce56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,19 @@ import glob -from django.core.files.base import File import pytest +from django.core.files.base import File -from tram.management.commands import attackdata, pipeline from tram import models +from tram.management.commands import attackdata, pipeline @pytest.fixture(scope="session", autouse=True) def verify_test_data_directory_is_empty(request): - files = glob.glob('data/media/tests/data/*') + files = glob.glob("data/media/tests/data/*") if len(files) > 0: - raise ValueError('data/media/tests/data/ is not empty! Remove contents to run tests.') + raise ValueError( + "data/media/tests/data/ is not empty! Remove contents to run tests." + ) @pytest.fixture @@ -23,7 +25,7 @@ def load_attack_data(): @pytest.fixture def load_small_training_data(): options = { - 'file': 'data/training/bootstrap-training-data-small.json', + "file": "data/training/bootstrap-training-data-small.json", } command = pipeline.Command() command.handle(subcommand=pipeline.LOAD_TRAINING_DATA, **options) @@ -31,7 +33,7 @@ def load_small_training_data(): @pytest.fixture def document(): - with open('tests/data/simple-test.docx', 'rb') as f: + with open("tests/data/simple-test.docx", "rb") as f: d = models.Document(docfile=File(f)) d.save() yield d @@ -41,13 +43,13 @@ def document(): @pytest.fixture def attack_object(): at = models.AttackObject( - name='Use multiple DNS infrastructures', - stix_id='attack-pattern--616238cb-990b-4c71-8f50-d8b10ed8ce6b', - stix_type='attack-pattern', - attack_id='T1327', - attack_type='technique', - attack_url='https://attack.mitre.org/techniques/T1327', - matrix='mitre-pre-attack', + name="Use multiple DNS infrastructures", + stix_id="attack-pattern--616238cb-990b-4c71-8f50-d8b10ed8ce6b", + stix_type="attack-pattern", + attack_id="T1327", + attack_type="technique", + attack_url="https://attack.mitre.org/techniques/T1327", + matrix="mitre-pre-attack", ) at.save() yield at @@ -57,9 +59,9 @@ def attack_object(): @pytest.fixture def report(document): rpt = models.Report( - name='Test report name', + name="Test report name", document=document, - text='test-document-text', + text="test-document-text", ) rpt.save() yield rpt @@ -77,9 +79,7 @@ def document_processing_job(document): @pytest.fixture def indicator(report): ind = models.Indicator( - report=report, - indicator_type='MD5', - value='54b0c58c7ce9f2a8b551351102ee0938' + report=report, indicator_type="MD5", value="54b0c58c7ce9f2a8b551351102ee0938" ) ind.save() yield ind @@ -89,7 +89,7 @@ def indicator(report): @pytest.fixture def sentence(report): s = models.Sentence( - text='test-text', + text="test-text", document=report.document, order=0, report=report, @@ -103,13 +103,13 @@ def sentence(report): @pytest.fixture def simple_training_data(report, load_attack_data): s = models.Sentence( - text='test-text', + text="test-text", report=report, document=report.document, - disposition='accept', + disposition="accept", ) s.save() - at = models.AttackTechnique.objects.get(attack_id='T1327') + at = models.AttackTechnique.objects.get(attack_id="T1327") m = models.Mapping( report=report, sentence=s, @@ -125,7 +125,7 @@ def simple_training_data(report, load_attack_data): @pytest.fixture def long_sentence(report): s = models.Sentence( - text='this sentence is long and should trigger the overflow', + text="this sentence is long and should trigger the overflow", document=report.document, order=0, report=report, diff --git a/tests/tram/test_base.py b/tests/tram/test_base.py index 1bc28a749e..48a0e5fb70 100644 --- a/tests/tram/test_base.py +++ b/tests/tram/test_base.py @@ -1,9 +1,9 @@ -from django.core.files import File -from constance import config import pytest +from constance import config +from django.core.files import File -from tram.ml import base import tram.models as db_models +from tram.ml import base @pytest.fixture @@ -14,7 +14,7 @@ def dummy_model(): class TestSentence: def test_sentence_stores_no_mapping(self): # Arrange - text = 'this is text' + text = "this is text" order = 0 mappings = None @@ -31,8 +31,8 @@ class TestMapping: def test_mapping_repr_is_correct(self): # Arrange confidence = 95.342000 - attack_id = 'T1327' - expected = 'Confidence=95.342000; Attack ID=T1327' + attack_id = "T1327" + expected = "Confidence=95.342000; Attack ID=T1327" # Act m = base.Mapping(confidence, attack_id) @@ -43,18 +43,12 @@ def test_mapping_repr_is_correct(self): class TestReport: def test_report_stores_properties(self): # Arrange - name = 'Test report' - text = 'Test report text' - sentences = [ - base.Sentence('test sentence text', 0, None) - ] + name = "Test report" + text = "Test report text" + sentences = [base.Sentence("test sentence text", 0, None)] # Act - rpt = base.Report( - name=name, - text=text, - sentences=sentences - ) + rpt = base.Report(name=name, text=text, sentences=sentences) # Assert assert rpt.name == name @@ -65,6 +59,7 @@ def test_report_stores_properties(self): @pytest.mark.django_db class TestModelWithoutAttackData: """Tests ml.base.Model via DummyModel, without the load_attack_data fixture""" + def test_get_attack_techniques_raises_if_not_initialized(self, dummy_model): # Act / Assert with pytest.raises(ValueError): @@ -72,7 +67,7 @@ def test_get_attack_techniques_raises_if_not_initialized(self, dummy_model): @pytest.mark.django_db -@pytest.mark.usefixtures('load_attack_data') +@pytest.mark.usefixtures("load_attack_data") class TestSkLearnModel: """Tests ml.base.SKLearnModel via DummyModel""" @@ -81,9 +76,14 @@ def test__sentence_tokenize_works_for_paragraph(self, dummy_model): paragraph = """Hello. My name is test. I write sentences. Tokenize, tokenize, tokenize! When will this entralling text stop, praytell? Nobody knows; the author can't stop. """ - expected = ['Hello.', 'My name is test.', 'I write sentences.', - 'Tokenize, tokenize, tokenize!', 'When will this entralling text stop, praytell?', - 'Nobody knows; the author can\'t stop.'] + expected = [ + "Hello.", + "My name is test.", + "I write sentences.", + "Tokenize, tokenize, tokenize!", + "When will this entralling text stop, praytell?", + "Nobody knows; the author can't stop.", + ] # Act sentences = dummy_model._sentence_tokenize(paragraph) @@ -91,14 +91,23 @@ def test__sentence_tokenize_works_for_paragraph(self, dummy_model): # Assert assert expected == sentences - @pytest.mark.parametrize("filepath,expected", [ - ('tests/data/AA20-302A.pdf', 'GLEMALT With a Ransomware Chaser'), - ('tests/data/AA20-302A.docx', 'Page 22 of 22 | Product ID: AA20-302A TLP:WHITE'), - ('tests/data/AA20-302A.html', 'CISA is part of the Department of Homeland Security'), - ]) + @pytest.mark.parametrize( + "filepath,expected", + [ + ("tests/data/AA20-302A.pdf", "GLEMALT With a Ransomware Chaser"), + ( + "tests/data/AA20-302A.docx", + "Page 22 of 22 | Product ID: AA20-302A TLP:WHITE", + ), + ( + "tests/data/AA20-302A.html", + "CISA is part of the Department of Homeland Security", + ), + ], + ) def test__extract_text_succeeds(self, dummy_model, filepath, expected): # Arrange - with open(filepath, 'rb') as f: + with open(filepath, "rb") as f: doc = db_models.Document(docfile=File(f)) doc.save() @@ -113,7 +122,7 @@ def test__extract_text_succeeds(self, dummy_model, filepath, expected): def test__extract_text_unknown_extension_raises_value_error(self, dummy_model): # Arrange - with open('tests/data/unknown-extension.fizzbuzz', 'rb') as f: + with open("tests/data/unknown-extension.fizzbuzz", "rb") as f: doc = db_models.Document(docfile=File(f)) doc.save() @@ -126,8 +135,8 @@ def test__extract_text_unknown_extension_raises_value_error(self, dummy_model): def test_get_report_name_succeeds(self, dummy_model): # Arrange - expected = 'Report for AA20-302A' - with open('tests/data/AA20-302A.docx', 'rb') as f: + expected = "Report for AA20-302A" + with open("tests/data/AA20-302A.docx", "rb") as f: doc = db_models.Document(docfile=File(f)) doc.save() job = db_models.DocumentProcessingJob(document=doc) @@ -148,13 +157,13 @@ def test_get_attack_objects_succeeds_after_initialization(self, dummy_model): objects = dummy_model.get_attack_object_ids() # Assert - assert 'T1327' in objects # Ensures mitre-pre-attack is available - assert 'T1497.003' in objects # Ensures mitre-attack is available - assert 'T1579' in objects # Ensures mitre-mobile-attack is available + assert "T1327" in objects # Ensures mitre-pre-attack is available + assert "T1497.003" in objects # Ensures mitre-attack is available + assert "T1579" in objects # Ensures mitre-mobile-attack is available def test_disk_round_trip_succeeds(self, dummy_model, tmpdir): # Arrange - filepath = (tmpdir + 'dummy_model.pkl').strpath + filepath = (tmpdir + "dummy_model.pkl").strpath # Act dummy_model.get_attack_object_ids() # Change the state of the DummyModel @@ -164,7 +173,9 @@ def test_disk_round_trip_succeeds(self, dummy_model, tmpdir): # Assert assert dummy_model.__class__ == dummy_model_2.__class__ - assert dummy_model.get_attack_object_ids() == dummy_model_2.get_attack_object_ids() + assert ( + dummy_model.get_attack_object_ids() == dummy_model_2.get_attack_object_ids() + ) def test_no_data_get_training_data_succeeds(self, dummy_model): # Act @@ -174,26 +185,28 @@ def test_no_data_get_training_data_succeeds(self, dummy_model): assert len(X) == 0 assert len(y) == 0 - def test_get_training_data_returns_only_accepted_sentences(self, dummy_model, report): + def test_get_training_data_returns_only_accepted_sentences( + self, dummy_model, report + ): # Arrange s1 = db_models.Sentence.objects.create( - text='sentence1', + text="sentence1", order=0, document=report.document, report=report, - disposition=None + disposition=None, ) s2 = db_models.Sentence.objects.create( - text='sentence 2', + text="sentence 2", order=1, document=report.document, report=report, - disposition='accept' + disposition="accept", ) m1 = db_models.Mapping.objects.create( report=report, sentence=s2, - attack_object=db_models.AttackObject.objects.get(attack_id='T1548'), + attack_object=db_models.AttackObject.objects.get(attack_id="T1548"), confidence=100.0, ) config.ML_ACCEPT_THRESHOLD = 0 # Set the threshold to 0 for this test @@ -213,13 +226,14 @@ def test_non_sklearn_pipeline_raises(self): class NonSKLearnPipeline(base.SKLearnModel): def get_model(self): return "This is not an sklearn.pipeline.Pipeline instance" + # Act with pytest.raises(TypeError): NonSKLearnPipeline() @pytest.mark.django_db -@pytest.mark.usefixtures('load_attack_data', 'load_small_training_data') +@pytest.mark.usefixtures("load_attack_data", "load_small_training_data") class TestsThatNeedTrainingData: """ Loading the training data is a large time cost, so this groups tests together that use @@ -234,7 +248,7 @@ class TestsThatNeedTrainingData: def test_modelmanager__init__loads_dummy_model(self): # Act - model_manager = base.ModelManager('dummy') + model_manager = base.ModelManager("dummy") # Assert assert model_manager.model.__class__ == base.DummyModel @@ -242,17 +256,18 @@ def test_modelmanager__init__loads_dummy_model(self): def test_modelmanager__init__raises_value_error_on_unknown_model(self): # Act / Assert with pytest.raises(ValueError): - base.ModelManager('this-should-raise') + base.ModelManager("this-should-raise") def test_modelmanager_train_model_doesnt_raise(self): # Arrange - model_manager = base.ModelManager('dummy') + model_manager = base.ModelManager("dummy") # Act model_manager.train_model() # Assert # TODO: Something meaningful + """ ----- End ModelManager Tests ----- """ @@ -265,7 +280,7 @@ def test_get_mappings_returns_mappings(self): config.ML_CONFIDENCE_THRESHOLD = 0 # Act - mappings = dummy_model.get_mappings('test sentence') + mappings = dummy_model.get_mappings("test sentence") # Assert for mapping in mappings: @@ -273,7 +288,7 @@ def test_get_mappings_returns_mappings(self): def test_process_job_produces_valid_report(self): # Arrange - with open('tests/data/AA20-302A.docx', 'rb') as f: + with open("tests/data/AA20-302A.docx", "rb") as f: doc = db_models.Document(docfile=File(f)) doc.save() job = db_models.DocumentProcessingJob(document=doc) @@ -303,23 +318,24 @@ def test_process_job_handles_image_based_pdf(self): is that the job is logged as "status: error". """ # Arrange - image_pdf = 'tests/data/GroupIB_Big_Airline_Heist_APT41.pdf' - with open(image_pdf, 'rb') as f: + image_pdf = "tests/data/GroupIB_Big_Airline_Heist_APT41.pdf" + with open(image_pdf, "rb") as f: processing_job = db_models.DocumentProcessingJob.create_from_file(File(f)) job_id = processing_job.id - model_manager = base.ModelManager('dummy') + model_manager = base.ModelManager("dummy") # Act model_manager.run_model() job_result = db_models.DocumentProcessingJob.objects.get(id=job_id) # Assert - assert job_result.status == 'error' + assert job_result.status == "error" assert len(job_result.message) > 0 """ ----- Begin DummyModel Tests ----- """ + def test_dummymodel_train_and_test_passes(self, dummy_model): # Act dummy_model.train() # Has no effect diff --git a/tests/tram/test_commands.py b/tests/tram/test_commands.py index 2cb19aad4a..21cf344c31 100644 --- a/tests/tram/test_commands.py +++ b/tests/tram/test_commands.py @@ -1,8 +1,8 @@ +import pytest from django.core.management import call_command from django.core.management.base import CommandError -import pytest -from tram.management.commands import pipeline, attackdata +from tram.management.commands import attackdata, pipeline from tram.ml import base from tram.models import AttackObject, Sentence @@ -10,25 +10,28 @@ class TestPipeline: def test_add_calls_create_from_file(self, mocker): # Arrange - mock_create = mocker.patch('tram.models.DocumentProcessingJob.create_from_file') - filepath = 'tests/data/simple-test.docx' + mock_create = mocker.patch("tram.models.DocumentProcessingJob.create_from_file") + filepath = "tests/data/simple-test.docx" # Act - call_command('pipeline', pipeline.ADD, file=filepath) + call_command("pipeline", pipeline.ADD, file=filepath) # Assert assert mock_create.called_once() - @pytest.mark.parametrize("subcommand,to_mock", [ - (pipeline.RUN, 'run_model'), - (pipeline.TRAIN, 'train_model'), - ]) + @pytest.mark.parametrize( + "subcommand,to_mock", + [ + (pipeline.RUN, "run_model"), + (pipeline.TRAIN, "train_model"), + ], + ) def test_subcommand_calls_correct_function(self, mocker, subcommand, to_mock): # Arrange mocked_func = mocker.patch.object(base.ModelManager, to_mock, return_value=None) # Act - call_command('pipeline', subcommand, model='dummy') + call_command("pipeline", subcommand, model="dummy") # Assert assert mocked_func.called_once() @@ -36,20 +39,22 @@ def test_subcommand_calls_correct_function(self, mocker, subcommand, to_mock): def test_incorrect_subcommand_raises_commanderror(self): # Act / Assert with pytest.raises(CommandError): - call_command('pipeline', 'incorrect-subcommand') + call_command("pipeline", "incorrect-subcommand") @pytest.mark.django_db def test_load_training_data_succeeds(self, load_attack_data): # Act - call_command('pipeline', pipeline.LOAD_TRAINING_DATA) + call_command("pipeline", pipeline.LOAD_TRAINING_DATA) # Assert - assert Sentence.objects.count() == 12588 # Count of sentences data/training/bootstrap-training-data.json + assert ( + Sentence.objects.count() == 12588 + ) # Count of sentences data/training/bootstrap-training-data.json @pytest.mark.django_db def test_run_succeeds(self, load_attack_data): # Act - call_command('pipeline', pipeline.RUN) + call_command("pipeline", pipeline.RUN) # Assert pass @@ -62,7 +67,7 @@ def test_load_succeeds(self): expected_object_count = 1461 # Act - call_command('attackdata', attackdata.LOAD) + call_command("attackdata", attackdata.LOAD) object_count = AttackObject.objects.all().count() # Assert @@ -73,8 +78,8 @@ def test_clear_succeeds(self): expected_techniques = 0 # Act - call_command('attackdata', attackdata.LOAD) - call_command('attackdata', attackdata.CLEAR) + call_command("attackdata", attackdata.LOAD) + call_command("attackdata", attackdata.CLEAR) techniques = AttackObject.objects.all().count() # Assert @@ -83,4 +88,4 @@ def test_clear_succeeds(self): def test_incorrect_subcommand_raises_commanderror(self): # Act / Assert with pytest.raises(CommandError): - call_command('attackdata', 'incorrect-subcommand') + call_command("attackdata", "incorrect-subcommand") diff --git a/tests/tram/test_models.py b/tests/tram/test_models.py index 629748718e..c554d6df60 100644 --- a/tests/tram/test_models.py +++ b/tests/tram/test_models.py @@ -5,7 +5,7 @@ class TestAttackTechnique: def test___str__renders_correctly(self, attack_object): # Arrange - expected = '(T1327) Use multiple DNS infrastructures' + expected = "(T1327) Use multiple DNS infrastructures" # Assert assert str(attack_object) == expected @@ -15,7 +15,7 @@ def test___str__renders_correctly(self, attack_object): class TestDocument: def test__str__renders_correctly(self, document): # Arrange - expected = 'tests/data/simple-test.docx' + expected = "tests/data/simple-test.docx" # Assert assert str(document) == expected @@ -25,7 +25,7 @@ def test__str__renders_correctly(self, document): class TestDocumentProcessingJob: def test__str__renders_correctly(self, document_processing_job): # Arrange - expected = 'Process tests/data/simple-test.docx' + expected = "Process tests/data/simple-test.docx" # Assert assert str(document_processing_job) == expected @@ -35,7 +35,7 @@ def test__str__renders_correctly(self, document_processing_job): class TestReport: def test__str__renders_correctly(self, report): # Arrange - expected = 'Test report name' + expected = "Test report name" # Assert assert str(report) == expected @@ -45,7 +45,7 @@ def test__str__renders_correctly(self, report): class TestIndicator: def test__str__renders_correctly(self, indicator): # Arrange - expected = 'MD5: 54b0c58c7ce9f2a8b551351102ee0938' + expected = "MD5: 54b0c58c7ce9f2a8b551351102ee0938" # Assert assert str(indicator) == expected @@ -55,14 +55,14 @@ def test__str__renders_correctly(self, indicator): class TestSentence: def test__str__renders_correctly(self, sentence): # Arrange - expected = 'test-text' + expected = "test-text" # Assert assert str(sentence) == expected def test__str__renders_long_sentence_correctly(self, long_sentence): # Arrange - expected = 'this sentence is long and should trigger...' + expected = "this sentence is long and should trigger..." # Assert assert str(long_sentence) == expected diff --git a/tests/tram/test_views.py b/tests/tram/test_views.py index f5d0a7c7ab..0bf762a567 100644 --- a/tests/tram/test_views.py +++ b/tests/tram/test_views.py @@ -1,17 +1,17 @@ import json +import pytest from django.contrib.auth.models import User from django.core.files.uploadedfile import SimpleUploadedFile from django.test import Client -import pytest from tram.models import Document, DocumentProcessingJob @pytest.fixture def user(): - user = User.objects.create_superuser(username='testuser') - user.set_password('12345') + user = User.objects.create_superuser(username="testuser") + user.set_password("12345") user.save() yield user user.delete() @@ -25,255 +25,264 @@ def client(user): @pytest.fixture def logged_in_client(client): - client.login(username='testuser', password='12345') + client.login(username="testuser", password="12345") return client @pytest.mark.django_db class TestLogin: - def test_get_login_loads_login_form(self, client): # Act - response = client.get('/login/') + response = client.get("/login/") # Assert - assert b'Login' in response.content + assert b"Login" in response.content def test_valid_login_redirects(self, client): # Arrange - data = {'username': 'testuser', - 'password': '12345'} + data = {"username": "testuser", "password": "12345"} # Act - response = client.post('/login/', data) + response = client.post("/login/", data) # Assert assert response.status_code == 302 - assert response.url == '/' + assert response.url == "/" def test_invalid_login_rerenders_login(self, client): # Arrange - data = {'username': 'not-a-real-user', - 'password': 'password'} + data = {"username": "not-a-real-user", "password": "password"} # Act - response = client.post('/login/', data) + response = client.post("/login/", data) # Assert assert response.status_code == 200 - assert b'Login' in response.content + assert b"Login" in response.content @pytest.mark.django_db class TestDocumentation: def test_documentation_loads(self, logged_in_client): # Act - response = logged_in_client.get('/docs/') + response = logged_in_client.get("/docs/") # Assert assert response.status_code == 200 - assert b'Documentation' in response.content + assert b"Documentation" in response.content @pytest.mark.django_db class TestIndex: def test_index_loads_with_no_stored_data(self, logged_in_client): # Act - response = logged_in_client.get('/') + response = logged_in_client.get("/") # Assert assert response.status_code == 200 - assert b'TRAM - Threat Report ATT&CK Mapper' in response.content + assert b"TRAM - Threat Report ATT&CK Mapper" in response.content def test_index_loads_with_one_stored_report(self, logged_in_client, report): # Act - response = logged_in_client.get('/') + response = logged_in_client.get("/") # Assert assert response.status_code == 200 - assert b'TRAM - Threat Report ATT&CK Mapper' in response.content + assert b"TRAM - Threat Report ATT&CK Mapper" in response.content - def test_index_loads_with_one_job_queued(self, logged_in_client, document_processing_job): + def test_index_loads_with_one_job_queued( + self, logged_in_client, document_processing_job + ): # Act - response = logged_in_client.get('/') + response = logged_in_client.get("/") # Assert assert response.status_code == 200 - assert b'TRAM - Threat Report ATT&CK Mapper' in response.content + assert b"TRAM - Threat Report ATT&CK Mapper" in response.content @pytest.mark.django_db class TestAnalyze: def test_analyze_loads(self, logged_in_client, report): # Act - response = logged_in_client.get('/analyze/1/') + response = logged_in_client.get("/analyze/1/") # Assert assert response.status_code == 200 - assert b'TRAM - Analyze Report' in response.content + assert b"TRAM - Analyze Report" in response.content @pytest.mark.django_db class TestUpload: def test_get_upload_returns_405(self, logged_in_client): # Act - response = logged_in_client.get('/upload/') + response = logged_in_client.get("/upload/") # Assert assert response.status_code == 405 def test_file_upload_succeeds_and_creates_job(self, logged_in_client): # Arrange - f = SimpleUploadedFile('test-report.pdf', - b'test file content', - content_type='application/pdf') - data = {'file': f} + f = SimpleUploadedFile( + "test-report.pdf", b"test file content", content_type="application/pdf" + ) + data = {"file": f} doc_count_pre = Document.objects.all().count() job_count_pre = DocumentProcessingJob.objects.all().count() # Act - response = logged_in_client.post('/upload/', data) + response = logged_in_client.post("/upload/", data) doc_count_post = Document.objects.all().count() job_count_post = DocumentProcessingJob.objects.all().count() - Document.objects.get(docfile='test-report.pdf').delete() + Document.objects.get(docfile="test-report.pdf").delete() # Assert assert response.status_code == 200 - assert b'File saved for processing' in response.content + assert b"File saved for processing" in response.content assert doc_count_pre + 1 == doc_count_post assert job_count_pre + 1 == job_count_post - def test_report_export_upload_creates_report(self, logged_in_client, load_attack_data): + def test_report_export_upload_creates_report( + self, logged_in_client, load_attack_data + ): # Act - with open('tests/data/report-for-simple-testdocx.json') as f: - response = logged_in_client.post('/upload/', {'file': f}) + with open("tests/data/report-for-simple-testdocx.json") as f: + response = logged_in_client.post("/upload/", {"file": f}) # Assert assert response.status_code == 200 def test_upload_unsupported_file_type_causes_bad_request(self, logged_in_client): # Arrange - f = SimpleUploadedFile('test-document.zip', - b'test file content', - content_type='application/zip') - data = {'file': f} + f = SimpleUploadedFile( + "test-document.zip", b"test file content", content_type="application/zip" + ) + data = {"file": f} # Act - response = logged_in_client.post('/upload/', data) + response = logged_in_client.post("/upload/", data) # Assert assert response.status_code == 400 - assert response.content == b'Unsupported file type' + assert response.content == b"Unsupported file type" @pytest.mark.django_db class TestMappingViewSet: def test_get_mappings(self, logged_in_client, mapping): # Act - response = logged_in_client.get('/api/mappings/') + response = logged_in_client.get("/api/mappings/") json_response = json.loads(response.content) # Assert assert len(json_response) == 1 - assert json_response[0]['attack_id'] == 'T1327' + assert json_response[0]["attack_id"] == "T1327" def test_get_mapping(self, logged_in_client, mapping): # Act - response = logged_in_client.get('/api/mappings/1/') + response = logged_in_client.get("/api/mappings/1/") json_response = json.loads(response.content) # Assert - assert json_response['attack_id'] == 'T1327' + assert json_response["attack_id"] == "T1327" def test_get_mappings_by_sentence(self, logged_in_client, mapping): # Act - response = logged_in_client.get('/api/mappings/?sentence-id=1') + response = logged_in_client.get("/api/mappings/?sentence-id=1") json_response = json.loads(response.content) # Assert assert len(json_response) == 1 - assert json_response[0]['attack_id'] == 'T1327' + assert json_response[0]["attack_id"] == "T1327" @pytest.mark.django_db class TestSentenceViewSet: def test_get_sentences(self, logged_in_client, sentence): # Act - response = logged_in_client.get('/api/sentences/') + response = logged_in_client.get("/api/sentences/") json_response = json.loads(response.content) # Assert assert len(json_response) == 1 - assert json_response[0]['order'] == 0 + assert json_response[0]["order"] == 0 def test_get_sentence(self, logged_in_client, sentence): # Act - response = logged_in_client.get('/api/sentences/1/') + response = logged_in_client.get("/api/sentences/1/") json_response = json.loads(response.content) # Assert - assert json_response['order'] == 0 + assert json_response["order"] == 0 def test_get_sentences_by_report(self, logged_in_client, sentence): # Act - response = logged_in_client.get('/api/sentences/?report-id=1') + response = logged_in_client.get("/api/sentences/?report-id=1") json_response = json.loads(response.content) # Assert assert len(json_response) == 1 - assert json_response[0]['order'] == 0 + assert json_response[0]["order"] == 0 @pytest.mark.django_db class TestReportExport: def test_get_report_export_succeeds(self, logged_in_client, mapping): # Act - response = logged_in_client.get('/api/report-export/1/') + response = logged_in_client.get("/api/report-export/1/") json_response = json.loads(response.content) # Assert - assert 'sentences' in json_response - assert len(json_response['sentences'][0]['mappings']) == 1 + assert "sentences" in json_response + assert len(json_response["sentences"][0]["mappings"]) == 1 - def test_bootstrap_training_data_can_be_posted_as_json_report(self, logged_in_client, load_attack_data): + def test_bootstrap_training_data_can_be_posted_as_json_report( + self, logged_in_client, load_attack_data + ): # Arrange - with open('data/training/bootstrap-training-data.json') as f: + with open("data/training/bootstrap-training-data.json") as f: json_string = f.read() # Act - response = logged_in_client.post('/api/report-export/', json_string, content_type='application/json') + response = logged_in_client.post( + "/api/report-export/", json_string, content_type="application/json" + ) # Assert assert response.status_code == 201 # HTTP 201 Created def test_report_export_update_not_implemented(self, logged_in_client): # Act - response = logged_in_client.post('/api/report-export/1/', '{}', content_type='application/json') + response = logged_in_client.post( + "/api/report-export/1/", "{}", content_type="application/json" + ) # Assert assert response.status_code == 405 # Method not allowed @pytest.mark.django_db -@pytest.mark.usefixtures('load_attack_data', ) +@pytest.mark.usefixtures( + "load_attack_data", +) class TestMl: def test_ml_home_returns_http_200_ok(self, logged_in_client): # Act - response = logged_in_client.get('/ml/') + response = logged_in_client.get("/ml/") # Assert assert response.status_code == 200 # HTTP 200 Ok def test_ml_model_detail_returns_http_200_ok(self, logged_in_client): # Act - response = logged_in_client.get('/ml/models/dummy') + response = logged_in_client.get("/ml/models/dummy") # Assert assert response.status_code == 200 # HTTP 200 Ok def test_ml_model_detail_returns_http_404_for_invalid_model(self, logged_in_client): # Act - response = logged_in_client.get('/ml/models/this-should-not-work') + response = logged_in_client.get("/ml/models/this-should-not-work") # Assert assert response.status_code == 404 # HTTP 200 Ok