From 12cd95ae9f65ab635c582f4eb6f0fd8547ab8c7b Mon Sep 17 00:00:00 2001 From: Michael Butt <869064+m3mike@users.noreply.github.com> Date: Mon, 7 Feb 2022 20:26:43 -0500 Subject: [PATCH] chore(*): reformat codebase using pre-commit (#140) * chore(*): reformat codebase using pre-commit * chore(ci): fix flake8 exclusion after reformat, add sonar exclusion for tests * fix(ci): fix nosec issue caused by reformat --- .pre-commit-config.yaml | 2 + .sonarcloud.properties | 2 + LICENSE.txt | 2 +- NOTICE.txt | 2 +- docker/README.md | 2 +- pyproject.toml | 2 + src/archive/pipeline/regex.yml | 2 +- src/scripts/reformat_training_data.py | 363 +++++++++--------- src/tram/manage.py | 4 +- src/tram/tram/admin.py | 19 +- src/tram/tram/asgi.py | 2 +- .../tram/management/commands/attackdata.py | 92 +++-- src/tram/tram/management/commands/pipeline.py | 80 ++-- src/tram/tram/ml/base.py | 193 ++++++---- src/tram/tram/models.py | 89 +++-- src/tram/tram/serializers.py | 149 ++++--- src/tram/tram/templates/analyze.html | 2 +- src/tram/tram/templates/index.html | 2 +- src/tram/tram/templates/ml_home.html | 2 +- src/tram/tram/templates/model_detail.html | 2 +- .../tram/templates/registration/login.html | 2 +- .../tram/templates/technique_sentences.html | 2 +- .../tram/templates/tram_documentation.html | 2 +- src/tram/tram/urls.py | 37 +- src/tram/tram/views.py | 86 +++-- src/tram/tram/wsgi.py | 2 +- tests/conftest.py | 46 +-- tests/tram/test_base.py | 116 +++--- tests/tram/test_commands.py | 41 +- tests/tram/test_models.py | 14 +- tests/tram/test_views.py | 141 +++---- 31 files changed, 836 insertions(+), 666 deletions(-) create mode 100644 .sonarcloud.properties diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d2176289a9..a02a4961cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,5 @@ +exclude: '^.*\b(migrations)\b.*$|\.min\.js$|\.min\.css$|\.min\.*\b|\.svg|^tests/data|^data/|static/js' + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.0.1 diff --git a/.sonarcloud.properties b/.sonarcloud.properties new file mode 100644 index 0000000000..3ed4eb003d --- /dev/null +++ b/.sonarcloud.properties @@ -0,0 +1,2 @@ +# sonar settings +sonar.exclusions=tests/ diff --git a/LICENSE.txt b/LICENSE.txt index f49a4e16e6..261eeb9e9f 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/NOTICE.txt b/NOTICE.txt index b4a6e43e24..5379ff1a95 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -16,4 +16,4 @@ limitations under the License. This project makes use of ATT&CK® -ATT&CK® Terms of Use - https://attack.mitre.org/resources/terms-of-use/ \ No newline at end of file +ATT&CK® Terms of Use - https://attack.mitre.org/resources/terms-of-use/ diff --git a/docker/README.md b/docker/README.md index ec5b2503a7..64c56f63e1 100644 --- a/docker/README.md +++ b/docker/README.md @@ -24,7 +24,7 @@ services: - "8000:8000" environment: - DATA_DIRECTORY=/data - - ALLOWED_HOSTS=["example_host1", "localhost"] + - ALLOWED_HOSTS=["example_host1", "localhost"] - SECRET_KEY=Ij0WGee73k9OESwqddmSKCx6SY9aJ_7bDojs485Z6ec # your secret key here - DEBUG=True - DJANGO_SUPERUSER_USERNAME=djangoSuperuser diff --git a/pyproject.toml b/pyproject.toml index 57d0ea5efe..88c5c0ba79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,8 @@ exclude = ''' | venv | migrations | node_modules + | .tox + | .git )/ ''' diff --git a/src/archive/pipeline/regex.yml b/src/archive/pipeline/regex.yml index 17cf147d06..81a500ce93 100644 --- a/src/archive/pipeline/regex.yml +++ b/src/archive/pipeline/regex.yml @@ -67,4 +67,4 @@ - name: sha512 code: sha512_hash regex: | - \b[a-fA-f0-9]{128}\b \ No newline at end of file + \b[a-fA-f0-9]{128}\b diff --git a/src/scripts/reformat_training_data.py b/src/scripts/reformat_training_data.py index 6bf74e2884..9a95972dce 100644 --- a/src/scripts/reformat_training_data.py +++ b/src/scripts/reformat_training_data.py @@ -28,173 +28,173 @@ The target format is defined by tram.serializers.ReportExportSerializer """ -from datetime import datetime -from functools import partial import json import os import sys +from datetime import datetime +from functools import partial import django -sys.path.append('src/tram/') -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tram.settings') +sys.path.append("src/tram/") +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tram.settings") django.setup() from tram.serializers import ReportExportSerializer # noqa: E402 -outfile = 'data/training/bootstrap-training-data.json' +outfile = "data/training/bootstrap-training-data.json" ATTACK_LOOKUP = { # A mapping of attack descriptions to technique IDs - 'drive-by compromise': 'T1189', - 'system information discovery': 'T1082', - 'new service': 'T1543', - 'service execution': 'T1569.002', - 'command-line interface': 'T1059', # Maps to: T1059 - Command and Scripting Interpreter - 'obfuscated files or information': 'T1027', - 'custom cryptographic protocol': 'T1573', # Maps to: T1573 - Encrypted Channel - 'system network configuration discovery': 'T1016', - 'web shell': 'T1505.003', - 'application window discovery': 'T1010', - 'file deletion': 'T1070.004', # Technique that became a subtechnique - 'standard application layer protocol': 'T1071', - 'web service': 'T1102', - 'exfiltration over command and control channel': 'T1041', - 'fallback channels': 'T1008', - 'bypass user account control': 'T1548.002', # Technique that became a subtechnique - 'system time discovery': 'T1124', - 'deobfuscate/decode files or information': 'T1140', - 'disabling security tools': 'T1562.001', # Maps to: T1562.001 - Impair Defenses: Disable or Modify Tools - 'registry run keys / startup folder': 'T1547.001', - 'remote file copy': 'T1105', # Maps to: T1105 - Ingress Tool Transfer - 'dll search order hijacking': 'T1574.001', - 'screen capture': 'T1113', - 'file and directory discovery': 'T1083', - 'tor': 'S0183', # Software?? - 'shortcut modification': 'T1547.009', - 'remote services': 'T1021', - 'connection proxy': 'T1090', - 'data encoding': 'T1132', - 'spearphishing link': 'T1566.002', - 'spearphishing attachment': 'T1566.001', - 'arp': 'S0099', - 'user execution': 'T1204', - 'process hollowing': 'T1055.012', - 'execution through api': 'T1106', # Maps to T1106 - Native API - 'masquerading': 'T1036', - 'code signing': 'T1553.002', - 'standard cryptographic protocol': 'T1521', - 'scripting': 'T1059', - 'remote system discovery': 'T1018', - 'credential dumping': 'T1003', - 'exploitation for client execution': 'T1203', - 'exploitation for privilege escalation': 'T1068', - 'security software discovery': 'T1518.001', - 'data from local system': 'T1533', - 'remote desktop protocol': 'T1021.001', - 'data compressed': 'T1560', # Maps to T1560 - Archive Collected Data - 'software packing': 'T1027.002', - 'ping': 'S0097', - 'brute force': 'T1110', - 'commonly used port': 'T1571', - 'modify registry': 'T1112', - 'uncommonly used port': 'T1571', - 'process injection': 'T1055', - 'timestomp': 'T1070.006', - 'windows management instrumentation': 'T1047', - 'data staged': 'T1074', - 'rundll32': 'T1218.011', - 'regsvr32': 'T1218.010', - 'account discovery': 'T1087', - 'process discovery': 'T1057', - 'clipboard data': 'T1115', - 'binary padding': 'T1027.001', - 'pass the hash': 'T1550.002', - 'network service scanning': 'T1046', - 'system service discovery': 'T1007', - 'data encrypted': 'T1486', - 'system network connections discovery': 'T1049', - 'windows admin shares': 'T1021.002', - 'system owner/user discovery': 'T1033', - 'launch agent': 'T1543.001', - 'permission groups discovery': 'T1069', - 'indicator removal on host': 'T1070', - 'input capture': 'T1056', - 'virtualization/sandbox evasion': 'T1497.001', - 'dll side-loading': 'T1574.002', - 'scheduled task': 'T1053', - 'access token manipulation': 'T1134', - 'powershell': 'T1546.013', - 'exfiltration over alternative protocol': 'T1048', - 'hidden files and directories': 'T1564.001', - 'network share discovery': 'T1135', - 'query registry': 'T1012', - 'credentials in files': 'T1552.001', - 'audio capture': 'T1123', - 'video capture': 'T1125', - 'peripheral device discovery': 'T1120', - 'spearphishing via service': 'T1566.003', - 'data encrypted for impact': 'T1486', - 'data destruction': 'T1485', - 'template injection': 'T1221', - 'inhibit system recovery': 'T1490', - 'create account': 'T1136', - 'exploitation of remote services': 'T1210', - 'valid accounts': 'T1078', - 'dynamic data exchange': 'T1559.002', - 'office application startup': 'T1137', - 'data obfuscation': 'T1001', - 'domain trust discovery': 'T1482', - 'email collection': 'T1114', - 'man in the browser': 'T1185', - 'data from removable media': 'T1025', - 'bootkit': 'T1542.003', - 'logon scripts': 'T1037', - 'execution through module load': 'T1129', - 'llmnr/nbt-ns poisoning and relay': 'T1557.001', - 'external remote services': 'T1133', - 'domain fronting': 'T1090.004', - 'sid-history injection': 'T1134.005', - 'service stop': 'T1489', - 'disk structure wipe': 'T1561.002', - 'credentials in registry': 'T1552.002', - 'appinit dlls': 'T1546.010', - 'exploit public-facing application': 'T1190', - 'remote access tools': 'T1219', - 'signed binary proxy execution': 'T1218', - 'appcert dlls': 'T1546.009', - 'winlogon helper dll': 'T1547.004', - 'file permissions modification': 'T1222', - 'hooking': 'T1056.004', - 'system firmware': 'T1542.001', - 'lsass driver': 'T1547.008', - 'distributed component object model': 'T1021.003', - 'cmstp': 'T1218.003', - 'execution guardrails': 'T1480', - 'component object model hijacking': 'T1546.015', - 'accessibility features': 'T1546.008', # TODO: Help wanted - 'keychain': 'T1555.001', - 'mshta': 'T1218.005', - 'pass the ticket': 'T1550.003', - 'kerberoasting': 'T1558.003', - 'password policy discovery': 'T1201', - 'local job scheduling': 'T1053.001', - 'windows remote management': 'T1021.006', - 'bits jobs': 'T1197', - 'data from information repositories': 'T1213', - 'lc_load_dylib addition': 'T1546.006', - 'histcontrol': 'T1562.003', - 'file system logical offsets': 'T1006', - 'regsvcs/regasm': 'T1218.009', - 'exploitation for credential access': 'T1212', - 'sudo': 'T1548.003', - 'installutil': 'T1218.004', - 'query registry ': 'T1012', - 'launchctl': 'T1569.001', - '.bash_profile and .bashrc': 'T1546.004', - 'applescript': 'T1059.002', - 'emond': 'T1546.014', - 'control panel items': 'T1218.002', - 'application shimming': 'T1546.011', + "drive-by compromise": "T1189", + "system information discovery": "T1082", + "new service": "T1543", + "service execution": "T1569.002", + "command-line interface": "T1059", # Maps to: T1059 - Command and Scripting Interpreter + "obfuscated files or information": "T1027", + "custom cryptographic protocol": "T1573", # Maps to: T1573 - Encrypted Channel + "system network configuration discovery": "T1016", + "web shell": "T1505.003", + "application window discovery": "T1010", + "file deletion": "T1070.004", # Technique that became a subtechnique + "standard application layer protocol": "T1071", + "web service": "T1102", + "exfiltration over command and control channel": "T1041", + "fallback channels": "T1008", + "bypass user account control": "T1548.002", # Technique that became a subtechnique + "system time discovery": "T1124", + "deobfuscate/decode files or information": "T1140", + "disabling security tools": "T1562.001", # Maps to: T1562.001 - Impair Defenses: Disable or Modify Tools + "registry run keys / startup folder": "T1547.001", + "remote file copy": "T1105", # Maps to: T1105 - Ingress Tool Transfer + "dll search order hijacking": "T1574.001", + "screen capture": "T1113", + "file and directory discovery": "T1083", + "tor": "S0183", # Software?? + "shortcut modification": "T1547.009", + "remote services": "T1021", + "connection proxy": "T1090", + "data encoding": "T1132", + "spearphishing link": "T1566.002", + "spearphishing attachment": "T1566.001", + "arp": "S0099", + "user execution": "T1204", + "process hollowing": "T1055.012", + "execution through api": "T1106", # Maps to T1106 - Native API + "masquerading": "T1036", + "code signing": "T1553.002", + "standard cryptographic protocol": "T1521", + "scripting": "T1059", + "remote system discovery": "T1018", + "credential dumping": "T1003", + "exploitation for client execution": "T1203", + "exploitation for privilege escalation": "T1068", + "security software discovery": "T1518.001", + "data from local system": "T1533", + "remote desktop protocol": "T1021.001", + "data compressed": "T1560", # Maps to T1560 - Archive Collected Data + "software packing": "T1027.002", + "ping": "S0097", + "brute force": "T1110", + "commonly used port": "T1571", + "modify registry": "T1112", + "uncommonly used port": "T1571", + "process injection": "T1055", + "timestomp": "T1070.006", + "windows management instrumentation": "T1047", + "data staged": "T1074", + "rundll32": "T1218.011", + "regsvr32": "T1218.010", + "account discovery": "T1087", + "process discovery": "T1057", + "clipboard data": "T1115", + "binary padding": "T1027.001", + "pass the hash": "T1550.002", + "network service scanning": "T1046", + "system service discovery": "T1007", + "data encrypted": "T1486", + "system network connections discovery": "T1049", + "windows admin shares": "T1021.002", + "system owner/user discovery": "T1033", + "launch agent": "T1543.001", + "permission groups discovery": "T1069", + "indicator removal on host": "T1070", + "input capture": "T1056", + "virtualization/sandbox evasion": "T1497.001", + "dll side-loading": "T1574.002", + "scheduled task": "T1053", + "access token manipulation": "T1134", + "powershell": "T1546.013", + "exfiltration over alternative protocol": "T1048", + "hidden files and directories": "T1564.001", + "network share discovery": "T1135", + "query registry": "T1012", + "credentials in files": "T1552.001", + "audio capture": "T1123", + "video capture": "T1125", + "peripheral device discovery": "T1120", + "spearphishing via service": "T1566.003", + "data encrypted for impact": "T1486", + "data destruction": "T1485", + "template injection": "T1221", + "inhibit system recovery": "T1490", + "create account": "T1136", + "exploitation of remote services": "T1210", + "valid accounts": "T1078", + "dynamic data exchange": "T1559.002", + "office application startup": "T1137", + "data obfuscation": "T1001", + "domain trust discovery": "T1482", + "email collection": "T1114", + "man in the browser": "T1185", + "data from removable media": "T1025", + "bootkit": "T1542.003", + "logon scripts": "T1037", + "execution through module load": "T1129", + "llmnr/nbt-ns poisoning and relay": "T1557.001", + "external remote services": "T1133", + "domain fronting": "T1090.004", + "sid-history injection": "T1134.005", + "service stop": "T1489", + "disk structure wipe": "T1561.002", + "credentials in registry": "T1552.002", + "appinit dlls": "T1546.010", + "exploit public-facing application": "T1190", + "remote access tools": "T1219", + "signed binary proxy execution": "T1218", + "appcert dlls": "T1546.009", + "winlogon helper dll": "T1547.004", + "file permissions modification": "T1222", + "hooking": "T1056.004", + "system firmware": "T1542.001", + "lsass driver": "T1547.008", + "distributed component object model": "T1021.003", + "cmstp": "T1218.003", + "execution guardrails": "T1480", + "component object model hijacking": "T1546.015", + "accessibility features": "T1546.008", # TODO: Help wanted + "keychain": "T1555.001", + "mshta": "T1218.005", + "pass the ticket": "T1550.003", + "kerberoasting": "T1558.003", + "password policy discovery": "T1201", + "local job scheduling": "T1053.001", + "windows remote management": "T1021.006", + "bits jobs": "T1197", + "data from information repositories": "T1213", + "lc_load_dylib addition": "T1546.006", + "histcontrol": "T1562.003", + "file system logical offsets": "T1006", + "regsvcs/regasm": "T1218.009", + "exploitation for credential access": "T1212", + "sudo": "T1548.003", + "installutil": "T1218.004", + "query registry ": "T1012", + "launchctl": "T1569.001", + ".bash_profile and .bashrc": "T1546.004", + "applescript": "T1059.002", + "emond": "T1546.014", + "control panel items": "T1218.002", + "application shimming": "T1546.011", } @@ -215,14 +215,14 @@ def to_report_export_serializer_json(self): """Creates a dict that can be used to create a serializers.ReportExportSerializer instance """ - utc_now = datetime.utcnow().isoformat() + 'Z' + utc_now = datetime.utcnow().isoformat() + "Z" res_json = { - 'name': 'Bootstrap Training Data', - 'text': 'There is no text for this report. These sentences were mapped by human analysts.', - 'ml_model': 'humans', - 'created_on': utc_now, - 'updated_on': utc_now, - 'sentences': [] + "name": "Bootstrap Training Data", + "text": "There is no text for this report. These sentences were mapped by human analysts.", + "ml_model": "humans", + "created_on": utc_now, + "updated_on": utc_now, + "sentences": [], } order = 0 @@ -231,21 +231,18 @@ def to_report_export_serializer_json(self): continue sentence = { - 'text': sentence_text, - 'order': order, - 'disposition': 'accept', - 'mappings': [], + "text": sentence_text, + "order": order, + "disposition": "accept", + "mappings": [], } order += 1 for mapping in mappings: - mapping = { - 'attack_id': mapping, - 'confidence': '100.0' - } - sentence['mappings'].append(mapping) - res_json['sentences'].append(sentence) + mapping = {"attack_id": mapping, "confidence": "100.0"} + sentence["mappings"].append(mapping) + res_json["sentences"].append(sentence) return res_json @@ -258,20 +255,24 @@ def get_attack_id(description): def main(): - with open('data/training/archive/all_analyzed_reports.json') as f: + with open("data/training/archive/all_analyzed_reports.json") as f: all_analyzed_reports = json.load(f) - with open('data/training/archive/negative_data.json') as f: + with open("data/training/archive/negative_data.json") as f: negative_data = json.load(f) training_data = TrainingData() # Add the positives for key, value in all_analyzed_reports.items(): - if key.endswith('-multi'): # It's a multi-mapping, value is a dictionary - for sentence in value['sentances']: # Sentences is misspelled in the source data - map(partial(training_data.add_mapping, sentence), - [ATTACK_LOOKUP[name.lower()] for name in value['technique_names']]) + if key.endswith("-multi"): # It's a multi-mapping, value is a dictionary + for sentence in value[ + "sentances" + ]: # Sentences is misspelled in the source data + map( + partial(training_data.add_mapping, sentence), + [ATTACK_LOOKUP[name.lower()] for name in value["technique_names"]], + ) else: # It's a single-mapping, value is a list of sentences technique_id = get_attack_id(key) for sentence in value: @@ -285,11 +286,11 @@ def main(): res = ReportExportSerializer(data=res_json) res.is_valid(raise_exception=True) - with open(outfile, 'w') as f: + with open(outfile, "w") as f: json.dump(res.initial_data, f, indent=4) - print('Wrote data to %s' % outfile) + print("Wrote data to %s" % outfile) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/tram/manage.py b/src/tram/manage.py index ee76876f16..a83fedd617 100755 --- a/src/tram/manage.py +++ b/src/tram/manage.py @@ -6,7 +6,7 @@ def main(): """Run administrative tasks.""" - os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tram.settings') + os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tram.settings") try: from django.core.management import execute_from_command_line except ImportError as exc: @@ -18,5 +18,5 @@ def main(): execute_from_command_line(sys.argv) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/tram/tram/admin.py b/src/tram/tram/admin.py index b39459fcdc..359142be98 100644 --- a/src/tram/tram/admin.py +++ b/src/tram/tram/admin.py @@ -1,7 +1,14 @@ from django.contrib import admin -from tram.models import AttackObject, Document, DocumentProcessingJob, \ - Indicator, Mapping, Report, Sentence +from tram.models import ( + AttackObject, + Document, + DocumentProcessingJob, + Indicator, + Mapping, + Report, + Sentence, +) class IndicatorInline(admin.TabularInline): @@ -17,11 +24,11 @@ class MappingInline(admin.TabularInline): class SentenceInline(admin.TabularInline): extra = 0 model = Sentence - readonly_fields = ('text', 'document', 'order') + readonly_fields = ("text", "document", "order") class AttackObjectAdmin(admin.ModelAdmin): - readonly_fields = ('name', 'stix_id', 'attack_id', 'attack_url', 'matrix') + readonly_fields = ("name", "stix_id", "attack_id", "attack_url", "matrix") class DocumentAdmin(admin.ModelAdmin): @@ -30,11 +37,11 @@ class DocumentAdmin(admin.ModelAdmin): class ReportAdmin(admin.ModelAdmin): inlines = [IndicatorInline, MappingInline] - readonly_fields = ('document', 'text', 'ml_model') + readonly_fields = ("document", "text", "ml_model") class SentenceAdmin(admin.ModelAdmin): - readonly_fields = ('text', 'document', 'order') + readonly_fields = ("text", "document", "order") admin.site.register(AttackObject, AttackObjectAdmin) diff --git a/src/tram/tram/asgi.py b/src/tram/tram/asgi.py index f0128d4eab..a559a941e9 100644 --- a/src/tram/tram/asgi.py +++ b/src/tram/tram/asgi.py @@ -11,6 +11,6 @@ from django.core.asgi import get_asgi_application -os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tram.settings') +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "tram.settings") application = get_asgi_application() diff --git a/src/tram/tram/management/commands/attackdata.py b/src/tram/tram/management/commands/attackdata.py index b793128c98..934b0d21d0 100644 --- a/src/tram/tram/management/commands/attackdata.py +++ b/src/tram/tram/management/commands/attackdata.py @@ -5,58 +5,66 @@ from tram.models import AttackObject -LOAD = 'load' -CLEAR = 'clear' +LOAD = "load" +CLEAR = "clear" STIX_TYPE_TO_ATTACK_TYPE = { - 'attack-pattern': 'technique', - 'course-of-action': 'mitigation', - 'intrusion-set': 'group', - 'malware': 'software', - 'tool': 'software', - 'x-mitre-tactic': 'tactic', + "attack-pattern": "technique", + "course-of-action": "mitigation", + "intrusion-set": "group", + "malware": "software", + "tool": "software", + "x-mitre-tactic": "tactic", } class Command(BaseCommand): - help = 'Machine learning pipeline commands' + help = "Machine learning pipeline commands" def add_arguments(self, parser): - sp = parser.add_subparsers(title='subcommands', - dest='subcommand', - required=True) - sp_load = sp.add_parser(LOAD, help='Load ATT&CK Data into the Database') # noqa: F841 - sp_clear = sp.add_parser(CLEAR, help='Clear ATT&CK Data from the Database') # noqa: F841 + sp = parser.add_subparsers( + title="subcommands", dest="subcommand", required=True + ) + sp_load = sp.add_parser( # noqa: F841 + LOAD, help="Load ATT&CK Data into the Database" + ) + sp_clear = sp.add_parser( # noqa: F841 + CLEAR, help="Clear ATT&CK Data from the Database" + ) def clear_attack_data(self): deleted = AttackObject.objects.all().delete() - print(f'Deleted {deleted[0]} Attack objects') + print(f"Deleted {deleted[0]} Attack objects") def create_attack_object(self, obj): - for external_reference in obj['external_references']: - if external_reference['source_name'] not in ('mitre-attack', 'mitre-pre-attack', 'mitre-mobile-attack'): + for external_reference in obj["external_references"]: + if external_reference["source_name"] not in ( + "mitre-attack", + "mitre-pre-attack", + "mitre-mobile-attack", + ): continue - attack_id = external_reference['external_id'] - attack_url = external_reference['url'] - matrix = external_reference['source_name'] + attack_id = external_reference["external_id"] + attack_url = external_reference["url"] + matrix = external_reference["source_name"] assert attack_id is not None assert attack_url is not None assert matrix is not None - stix_type = obj['type'] + stix_type = obj["type"] attack_type = STIX_TYPE_TO_ATTACK_TYPE[stix_type] obj, created = AttackObject.objects.get_or_create( - name=obj['name'], - stix_id=obj['id'], + name=obj["name"], + stix_id=obj["id"], stix_type=stix_type, attack_id=attack_id, attack_type=attack_type, attack_url=attack_url, - matrix=matrix + matrix=matrix, ) return obj, created @@ -65,21 +73,27 @@ def load_attack_data(self, filepath): created_stats = {} skipped_stats = {} - with open(filepath, 'r') as f: + with open(filepath, "r") as f: attack_json = json.load(f) - assert attack_json['spec_version'] == '2.0' - assert attack_json['type'] == 'bundle' + assert attack_json["spec_version"] == "2.0" + assert attack_json["type"] == "bundle" - for obj in attack_json['objects']: - obj_type = obj['type'] + for obj in attack_json["objects"]: + obj_type = obj["type"] # TODO: Skip deprecated objects - if obj.get('revoked', False): # Skip revoked objects + if obj.get("revoked", False): # Skip revoked objects skipped_stats[obj_type] = skipped_stats.get(obj_type, 0) + 1 continue - if obj_type in ('relationship', 'course-of-action', 'identity', 'x-mitre-matrix', 'marking-definition'): + if obj_type in ( + "relationship", + "course-of-action", + "identity", + "x-mitre-matrix", + "marking-definition", + ): skipped_stats[obj_type] = skipped_stats.get(obj_type, 0) + 1 continue @@ -92,21 +106,23 @@ def load_attack_data(self, filepath): except ValueError: # Value error means unsupported object type skipped_stats[obj_type] = skipped_stats.get(obj_type, 0) + 1 - print(f'Load stats for {filepath}:') + print(f"Load stats for {filepath}:") for k, v in created_stats.items(): - print(f'\tCreated {v} {k} objects') + print(f"\tCreated {v} {k} objects") for k, v in skipped_stats.items(): - print(f'\tSkipped {v} {k} objects') + print(f"\tSkipped {v} {k} objects") def handle(self, *args, **options): - subcommand = options['subcommand'] + subcommand = options["subcommand"] if subcommand == LOAD: # Note - as of ATT&CK v8.2 # Techniques are unique among files, but # Groups are not unique among files - self.load_attack_data(settings.DATA_DIRECTORY / 'attack/enterprise-attack.json') - self.load_attack_data(settings.DATA_DIRECTORY / 'attack/mobile-attack.json') - self.load_attack_data(settings.DATA_DIRECTORY / 'attack/pre-attack.json') + self.load_attack_data( + settings.DATA_DIRECTORY / "attack/enterprise-attack.json" + ) + self.load_attack_data(settings.DATA_DIRECTORY / "attack/mobile-attack.json") + self.load_attack_data(settings.DATA_DIRECTORY / "attack/pre-attack.json") elif subcommand == CLEAR: self.clear_attack_data() diff --git a/src/tram/tram/management/commands/pipeline.py b/src/tram/tram/management/commands/pipeline.py index 9bebb61da8..2cffc3b019 100644 --- a/src/tram/tram/management/commands/pipeline.py +++ b/src/tram/tram/management/commands/pipeline.py @@ -4,68 +4,80 @@ from django.core.files import File from django.core.management.base import BaseCommand - -from tram.ml import base import tram.models as db_models from tram import serializers +from tram.ml import base - -ADD = 'add' -RUN = 'run' -TRAIN = 'train' -LOAD_TRAINING_DATA = 'load-training-data' +ADD = "add" +RUN = "run" +TRAIN = "train" +LOAD_TRAINING_DATA = "load-training-data" class Command(BaseCommand): - help = 'Machine learning pipeline commands' + help = "Machine learning pipeline commands" def add_arguments(self, parser): - sp = parser.add_subparsers(title='subcommands', - dest='subcommand', - required=True) - sp_run = sp.add_parser(RUN, help='Run the ML Pipeline') - sp_run.add_argument('--model', default='logreg', help='Select the ML model.') - sp_run.add_argument('--run-forever', default=False, action='store_true', - help='Specify whether to run forever, or quit when there are no more jobs to process') - sp_train = sp.add_parser(TRAIN, help='Train the ML Pipeline') # noqa: F841 - sp_train.add_argument('--model', default='logreg', help='Select the ML model.') - sp_add = sp.add_parser(ADD, help='Add a document for processing by the ML pipeline') - sp_add.add_argument('--file', required=True, help='Specify the file to be added') - sp_load = sp.add_parser(LOAD_TRAINING_DATA, help='Load training data. Must be formatted as a Report Export.') - sp_load.add_argument('--file', default='data/training/bootstrap-training-data.json', - help='Training data file to load. Defaults: data/training/bootstrap-training-data.json') + sp = parser.add_subparsers( + title="subcommands", dest="subcommand", required=True + ) + sp_run = sp.add_parser(RUN, help="Run the ML Pipeline") + sp_run.add_argument("--model", default="logreg", help="Select the ML model.") + sp_run.add_argument( + "--run-forever", + default=False, + action="store_true", + help="Specify whether to run forever, or quit when there are no more jobs to process", + ) + sp_train = sp.add_parser(TRAIN, help="Train the ML Pipeline") # noqa: F841 + sp_train.add_argument("--model", default="logreg", help="Select the ML model.") + sp_add = sp.add_parser( + ADD, help="Add a document for processing by the ML pipeline" + ) + sp_add.add_argument( + "--file", required=True, help="Specify the file to be added" + ) + sp_load = sp.add_parser( + LOAD_TRAINING_DATA, + help="Load training data. Must be formatted as a Report Export.", + ) + sp_load.add_argument( + "--file", + default="data/training/bootstrap-training-data.json", + help="Training data file to load. Defaults: data/training/bootstrap-training-data.json", + ) def handle(self, *args, **options): - subcommand = options['subcommand'] + subcommand = options["subcommand"] if subcommand == ADD: - filepath = options['file'] - with open(filepath, 'rb') as f: + filepath = options["file"] + with open(filepath, "rb") as f: django_file = File(f) db_models.DocumentProcessingJob.create_from_file(django_file) - self.stdout.write(f'Added file to ML Pipeline: {filepath}') + self.stdout.write(f"Added file to ML Pipeline: {filepath}") return if subcommand == LOAD_TRAINING_DATA: - filepath = options['file'] - self.stdout.write(f'Loading training data from {filepath}') - with open(filepath, 'r') as f: + filepath = options["file"] + self.stdout.write(f"Loading training data from {filepath}") + with open(filepath, "r") as f: res = serializers.ReportExportSerializer(data=json.load(f)) res.is_valid(raise_exception=True) res.save() return - model = options['model'] + model = options["model"] model_manager = base.ModelManager(model) if subcommand == RUN: - self.stdout.write(f'Running ML Pipeline with Model: {model}') - return model_manager.run_model(options['run_forever']) + self.stdout.write(f"Running ML Pipeline with Model: {model}") + return model_manager.run_model(options["run_forever"]) elif subcommand == TRAIN: - self.stdout.write(f'Training ML Model: {model}') + self.stdout.write(f"Training ML Model: {model}") start = time.time() return_value = model_manager.train_model() end = time.time() elapsed = end - start - self.stdout.write(f'Trained ML model in {elapsed} seconds') + self.stdout.write(f"Trained ML model in {elapsed} seconds") return return_value diff --git a/src/tram/tram/ml/base.py b/src/tram/tram/ml/base.py index 9fa9ccc02f..6a7b9acbe0 100644 --- a/src/tram/tram/ml/base.py +++ b/src/tram/tram/ml/base.py @@ -1,20 +1,20 @@ -from abc import ABC, abstractmethod -from datetime import datetime, timezone -from io import BytesIO -from os import path import pathlib import pickle +import re import time import traceback +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from io import BytesIO +from os import path -from bs4 import BeautifulSoup -from constance import config -from django.db import transaction -from django.conf import settings import docx import nltk import pdfplumber -import re +from bs4 import BeautifulSoup +from constance import config +from django.conf import settings +from django.db import transaction from sklearn.dummy import DummyClassifier from sklearn.feature_extraction.text import CountVectorizer from sklearn.linear_model import LogisticRegression @@ -40,7 +40,7 @@ def __init__(self, confidence=0.0, attack_id=None): self.attack_id = attack_id def __repr__(self): - return 'Confidence=%f; Attack ID=%s' % (self.confidence, self.attack_id) + return "Confidence=%f; Attack ID=%s" % (self.confidence, self.attack_id) class Report(object): @@ -55,6 +55,7 @@ class SKLearnModel(ABC): TODO: 1. Move text extraction and tokenization out of the SKLearnModel """ + def __init__(self): self.techniques_model = self.get_model() self.last_trained = None @@ -62,12 +63,13 @@ def __init__(self): self.detailed_f1_score = None if not isinstance(self.techniques_model, Pipeline): - raise TypeError('get_model() must return an sklearn.pipeline.Pipeline instance') + raise TypeError( + "get_model() must return an sklearn.pipeline.Pipeline instance" + ) @abstractmethod def get_model(self): - """Returns an sklearn.Pipeline that has fit() and predict() methods - """ + """Returns an sklearn.Pipeline that has fit() and predict() methods""" def train(self): """ @@ -86,8 +88,9 @@ def test(self): X, y = self.get_training_data() # Create training set and test set - X_train, X_test, y_train, y_test = \ - train_test_split(X, y, test_size=0.2, shuffle=True, random_state=0, stratify=y) + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, shuffle=True, random_state=0, stratify=y + ) # Train model test_model = self.get_model() @@ -100,26 +103,28 @@ def test(self): # Calculate an f1 score for each technique labels = sorted(list(set(y))) scores = f1_score(y_test, y_predicted, labels=list(set(y)), average=None) - self.detailed_f1_score = sorted(zip(labels, scores), key=lambda t: t[1], reverse=True) + self.detailed_f1_score = sorted( + zip(labels, scores), key=lambda t: t[1], reverse=True + ) # Average F1 score across techniques, weighted by the # of training examples per technique - weighted_f1 = f1_score(y_test, y_predicted, average='weighted') + weighted_f1 = f1_score(y_test, y_predicted, average="weighted") self.average_f1_score = weighted_f1 def _get_report_name(self, job): name = pathlib.Path(job.document.docfile.path).name - return 'Report for %s' % name + return "Report for %s" % name def _extract_text(self, document): suffix = pathlib.Path(document.docfile.path).suffix - if suffix == '.pdf': + if suffix == ".pdf": text = self._extract_pdf_text(document) - elif suffix == '.docx': + elif suffix == ".docx": text = self._extract_docx_text(document) - elif suffix == '.html': + elif suffix == ".html": text = self._extract_html_text(document) else: - raise ValueError('Unknown file suffix: %s' % suffix) + raise ValueError("Unknown file suffix: %s" % suffix) return text @@ -132,8 +137,12 @@ def lemmatize(self, sentence): lemma = nltk.stem.WordNetLemmatizer() # Lemmatize each word in sentence - lemmatized_sentence = ' '.join([lemma.lemmatize(w) for w in sentence.rstrip().split()]) - lemmatized_sentence = re.sub(r'\d+', '', lemmatized_sentence) # Remove digits with regex + lemmatized_sentence = " ".join( + [lemma.lemmatize(w) for w in sentence.rstrip().split()] + ) + lemmatized_sentence = re.sub( + r"\d+", "", lemmatized_sentence + ) # Remove digits with regex return lemmatized_sentence @@ -153,9 +162,14 @@ def get_training_data(self): return X, y def get_attack_object_ids(self): - objects = [obj.attack_id for obj in db_models.AttackObject.objects.all().order_by('attack_id')] + objects = [ + obj.attack_id + for obj in db_models.AttackObject.objects.all().order_by("attack_id") + ] if len(objects) == 0: - raise ValueError('Zero techniques found. Maybe run `python manage.py attackdata load` ?') + raise ValueError( + "Zero techniques found. Maybe run `python manage.py attackdata load` ?" + ) return objects def get_mappings(self, sentence): @@ -165,7 +179,9 @@ def get_mappings(self, sentence): mappings = [] techniques = self.techniques_model.classes_ - probs = self.techniques_model.predict_proba([sentence])[0] # Probability is a range between 0-1 + probs = self.techniques_model.predict_proba([sentence])[ + 0 + ] # Probability is a range between 0-1 # Create a list of tuples of (confidence, technique) confidences_and_techniques = zip(probs, techniques) @@ -185,7 +201,7 @@ def _sentence_tokenize(self, text): def _extract_pdf_text(self, document): with pdfplumber.open(BytesIO(document.docfile.read())) as pdf: - text = ''.join(page.extract_text() for page in pdf.pages) + text = "".join(page.extract_text() for page in pdf.pages) return text @@ -197,7 +213,7 @@ def _extract_html_text(self, document): def _extract_docx_text(self, document): parsed_docx = docx.Document(BytesIO(document.docfile.read())) - text = ' '.join([paragraph.text for paragraph in parsed_docx.paragraphs]) + text = " ".join([paragraph.text for paragraph in parsed_docx.paragraphs]) return text def process_job(self, job): @@ -217,24 +233,29 @@ def process_job(self, job): return report def save_to_file(self, filepath): - with open(filepath, 'wb') as f: + with open(filepath, "wb") as f: pickle.dump(self, f) @classmethod def load_from_file(cls, filepath): - with open(filepath, 'rb') as f: - model = pickle.load(f) # nosec - Accept the risk until a better design is implemented - + with open(filepath, "rb") as f: + model = pickle.load(f) # nosec + # accept risk until better design implemented assert cls == model.__class__ return model class DummyModel(SKLearnModel): def get_model(self): - return Pipeline([ - ("features", CountVectorizer(lowercase=True, stop_words='english', min_df=3)), - ("clf", DummyClassifier(strategy='uniform')) - ]) + return Pipeline( + [ + ( + "features", + CountVectorizer(lowercase=True, stop_words="english", min_df=3), + ), + ("clf", DummyClassifier(strategy="uniform")), + ] + ) class NaiveBayesModel(SKLearnModel): @@ -244,10 +265,15 @@ def get_model(self): 1) Features = document-term matrix, with stop words removed from the term vocabulary. 2) Classifier (clf) = multinomial Naive Bayes """ - return Pipeline([ - ("features", CountVectorizer(lowercase=True, stop_words='english', min_df=3)), - ("clf", MultinomialNB()) - ]) + return Pipeline( + [ + ( + "features", + CountVectorizer(lowercase=True, stop_words="english", min_df=3), + ), + ("clf", MultinomialNB()), + ] + ) class LogisticRegressionModel(SKLearnModel): @@ -257,38 +283,43 @@ def get_model(self): 1) Features = document-term matrix, with stop words removed from the term vocabulary. 2) Classifier (clf) = multinomial logistic regression """ - return Pipeline([ - ("features", CountVectorizer(lowercase=True, stop_words='english', min_df=3)), - ("clf", LogisticRegression()) - ]) + return Pipeline( + [ + ( + "features", + CountVectorizer(lowercase=True, stop_words="english", min_df=3), + ), + ("clf", LogisticRegression()), + ] + ) class ModelManager(object): model_registry = { # TODO: Add a hook to register user-created models - 'dummy': DummyModel, - 'nb': NaiveBayesModel, - 'logreg': LogisticRegressionModel, + "dummy": DummyModel, + "nb": NaiveBayesModel, + "logreg": LogisticRegressionModel, } def __init__(self, model): model_class = self.model_registry.get(model) if not model_class: - raise ValueError('Unrecognized model: %s' % model) + raise ValueError("Unrecognized model: %s" % model) model_filepath = self.get_model_filepath(model_class) if path.exists(model_filepath): self.model = model_class.load_from_file(model_filepath) - print('%s loaded from %s' % (model_class.__name__, model_filepath)) + print("%s loaded from %s" % (model_class.__name__, model_filepath)) else: self.model = model_class() - print('%s loaded from __init__' % model_class.__name__) + print("%s loaded from __init__" % model_class.__name__) def _save_report(self, report, document): rpt = db_models.Report( name=report.name, document=document, text=report.text, - ml_model=self.model.__class__.__name__ + ml_model=self.model.__class__.__name__, ) rpt.save() @@ -304,7 +335,9 @@ def _save_report(self, report, document): for mapping in sentence.mappings: if mapping.attack_id: - obj = db_models.AttackObject.objects.get(attack_id=mapping.attack_id) + obj = db_models.AttackObject.objects.get( + attack_id=mapping.attack_id + ) else: obj = None @@ -318,21 +351,23 @@ def _save_report(self, report, document): def run_model(self, run_forever=False): while True: - jobs = db_models.DocumentProcessingJob.objects.filter(status='queued').order_by('created_on') + jobs = db_models.DocumentProcessingJob.objects.filter( + status="queued" + ).order_by("created_on") for job in jobs: filename = job.document.docfile.name - print('Processing Job #%d: %s' % (job.id, filename)) + print("Processing Job #%d: %s" % (job.id, filename)) try: report = self.model.process_job(job) with transaction.atomic(): self._save_report(report, job.document) job.delete() - print('Created report %s' % report.name) + print("Created report %s" % report.name) except Exception as ex: - job.status = 'error' + job.status = "error" job.message = str(ex) job.save() - print(f'Failed to create report for {filename}.') + print(f"Failed to create report for {filename}.") print(traceback.format_exc()) if not run_forever: @@ -340,7 +375,7 @@ def run_model(self, run_forever=False): time.sleep(1) def get_model_filepath(self, model_class): - filepath = settings.ML_MODEL_DIR + '/' + model_class.__name__ + '.pkl' + filepath = settings.ML_MODEL_DIR + "/" + model_class.__name__ + ".pkl" return filepath def train_model(self): @@ -348,7 +383,7 @@ def train_model(self): self.model.test() filepath = self.get_model_filepath(self.model.__class__) self.model.save_to_file(filepath) - print('Trained model saved to %s' % filepath) + print("Trained model saved to %s" % filepath) return @staticmethod @@ -361,7 +396,9 @@ def get_all_model_metadata(): model_metadata = ModelManager.get_model_metadata(model_key) all_model_metadata.append(model_metadata) - all_model_metadata = sorted(all_model_metadata, key=lambda i: i['average_f1_score'], reverse=True) + all_model_metadata = sorted( + all_model_metadata, key=lambda i: i["average_f1_score"], reverse=True + ) return all_model_metadata @@ -373,34 +410,38 @@ def get_model_metadata(model_key): mm = ModelManager(model_key) model_name = mm.model.__class__.__name__ if mm.model.last_trained is None: - last_trained = 'Never trained' + last_trained = "Never trained" trained_techniques_count = 0 else: - last_trained = mm.model.last_trained.strftime('%m/%d/%Y %H:%M:%S UTC') + last_trained = mm.model.last_trained.strftime("%m/%d/%Y %H:%M:%S UTC") trained_techniques_count = len(mm.model.detailed_f1_score) average_f1_score = round((mm.model.average_f1_score or 0.0) * 100, 2) stored_scores = mm.model.detailed_f1_score or [] attack_ids = set([score[0] for score in stored_scores]) - attack_techniques = db_models.AttackObject.objects.filter(attack_id__in=attack_ids) + attack_techniques = db_models.AttackObject.objects.filter( + attack_id__in=attack_ids + ) detailed_f1_score = [] for score in stored_scores: score_id = score[0] score_value = round(score[1] * 100, 2) attack_technique = attack_techniques.get(attack_id=score_id) - detailed_f1_score.append({ - 'technique': score_id, - 'technique_name': attack_technique.name, - 'attack_url': attack_technique.attack_url, - 'score': score_value - }) + detailed_f1_score.append( + { + "technique": score_id, + "technique_name": attack_technique.name, + "attack_url": attack_technique.attack_url, + "score": score_value, + } + ) model_metadata = { - 'model_key': model_key, - 'name': model_name, - 'last_trained': last_trained, - 'trained_techniques_count': trained_techniques_count, - 'average_f1_score': average_f1_score, - 'detailed_f1_score': detailed_f1_score, + "model_key": model_key, + "name": model_name, + "last_trained": last_trained, + "trained_techniques_count": trained_techniques_count, + "average_f1_score": average_f1_score, + "detailed_f1_score": detailed_f1_score, } return model_metadata diff --git a/src/tram/tram/models.py b/src/tram/tram/models.py index f1b2a7fd4c..60367ad9a7 100644 --- a/src/tram/tram/models.py +++ b/src/tram/tram/models.py @@ -9,22 +9,22 @@ from django.dispatch.dispatcher import receiver DISPOSITION_CHOICES = ( - ('accept', 'Accepted'), - ('reject', 'Rejected'), - (None, 'No Disposition'), + ("accept", "Accepted"), + ("reject", "Rejected"), + (None, "No Disposition"), ) JOB_STATUS_CHOICES = ( - ('queued', 'Queued'), - ('error', 'Error'), + ("queued", "Queued"), + ("error", "Error"), ) SENTENCE_PREVIEW_CHARS = 40 class AttackObject(models.Model): - """Attack Techniques - """ + """Attack Techniques""" + name = models.CharField(max_length=200) stix_id = models.CharField(max_length=128, unique=True) stix_type = models.CharField(max_length=128) @@ -35,7 +35,7 @@ class AttackObject(models.Model): created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) - sentences = models.ManyToManyField('Sentence', through='Mapping') + sentences = models.ManyToManyField("Sentence", through="Mapping") @classmethod def get_sentence_counts(cls, accept_threshold=0): @@ -44,20 +44,28 @@ def get_sentence_counts(cls, accept_threshold=0): return: The list of AttackTechnique objects, annotated with how many training sentences have been accepted, pending, and there are in total. """ - sentence_counts = cls.objects.annotate( - accepted_sentences=Count('sentences', filter=Q(sentences__disposition='accept')), - pending_sentences=Count('sentences', filter=Q(sentences__disposition=None)), - total_sentences=Count('sentences') - ).order_by('-accepted_sentences', 'attack_id').filter(accepted_sentences__gte=accept_threshold) + sentence_counts = ( + cls.objects.annotate( + accepted_sentences=Count( + "sentences", filter=Q(sentences__disposition="accept") + ), + pending_sentences=Count( + "sentences", filter=Q(sentences__disposition=None) + ), + total_sentences=Count("sentences"), + ) + .order_by("-accepted_sentences", "attack_id") + .filter(accepted_sentences__gte=accept_threshold) + ) return sentence_counts def __str__(self): - return '(%s) %s' % (self.attack_id, self.name) + return "(%s) %s" % (self.attack_id, self.name) class Document(models.Model): - """Store all documents that can be analyzed to create reports - """ + """Store all documents that can be analyzed to create reports""" + docfile = models.FileField() created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) @@ -68,11 +76,13 @@ def __str__(self): class DocumentProcessingJob(models.Model): - """Queue of document processing jobs - """ + """Queue of document processing jobs""" + document = models.ForeignKey(Document, on_delete=models.CASCADE) - status = models.CharField(max_length=255, default='queued', choices=JOB_STATUS_CHOICES) - message = models.CharField(max_length=16384, default='') + status = models.CharField( + max_length=255, default="queued", choices=JOB_STATUS_CHOICES + ) + message = models.CharField(max_length=16384, default="") created_by = models.ForeignKey(User, null=True, on_delete=models.SET_NULL) created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) @@ -87,12 +97,12 @@ def create_from_file(cls, f): return dpj def __str__(self): - return 'Process %s' % self.document.docfile.name + return "Process %s" % self.document.docfile.name class Report(models.Model): - """Store reports - """ + """Store reports""" + name = models.CharField(max_length=200) document = models.ForeignKey(Document, null=True, on_delete=models.CASCADE) text = models.TextField() @@ -106,8 +116,8 @@ def __str__(self): class Indicator(models.Model): - """Indicators extracted from a document for a report - """ + """Indicators extracted from a document for a report""" + report = models.ForeignKey(Report, on_delete=models.CASCADE) indicator_type = models.CharField(max_length=200) value = models.CharField(max_length=200) @@ -115,31 +125,37 @@ class Indicator(models.Model): updated_on = models.DateTimeField(auto_now=True) def __str__(self): - return '%s: %s' % (self.indicator_type, self.value) + return "%s: %s" % (self.indicator_type, self.value) class Sentence(models.Model): text = models.TextField() document = models.ForeignKey(Document, null=True, on_delete=models.CASCADE) - order = models.IntegerField(default=1000) # Sentences with lower numbers are displayed first + order = models.IntegerField( + default=1000 + ) # Sentences with lower numbers are displayed first report = models.ForeignKey(Report, on_delete=models.CASCADE) - disposition = models.CharField(max_length=200, default=None, null=True, blank=True, choices=DISPOSITION_CHOICES) + disposition = models.CharField( + max_length=200, default=None, null=True, blank=True, choices=DISPOSITION_CHOICES + ) created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) def __str__(self): - append = '' + append = "" if len(self.text) > SENTENCE_PREVIEW_CHARS: - append = '...' + append = "..." return self.text[:SENTENCE_PREVIEW_CHARS] + append class Mapping(models.Model): - """Maps sentences to Attack TTPs - """ + """Maps sentences to Attack TTPs""" + report = models.ForeignKey(Report, on_delete=models.CASCADE) sentence = models.ForeignKey(Sentence, on_delete=models.CASCADE) - attack_object = models.ForeignKey(AttackObject, on_delete=models.CASCADE, blank=True, null=True) + attack_object = models.ForeignKey( + AttackObject, on_delete=models.CASCADE, blank=True, null=True + ) confidence = models.FloatField() created_on = models.DateTimeField(auto_now_add=True) updated_on = models.DateTimeField(auto_now=True) @@ -150,15 +166,16 @@ def __str__(self): @classmethod def get_accepted_mappings(cls): # Get Attack techniques that have the required amount of positive examples - attack_objects = AttackObject.get_sentence_counts(accept_threshold=config.ML_ACCEPT_THRESHOLD) + attack_objects = AttackObject.get_sentence_counts( + accept_threshold=config.ML_ACCEPT_THRESHOLD + ) # Get mappings for the attack techniques above threshold mappings = Mapping.objects.filter(attack_object__in=attack_objects) return mappings class MLSettings(models.Model): - """Settings for Machine Learning models - """ + """Settings for Machine Learning models""" def _delete_file(path): diff --git a/src/tram/tram/serializers.py b/src/tram/tram/serializers.py index 60475ee857..f17e73bdb5 100644 --- a/src/tram/tram/serializers.py +++ b/src/tram/tram/serializers.py @@ -7,35 +7,48 @@ class AttackObjectSerializer(serializers.ModelSerializer): class Meta: model = db_models.AttackObject - fields = ['id', 'attack_id', 'name'] + fields = ["id", "attack_id", "name"] class DocumentProcessingJobSerializer(serializers.ModelSerializer): """Needs to be kept in sync with ReportSerializer for display purposes""" + name = serializers.SerializerMethodField() byline = serializers.SerializerMethodField() status = serializers.SerializerMethodField() class Meta: model = db_models.DocumentProcessingJob - fields = ['id', 'name', 'byline', 'status', 'message', 'created_by', 'created_on', 'updated_on'] - order = ['-created_on'] + fields = [ + "id", + "name", + "byline", + "status", + "message", + "created_by", + "created_on", + "updated_on", + ] + order = ["-created_on"] def get_name(self, obj): name = obj.document.docfile.name return name def get_byline(self, obj): - byline = '%s on %s' % (obj.created_by, obj.created_on.strftime('%Y-%M-%d %H:%M:%S UTC')) + byline = "%s on %s" % ( + obj.created_by, + obj.created_on.strftime("%Y-%M-%d %H:%M:%S UTC"), + ) return byline def get_status(self, obj): - if obj.status == 'queued': - return 'Queued' - elif obj.status == 'error': - return 'Error' + if obj.status == "queued": + return "Queued" + elif obj.status == "error": + return "Error" else: - return 'Unknown' + return "Unknown" class MappingSerializer(serializers.ModelSerializer): @@ -45,7 +58,7 @@ class MappingSerializer(serializers.ModelSerializer): class Meta: model = db_models.Mapping - fields = ['id', 'attack_id', 'name', 'confidence'] + fields = ["id", "attack_id", "name", "confidence"] def get_attack_id(self, obj): return obj.attack_object.attack_id @@ -63,24 +76,26 @@ def to_internal_value(self, data): internal_value = super().to_internal_value(data) # Keeps model fields # Add necessary fields - attack_object = db_models.AttackObject.objects.get(attack_id=data['attack_id']) - sentence = db_models.Sentence.objects.get(id=data['sentence']) - report = db_models.Report.objects.get(id=data['report']) - - internal_value.update({ - 'report': report, - 'sentence': sentence, - 'attack_object': attack_object, - }) + attack_object = db_models.AttackObject.objects.get(attack_id=data["attack_id"]) + sentence = db_models.Sentence.objects.get(id=data["sentence"]) + report = db_models.Report.objects.get(id=data["report"]) + + internal_value.update( + { + "report": report, + "sentence": sentence, + "attack_object": attack_object, + } + ) return internal_value def create(self, validated_data): mapping = db_models.Mapping.objects.create( - report=validated_data['report'], - sentence=validated_data['sentence'], - attack_object=validated_data['attack_object'], - confidence=validated_data['confidence'] + report=validated_data["report"], + sentence=validated_data["sentence"], + attack_object=validated_data["attack_object"], + confidence=validated_data["confidence"], ) return mapping @@ -95,12 +110,26 @@ class ReportSerializer(serializers.ModelSerializer): class Meta: model = db_models.Report - fields = ['id', 'name', 'byline', 'accepted_sentences', 'reviewing_sentences', 'total_sentences', - 'text', 'ml_model', 'created_by', 'created_on', 'updated_on', 'status'] - order = ['-created_on'] + fields = [ + "id", + "name", + "byline", + "accepted_sentences", + "reviewing_sentences", + "total_sentences", + "text", + "ml_model", + "created_by", + "created_on", + "updated_on", + "status", + ] + order = ["-created_on"] def get_accepted_sentences(self, obj): - count = db_models.Sentence.objects.filter(disposition='accept', report=obj).count() + count = db_models.Sentence.objects.filter( + disposition="accept", report=obj + ).count() return count def get_reviewing_sentences(self, obj): @@ -112,26 +141,32 @@ def get_total_sentences(self, obj): return count def get_byline(self, obj): - byline = '%s on %s' % (obj.created_by, obj.created_on.strftime('%Y-%m-%d %H:%M:%S UTC')) + byline = "%s on %s" % ( + obj.created_by, + obj.created_on.strftime("%Y-%m-%d %H:%M:%S UTC"), + ) return byline def get_status(self, obj): reviewing_sentences = self.get_reviewing_sentences(obj) - status = 'Reviewing' + status = "Reviewing" if reviewing_sentences == 0: - status = 'Accepted' + status = "Accepted" return status class ReportExportSerializer(ReportSerializer): """Defines the export format for reports. Defined separately from ReportSerializer so that: - 1. ReportSerializer and ReportExportSerializer can evolve independently - 2. The export is larger than what the REST API needs + 1. ReportSerializer and ReportExportSerializer can evolve independently + 2. The export is larger than what the REST API needs """ + sentences = serializers.SerializerMethodField() class Meta(ReportSerializer.Meta): - fields = ReportSerializer.Meta.fields + ['sentences', ] + fields = ReportSerializer.Meta.fields + [ + "sentences", + ] def get_sentences(self, obj): sentences = db_models.Sentence.objects.filter(report=obj) @@ -148,28 +183,30 @@ def to_internal_value(self, data): internal_value = super().to_internal_value(data) # Keeps model fields # Add sentences - sentence_serializers = [SentenceSerializer(data=sentence) for sentence in data.get('sentences', [])] + sentence_serializers = [ + SentenceSerializer(data=sentence) for sentence in data.get("sentences", []) + ] - internal_value.update({'sentences': sentence_serializers}) + internal_value.update({"sentences": sentence_serializers}) return internal_value def create(self, validated_data): with transaction.atomic(): report = db_models.Report.objects.create( - name=validated_data['name'], - document=None, - text=validated_data['text'], - ml_model=validated_data['ml_model'], - created_by=None, # TODO: Get user from session - ) - - for sentence in validated_data['sentences']: + name=validated_data["name"], + document=None, + text=validated_data["text"], + ml_model=validated_data["ml_model"], + created_by=None, # TODO: Get user from session + ) + + for sentence in validated_data["sentences"]: if sentence.is_valid(): - sentence.validated_data['report'] = report + sentence.validated_data["report"] = report sentence.save() else: # TODO: Handle this case better - raise Exception('Sentence validation needs to be handled better') + raise Exception("Sentence validation needs to be handled better") return report @@ -182,7 +219,7 @@ class SentenceSerializer(serializers.ModelSerializer): class Meta: model = db_models.Sentence - fields = ['id', 'text', 'order', 'disposition', 'mappings'] + fields = ["id", "text", "order", "disposition", "mappings"] def get_mappings(self, obj): mappings = db_models.Mapping.objects.filter(sentence=obj) @@ -199,27 +236,29 @@ def to_internal_value(self, data): internal_value = super().to_internal_value(data) # Keeps model fields # Add mappings - mapping_serializers = [MappingSerializer(data=mapping) for mapping in data.get('mappings', [])] + mapping_serializers = [ + MappingSerializer(data=mapping) for mapping in data.get("mappings", []) + ] - internal_value.update({'mappings': mapping_serializers}) + internal_value.update({"mappings": mapping_serializers}) return internal_value def create(self, validated_data): with transaction.atomic(): sentence = db_models.Sentence.objects.create( - text=validated_data['text'], + text=validated_data["text"], document=None, - report=validated_data['report'], - disposition=validated_data['disposition'] + report=validated_data["report"], + disposition=validated_data["disposition"], ) - for mapping in validated_data.get('mappings', []): - mapping.initial_data['sentence'] = sentence.id - mapping.initial_data['report'] = validated_data['report'].id + for mapping in validated_data.get("mappings", []): + mapping.initial_data["sentence"] = sentence.id + mapping.initial_data["report"] = validated_data["report"].id if mapping.is_valid(): mapping.save() else: # TODO: Handle this case better - raise Exception('Mapping validation needs to be handled better') + raise Exception("Mapping validation needs to be handled better") return sentence diff --git a/src/tram/tram/templates/analyze.html b/src/tram/tram/templates/analyze.html index cba23cebd2..637c266791 100644 --- a/src/tram/tram/templates/analyze.html +++ b/src/tram/tram/templates/analyze.html @@ -53,4 +53,4 @@