diff --git a/empire/server/api/app.py b/empire/server/api/app.py index 01ab6791e..e50e5afb4 100644 --- a/empire/server/api/app.py +++ b/empire/server/api/app.py @@ -66,6 +66,7 @@ def initialize(secure: bool = False, ip: str = "0.0.0.0", port: int = 1337): from empire.server.api.v2.plugin import plugin_api, plugin_task_api from empire.server.api.v2.profile import profile_api from empire.server.api.v2.stager import stager_api, stager_template_api + from empire.server.api.v2.tag import tag_api from empire.server.api.v2.user import user_api from empire.server.server import main @@ -97,6 +98,7 @@ def shutdown_event(): v2App.include_router(meta_api.router) v2App.include_router(plugin_task_api.router) v2App.include_router(plugin_api.router) + v2App.include_router(tag_api.router) v2App.add_middleware( EmpireCORSMiddleware, diff --git a/empire/server/api/v2/agent/agent_api.py b/empire/server/api/v2/agent/agent_api.py index 4f67b1467..ec81a41df 100644 --- a/empire/server/api/v2/agent/agent_api.py +++ b/empire/server/api/v2/agent/agent_api.py @@ -24,6 +24,7 @@ NotFoundResponse, OrderDirection, ) +from empire.server.api.v2.tag import tag_api from empire.server.core.config import empire_config from empire.server.core.db import models from empire.server.server import main @@ -50,6 +51,9 @@ async def get_agent(uid: str, db: Session = Depends(get_db)): raise HTTPException(404, f"Agent not found for id {uid}") +tag_api.add_endpoints_to_taggable(router, "/{uid}/tags", get_agent) + + @router.get("/checkins", response_model=AgentCheckIns) def read_agent_checkins_all( db: Session = Depends(get_db), diff --git a/empire/server/api/v2/agent/agent_dto.py b/empire/server/api/v2/agent/agent_dto.py index 3dea2d590..667cb14ee 100644 --- a/empire/server/api/v2/agent/agent_dto.py +++ b/empire/server/api/v2/agent/agent_dto.py @@ -5,6 +5,7 @@ from pydantic import BaseModel from empire.server.api.v2.shared_dto import PROXY_ID +from empire.server.api.v2.tag.tag_dto import Tag, domain_to_dto_tag from empire.server.core.db import models @@ -48,6 +49,7 @@ def domain_to_dto_agent(agent: models.Agent): archived=agent.archived, # Could make this a typed class later to match the schema proxies=to_proxy_dto(agent.proxies), + tags=list(map(lambda x: domain_to_dto_tag(x), agent.tags)), ) @@ -111,6 +113,7 @@ class Agent(BaseModel): archived: bool stale: bool proxies: Optional[Dict] + tags: List[Tag] class Agents(BaseModel): diff --git a/empire/server/api/v2/agent/agent_task_api.py b/empire/server/api/v2/agent/agent_task_api.py index 32aae1a41..9dfa90693 100644 --- a/empire/server/api/v2/agent/agent_task_api.py +++ b/empire/server/api/v2/agent/agent_task_api.py @@ -38,6 +38,8 @@ NotFoundResponse, OrderDirection, ) +from empire.server.api.v2.tag import tag_api +from empire.server.api.v2.tag.tag_dto import TagStr from empire.server.core.agent_service import AgentService from empire.server.core.agent_task_service import AgentTaskService from empire.server.core.db import models @@ -83,6 +85,9 @@ async def get_task( ) +tag_api.add_endpoints_to_taggable(router, "/{agent_id}/tasks/{uid}/tags", get_task) + + @router.get("/tasks", response_model=AgentTasks) async def read_tasks_all_agents( limit: int = -1, @@ -96,6 +101,7 @@ async def read_tasks_all_agents( status: Optional[AgentTaskStatus] = None, agents: Optional[List[str]] = Query(None), users: Optional[List[int]] = Query(None), + tags: Optional[List[TagStr]] = Query(None), query: Optional[str] = None, db: Session = Depends(get_db), ): @@ -103,6 +109,7 @@ async def read_tasks_all_agents( db, agents=agents, users=users, + tags=tags, limit=limit, offset=(page - 1) * limit, include_full_input=include_full_input, @@ -145,6 +152,7 @@ async def read_tasks( order_direction: OrderDirection = OrderDirection.desc, status: Optional[AgentTaskStatus] = None, users: Optional[List[int]] = Query(None), + tags: Optional[List[TagStr]] = Query(None), db: Session = Depends(get_db), db_agent: models.Agent = Depends(get_agent), query: Optional[str] = None, @@ -153,6 +161,7 @@ async def read_tasks( db, agents=[db_agent.session_id], users=users, + tags=tags, limit=limit, offset=(page - 1) * limit, include_full_input=include_full_input, diff --git a/empire/server/api/v2/agent/agent_task_dto.py b/empire/server/api/v2/agent/agent_task_dto.py index 95375c15d..7500ef414 100644 --- a/empire/server/api/v2/agent/agent_task_dto.py +++ b/empire/server/api/v2/agent/agent_task_dto.py @@ -8,6 +8,7 @@ DownloadDescription, domain_to_dto_download_description, ) +from empire.server.api.v2.tag.tag_dto import Tag, domain_to_dto_tag from empire.server.core.db import models @@ -41,6 +42,7 @@ def domain_to_dto_task( status=task.status, created_at=task.created_at, updated_at=task.updated_at, + tags=list(map(lambda x: domain_to_dto_tag(x), task.tags)), ) @@ -59,6 +61,7 @@ class AgentTask(BaseModel): status: models.AgentTaskStatus created_at: datetime updated_at: datetime + tags: List[Tag] class AgentTasks(BaseModel): diff --git a/empire/server/api/v2/credential/credential_api.py b/empire/server/api/v2/credential/credential_api.py index 341f66ee6..770d4c1f8 100644 --- a/empire/server/api/v2/credential/credential_api.py +++ b/empire/server/api/v2/credential/credential_api.py @@ -16,6 +16,7 @@ ) from empire.server.api.v2.shared_dependencies import get_db from empire.server.api.v2.shared_dto import BadRequestResponse, NotFoundResponse +from empire.server.api.v2.tag import tag_api from empire.server.core.db import models from empire.server.server import main @@ -41,6 +42,9 @@ async def get_credential(uid: int, db: Session = Depends(get_db)): raise HTTPException(404, f"Credential not found for id {uid}") +tag_api.add_endpoints_to_taggable(router, "/{uid}/tags", get_credential) + + @router.get("/{uid}", response_model=Credential) async def read_credential( uid: int, db_credential: models.Credential = Depends(get_credential) diff --git a/empire/server/api/v2/credential/credential_dto.py b/empire/server/api/v2/credential/credential_dto.py index e1c92b9a1..598d8c5f2 100644 --- a/empire/server/api/v2/credential/credential_dto.py +++ b/empire/server/api/v2/credential/credential_dto.py @@ -3,6 +3,8 @@ from pydantic import BaseModel +from empire.server.api.v2.tag.tag_dto import Tag, domain_to_dto_tag + def domain_to_dto_credential(credential): return Credential( @@ -17,6 +19,7 @@ def domain_to_dto_credential(credential): notes=credential.notes, created_at=credential.created_at, updated_at=credential.updated_at, + tags=list(map(lambda x: domain_to_dto_tag(x), credential.tags)), ) @@ -32,6 +35,7 @@ class Credential(BaseModel): notes: Optional[str] created_at: datetime updated_at: datetime + tags: List[Tag] class Credentials(BaseModel): diff --git a/empire/server/api/v2/download/download_api.py b/empire/server/api/v2/download/download_api.py index bc690ce83..a6ff1045f 100644 --- a/empire/server/api/v2/download/download_api.py +++ b/empire/server/api/v2/download/download_api.py @@ -20,6 +20,8 @@ NotFoundResponse, OrderDirection, ) +from empire.server.api.v2.tag import tag_api +from empire.server.api.v2.tag.tag_dto import TagStr from empire.server.core.db import models from empire.server.server import main @@ -59,8 +61,12 @@ async def download_download( return FileResponse(db_download.location, filename=filename) +tag_api.add_endpoints_to_taggable(router, "/{uid}/tags", get_download) + + @router.get( "/{uid}", + response_model=Download, ) async def read_download( uid: int, @@ -79,10 +85,12 @@ async def read_downloads( order_by: DownloadOrderOptions = DownloadOrderOptions.updated_at, query: Optional[str] = None, sources: Optional[List[DownloadSourceFilter]] = Query(None), + tags: Optional[List[TagStr]] = Query(None), ): downloads, total = download_service.get_all( db=db, download_types=sources, + tags=tags, q=query, limit=limit, offset=(page - 1) * limit, diff --git a/empire/server/api/v2/download/download_dto.py b/empire/server/api/v2/download/download_dto.py index e729ef0b7..f271c22c4 100644 --- a/empire/server/api/v2/download/download_dto.py +++ b/empire/server/api/v2/download/download_dto.py @@ -4,6 +4,8 @@ from pydantic import BaseModel +from empire.server.api.v2.tag.tag_dto import Tag, domain_to_dto_tag + def removeprefix(value: str, prefix: str) -> str: if value.startswith(prefix): @@ -21,6 +23,7 @@ def domain_to_dto_download(download): size=download.size, created_at=download.created_at, updated_at=download.updated_at, + tags=list(map(lambda x: domain_to_dto_tag(x), download.tags)), ) @@ -46,6 +49,7 @@ class Download(BaseModel): size: int created_at: datetime updated_at: datetime + tags: List[Tag] class Downloads(BaseModel): diff --git a/empire/server/api/v2/listener/listener_api.py b/empire/server/api/v2/listener/listener_api.py index 2634e38f8..1386e1381 100644 --- a/empire/server/api/v2/listener/listener_api.py +++ b/empire/server/api/v2/listener/listener_api.py @@ -14,6 +14,7 @@ ) from empire.server.api.v2.shared_dependencies import get_db from empire.server.api.v2.shared_dto import BadRequestResponse, NotFoundResponse +from empire.server.api.v2.tag import tag_api from empire.server.core.db import models from empire.server.server import main @@ -39,6 +40,9 @@ async def get_listener(uid: int, db: Session = Depends(get_db)): raise HTTPException(404, f"Listener not found for id {uid}") +tag_api.add_endpoints_to_taggable(router, "/{uid}/tags", get_listener) + + @router.get("/{uid}", response_model=Listener) async def read_listener(uid: int, db_listener: models.Listener = Depends(get_listener)): return domain_to_dto_listener(db_listener) diff --git a/empire/server/api/v2/listener/listener_dto.py b/empire/server/api/v2/listener/listener_dto.py index 0b0e55c83..848310641 100644 --- a/empire/server/api/v2/listener/listener_dto.py +++ b/empire/server/api/v2/listener/listener_dto.py @@ -4,6 +4,7 @@ from pydantic import BaseModel from empire.server.api.v2.shared_dto import Author, CustomOptionSchema, to_value_type +from empire.server.api.v2.tag.tag_dto import Tag, domain_to_dto_tag def domain_to_dto_template(listener, uid: str): @@ -59,6 +60,7 @@ def domain_to_dto_listener(listener): enabled=listener.enabled, options=options, created_at=listener.created_at, + tags=list(map(lambda x: domain_to_dto_tag(x), listener.tags)), ) @@ -249,6 +251,7 @@ class Listener(BaseModel): template: str options: Dict[str, str] created_at: datetime + tags: List[Tag] class Listeners(BaseModel): diff --git a/empire/server/api/v2/plugin/plugin_task_api.py b/empire/server/api/v2/plugin/plugin_task_api.py index 008847bd8..f01924ac2 100644 --- a/empire/server/api/v2/plugin/plugin_task_api.py +++ b/empire/server/api/v2/plugin/plugin_task_api.py @@ -19,6 +19,8 @@ NotFoundResponse, OrderDirection, ) +from empire.server.api.v2.tag import tag_api +from empire.server.api.v2.tag.tag_dto import TagStr from empire.server.core.db import models from empire.server.core.db.models import PluginTaskStatus from empire.server.core.download_service import DownloadService @@ -59,6 +61,9 @@ async def get_task(uid: int, db: Session = Depends(get_db), plugin=Depends(get_p ) +tag_api.add_endpoints_to_taggable(router, "/{plugin_id}/tasks/{uid}/tags", get_task) + + @router.get("/tasks", response_model=PluginTasks) async def read_tasks_all_plugins( limit: int = -1, @@ -71,6 +76,7 @@ async def read_tasks_all_plugins( status: Optional[PluginTaskStatus] = None, plugins: Optional[List[str]] = Query(None), users: Optional[List[int]] = Query(None), + tags: Optional[List[TagStr]] = Query(None), query: Optional[str] = None, db: Session = Depends(get_db), ): @@ -78,6 +84,7 @@ async def read_tasks_all_plugins( db, plugins=plugins, users=users, + tags=tags, limit=limit, offset=(page - 1) * limit, include_full_input=include_full_input, @@ -116,6 +123,7 @@ async def read_tasks( order_direction: OrderDirection = OrderDirection.desc, status: Optional[PluginTaskStatus] = None, users: Optional[List[int]] = Query(None), + tags: Optional[List[TagStr]] = Query(None), db: Session = Depends(get_db), plugin=Depends(get_plugin), query: Optional[str] = None, @@ -124,6 +132,7 @@ async def read_tasks( db, plugins=[plugin.info["Name"]], users=users, + tags=tags, limit=limit, offset=(page - 1) * limit, include_full_input=include_full_input, diff --git a/empire/server/api/v2/plugin/plugin_task_dto.py b/empire/server/api/v2/plugin/plugin_task_dto.py index 902fad9be..16b9941dc 100644 --- a/empire/server/api/v2/plugin/plugin_task_dto.py +++ b/empire/server/api/v2/plugin/plugin_task_dto.py @@ -8,6 +8,7 @@ DownloadDescription, domain_to_dto_download_description, ) +from empire.server.api.v2.tag.tag_dto import Tag, domain_to_dto_tag from empire.server.core.db import models @@ -37,6 +38,7 @@ def domain_to_dto_plugin_task( status=task.status, created_at=task.created_at, updated_at=task.updated_at, + tags=list(map(lambda x: domain_to_dto_tag(x), task.tags)), ) @@ -52,6 +54,7 @@ class PluginTask(BaseModel): status: Optional[models.PluginTaskStatus] created_at: datetime updated_at: datetime + tags: List[Tag] class PluginTasks(BaseModel): diff --git a/empire/server/api/v2/tag/__init__.py b/empire/server/api/v2/tag/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/empire/server/api/v2/tag/tag_api.py b/empire/server/api/v2/tag/tag_api.py new file mode 100644 index 000000000..8ea1071fd --- /dev/null +++ b/empire/server/api/v2/tag/tag_api.py @@ -0,0 +1,121 @@ +import math +from typing import List, Optional, Union + +from fastapi import Depends, HTTPException, Query +from sqlalchemy.orm import Session +from starlette.responses import Response +from starlette.status import HTTP_201_CREATED, HTTP_204_NO_CONTENT + +from empire.server.api.api_router import APIRouter +from empire.server.api.jwt_auth import get_current_active_user +from empire.server.api.v2.shared_dependencies import get_db +from empire.server.api.v2.shared_dto import ( + BadRequestResponse, + NotFoundResponse, + OrderDirection, +) +from empire.server.api.v2.tag.tag_dto import ( + TagOrderOptions, + TagRequest, + Tags, + TagSourceFilter, + domain_to_dto_tag, +) +from empire.server.core.db import models +from empire.server.server import main + +tag_service = main.tagsv2 + + +router = APIRouter( + prefix="/api/v2/tags", + tags=["tags"], + responses={ + 404: {"description": "Not found", "model": NotFoundResponse}, + 400: {"description": "Bad request", "model": BadRequestResponse}, + }, + dependencies=[Depends(get_current_active_user)], +) + + +@router.get("/") +async def get_tags( + db: Session = Depends(get_db), + limit: int = -1, + page: int = 1, + order_direction: OrderDirection = OrderDirection.asc, + order_by: TagOrderOptions = TagOrderOptions.updated_at, + query: Optional[str] = None, + sources: Optional[List[TagSourceFilter]] = Query(None), +): + tags, total = tag_service.get_all( + db=db, + tag_types=sources, + q=query, + limit=limit, + offset=(page - 1) * limit, + order_by=order_by, + order_direction=order_direction, + ) + + tags_converted = list(map(lambda x: domain_to_dto_tag(x), tags)) + + return Tags( + records=tags_converted, + page=page, + total_pages=math.ceil(total / limit) if limit > 0 else page, + limit=limit, + total=total, + ) + + +def add_endpoints_to_taggable(router, path, get_taggable): + async def get_tag(tag_id: int, db: Session = Depends(get_db)): + tag = tag_service.get_by_id(db, tag_id) + + if tag: + return tag + + raise HTTPException(404, f"Tag not found for id {tag_id}") + + async def add_tag( + uid: Union[int, str], + tag_req: TagRequest, + db_taggable=Depends(get_taggable), + db: Session = Depends(get_db), + ): + tag = tag_service.add_tag(db, db_taggable, tag_req) + + return domain_to_dto_tag(tag) + + async def update_tag( + uid: Union[int, str], + tag_req: TagRequest, + db_taggable=Depends(get_taggable), + db_tag: models.Tag = Depends(get_tag), + db: Session = Depends(get_db), + ): + tag = tag_service.update_tag(db, db_tag, db_taggable, tag_req) + + return domain_to_dto_tag(tag) + + async def delete_tag( + uid: Union[int, str], + tag_id: int, + db_taggable=Depends(get_taggable), + db: Session = Depends(get_db), + ): + tag_service.delete_tag(db, db_taggable, tag_id) + + return Response(status_code=HTTP_204_NO_CONTENT) + + router.add_api_route( + path, endpoint=add_tag, methods=["POST"], status_code=HTTP_201_CREATED + ) + router.add_api_route(path + "/{tag_id}", endpoint=update_tag, methods=["PUT"]) + router.add_api_route( + path + "/{tag_id}", + endpoint=delete_tag, + methods=["DELETE"], + status_code=HTTP_204_NO_CONTENT, + ) diff --git a/empire/server/api/v2/tag/tag_dto.py b/empire/server/api/v2/tag/tag_dto.py new file mode 100644 index 000000000..9c3efdd8f --- /dev/null +++ b/empire/server/api/v2/tag/tag_dto.py @@ -0,0 +1,59 @@ +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel, constr + +from empire.server.core.db import models + +# Validate the string contains 1 colon +TagStr = constr(regex=r"^[^:]+:[^:]+$") + +# Validate the string has no colons +TagStrNoColon = constr(regex=r"^[^:]+$") + + +class TagSourceFilter(str, Enum): + listener = "listener" + agent = "agent" + agent_task = "agent_task" + plugin_task = "plugin_task" + download = "download" + credential = "credential" + + +class Tag(BaseModel): + id: int + name: str + value: str + label: str + color: Optional[str] + + +class Tags(BaseModel): + records: List[Tag] + limit: int + page: int + total_pages: int + total: int + + +class TagRequest(BaseModel): + name: TagStrNoColon + value: TagStrNoColon + color: Optional[str] + + +class TagOrderOptions(str, Enum): + name = "name" + created_at = "created_at" + updated_at = "updated_at" + + +def domain_to_dto_tag(tag: models.Tag): + return Tag( + id=tag.id, + name=tag.name, + value=tag.value, + label=f"{tag.name}:{tag.value}", + color=tag.color, + ) diff --git a/empire/server/common/empire.py b/empire/server/common/empire.py index 6a5c45edc..85002c846 100755 --- a/empire/server/common/empire.py +++ b/empire/server/common/empire.py @@ -34,6 +34,7 @@ from empire.server.core.profile_service import ProfileService from empire.server.core.stager_service import StagerService from empire.server.core.stager_template_service import StagerTemplateService +from empire.server.core.tag_service import TagService from empire.server.core.user_service import UserService from empire.server.utils import data_util @@ -89,6 +90,7 @@ def __init__(self, args=None): self.agentfilesv2 = AgentFileService(self) self.agentsv2 = AgentService(self) self.pluginsv2 = PluginService(self) + self.tagsv2 = TagService(self) self.pluginsv2.startup() hooks_internal.initialize() diff --git a/empire/server/core/agent_task_service.py b/empire/server/core/agent_task_service.py index ac9b291ff..d4378021b 100644 --- a/empire/server/core/agent_task_service.py +++ b/empire/server/core/agent_task_service.py @@ -44,6 +44,7 @@ def get_tasks( db: Session, agents: List[str] = None, users: List[int] = None, + tags: List[str] = None, limit: int = -1, offset: int = 0, include_full_input: bool = False, @@ -63,7 +64,19 @@ def get_tasks( query = query.filter(models.AgentTask.agent_id.in_(agents)) if users: - query = query.filter(models.AgentTask.user_id.in_(users)) + user_filters = [models.AgentTask.user_id.in_(users)] + if 0 in users: + user_filters.append(models.AgentTask.user_id.is_(None)) + query = query.filter(or_(*user_filters)) + + if tags: + tags_split = [tag.split(":", 1) for tag in tags] + query = query.join(models.AgentTask.tags).filter( + and_( + models.Tag.name.in_([tag[0] for tag in tags_split]), + models.Tag.value.in_([tag[1] for tag in tags_split]), + ) + ) query_options = [ joinedload(models.AgentTask.user), @@ -299,7 +312,7 @@ def create_task_module( db: Session, agent: models.Agent, module_req: ModulePostRequest, - user_id: int, + user_id: int = 0, ): module_req.options["Agent"] = agent.session_id resp, err = self.module_service.execute_module( diff --git a/empire/server/core/db/models.py b/empire/server/core/db/models.py index 98d5f5531..f22c1978a 100644 --- a/empire/server/core/db/models.py +++ b/empire/server/core/db/models.py @@ -67,7 +67,6 @@ def get_database_config(): Column("download_id", Integer, ForeignKey("downloads.id")), ) - stager_download_assc = Table( "stager_download_assc", Base.metadata, @@ -82,6 +81,52 @@ def get_database_config(): Column("download_id", Integer, ForeignKey("downloads.id")), ) +listener_tag_assc = Table( + "listener_tag_assc", + Base.metadata, + Column("listener_id", Integer, ForeignKey("listeners.id")), + Column("tag_id", Integer, ForeignKey("tags.id")), +) + +agent_tag_assc = Table( + "agent_tag_assc", + Base.metadata, + Column("agent_id", String(255), ForeignKey("agents.session_id")), + Column("tag_id", Integer, ForeignKey("tags.id")), +) + +agent_task_tag_assc = Table( + "agent_task_tag_assc", + Base.metadata, + Column("agent_task_id", Integer), + Column("agent_id", String(255)), + Column("tag_id", Integer, ForeignKey("tags.id")), + ForeignKeyConstraint( + ("agent_task_id", "agent_id"), ("agent_tasks.id", "agent_tasks.agent_id") + ), +) + +plugin_task_tag_assc = Table( + "plugin_task_tag_assc", + Base.metadata, + Column("plugin_task_id", Integer, ForeignKey("plugin_tasks.id")), + Column("tag_id", Integer, ForeignKey("tags.id")), +) + +credential_tag_assc = Table( + "credential_tag_assc", + Base.metadata, + Column("credential_id", Integer, ForeignKey("credentials.id")), + Column("tag_id", Integer, ForeignKey("tags.id")), +) + +download_tag_assc = Table( + "download_tag_assc", + Base.metadata, + Column("download_id", Integer, ForeignKey("downloads.id")), + Column("tag_id", Integer, ForeignKey("tags.id")), +) + class User(Base): __tablename__ = "users" @@ -111,6 +156,7 @@ class Listener(Base): enabled = Column(Boolean, nullable=False) options = Column(JSON) created_at = Column(UtcDateTime, nullable=False, default=utcnow()) + tags = relationship("Tag", secondary=listener_tag_assc) def __repr__(self): return "" % (self.name) @@ -187,6 +233,7 @@ class Agent(Base): proxies = Column(JSON) socks = Column(Boolean) socks_port = Column(Integer) + tags = relationship("Tag", secondary=agent_tag_assc) @hybrid_property def lastseen_time(self): @@ -297,6 +344,7 @@ class Credential(Base): updated_at = Column( UtcDateTime, default=utcnow(), onupdate=utcnow(), nullable=False ) + tags = relationship("Tag", secondary=credential_tag_assc) def __repr__(self): return "" % (self.id) @@ -318,6 +366,7 @@ class Download(Base): updated_at = Column( UtcDateTime, default=utcnow(), onupdate=utcnow(), nullable=False ) + tags = relationship("Tag", secondary=download_tag_assc) def get_base64_file(self): with open(self.location, "rb") as f: @@ -359,6 +408,7 @@ class AgentTask(Base): task_name = Column(Text) status = Column(Enum(AgentTaskStatus), index=True) downloads = relationship("Download", secondary=agent_task_download_assc) + tags = relationship("Tag", secondary=agent_task_tag_assc) def __repr__(self): return "" % (self.id) @@ -394,6 +444,7 @@ class PluginTask(Base): task_name = Column(Text) status = Column(Enum(PluginTaskStatus), index=True) downloads = relationship("Download", secondary=plugin_task_download_assc) + tags = relationship("Tag", secondary=plugin_task_tag_assc) def __repr__(self): return "" % (self.id) @@ -484,3 +535,15 @@ class ObfuscationConfig(Base): module = Column(String(255)) enabled = Column(Boolean) preobfuscatable = Column(Boolean) + + +class Tag(Base): + __tablename__ = "tags" + id = Column(Integer, Sequence("tag_seq"), primary_key=True) + name = Column(String(255), nullable=False) + value = Column(String(255), nullable=False) + color = Column(String(12), nullable=True) + created_at = Column(UtcDateTime, nullable=False, default=utcnow()) + updated_at = Column( + UtcDateTime, nullable=False, onupdate=utcnow(), default=utcnow() + ) diff --git a/empire/server/core/download_service.py b/empire/server/core/download_service.py index df30be09a..829af5911 100644 --- a/empire/server/core/download_service.py +++ b/empire/server/core/download_service.py @@ -1,5 +1,6 @@ import os import shutil +from operator import and_ from pathlib import Path from typing import List, Optional, Tuple @@ -28,7 +29,8 @@ def get_by_id(db: Session, uid: int): def get_all( db: Session, download_types: Optional[List[DownloadSourceFilter]], - q: str, + tags: List[str] = None, + q: str = None, limit: int = -1, offset: int = 0, order_by: DownloadOrderOptions = DownloadOrderOptions.updated_at, @@ -79,6 +81,15 @@ def get_all( ) ) + if tags: + tags_split = [tag.split(":", 1) for tag in tags] + query = query.join(models.Download.tags).filter( + and_( + models.Tag.name.in_([tag[0] for tag in tags_split]), + models.Tag.value.in_([tag[1] for tag in tags_split]), + ) + ) + if order_by == DownloadOrderOptions.filename: order_by_prop = func.lower(models.Download.filename) elif order_by == DownloadOrderOptions.location: diff --git a/empire/server/core/hooks.py b/empire/server/core/hooks.py index 0f052cd4e..b1a4a6977 100644 --- a/empire/server/core/hooks.py +++ b/empire/server/core/hooks.py @@ -37,6 +37,14 @@ class Hooks(object): # Its arguments are (db: Session, agent: models.Agent) AFTER_AGENT_CHECKIN_HOOK = "after_agent_checkin_hook" + # This event is triggered after a tag is created. + # Its arguments are (db: Session, tag: models.Tag, taggable: Union[models.Agent, models.Listener, etc]) + AFTER_TAG_CREATED_HOOK = "after_tag_created_hook" + + # This event is triggered after a tag is updated. + # Its arguments are (db: Session, tag: models.Tag, taggable: Union[models.Agent, models.Listener, etc]) + AFTER_TAG_UPDATED_HOOK = "after_tag_updated_hook" + def __init__(self): self.hooks: Dict[str, Dict[str, Callable]] = {} self.filters: Dict[str, Dict[str, Callable]] = {} diff --git a/empire/server/core/plugin_service.py b/empire/server/core/plugin_service.py index 8096fc02a..e0ee8ca88 100644 --- a/empire/server/core/plugin_service.py +++ b/empire/server/core/plugin_service.py @@ -6,7 +6,7 @@ from datetime import datetime from typing import List, Optional -from sqlalchemy import func, or_ +from sqlalchemy import and_, func, or_ from sqlalchemy.orm import Session, joinedload, undefer from empire.server.api.v2.plugin.plugin_dto import PluginExecutePostRequest @@ -184,6 +184,7 @@ def get_tasks( db: Session, plugins: List[str] = None, users: List[int] = None, + tags: List[str] = None, limit: int = -1, offset: int = 0, include_full_input: bool = False, @@ -202,11 +203,24 @@ def get_tasks( query = query.filter(models.PluginTask.plugin_id.in_(plugins)) if users: - query = query.filter(models.PluginTask.user_id.in_(users)) + user_filters = [models.PluginTask.user_id.in_(users)] + if 0 in users: + user_filters.append(models.PluginTask.user_id.is_(None)) + query = query.filter(or_(*user_filters)) + + if tags: + tags_split = [tag.split(":", 1) for tag in tags] + query = query.join(models.PluginTask.tags).filter( + and_( + models.Tag.name.in_([tag[0] for tag in tags_split]), + models.Tag.value.in_([tag[1] for tag in tags_split]), + ) + ) query_options = [ joinedload(models.PluginTask.user), ] + if include_full_input: query_options.append(undefer(models.PluginTask.input_full)) if include_output: diff --git a/empire/server/core/tag_service.py b/empire/server/core/tag_service.py new file mode 100644 index 000000000..6e72140a9 --- /dev/null +++ b/empire/server/core/tag_service.py @@ -0,0 +1,147 @@ +import logging +from typing import List, Optional, Union + +from sqlalchemy import func, or_ +from sqlalchemy.orm import Session + +from empire.server.api.v2.shared_dto import OrderDirection +from empire.server.api.v2.tag.tag_dto import TagOrderOptions, TagSourceFilter +from empire.server.core.db import models +from empire.server.core.hooks import hooks + +log = logging.getLogger(__name__) + + +class TagService(object): + def __init__(self, main_menu): + self.main_menu = main_menu + + def get_by_id(self, db: Session, tag_id: int): + return db.query(models.Tag).filter(models.Tag.id == tag_id).first() + + def get_all( + self, + db: Session, + tag_types: Optional[List[TagSourceFilter]], + q: str, + limit: int = -1, + offset: int = 0, + order_by: TagOrderOptions = TagOrderOptions.updated_at, + order_direction: OrderDirection = OrderDirection.desc, + ): + query = db.query(models.Tag, func.count(models.Tag.id).over().label("total")) + + tag_types = tag_types or [] + sub = [] + if TagSourceFilter.agent_task in tag_types: + sub.append(db.query(models.agent_task_tag_assc.c.tag_id.label("tag_id"))) + if TagSourceFilter.plugin_task in tag_types: + sub.append(db.query(models.plugin_task_tag_assc.c.tag_id.label("tag_id"))) + if TagSourceFilter.agent in tag_types: + sub.append(db.query(models.agent_tag_assc.c.tag_id.label("tag_id"))) + if TagSourceFilter.listener in tag_types: + sub.append(db.query(models.listener_tag_assc.c.tag_id.label("tag_id"))) + if TagSourceFilter.download in tag_types: + sub.append(db.query(models.download_tag_assc.c.tag_id.label("tag_id"))) + if TagSourceFilter.credential in tag_types: + sub.append(db.query(models.credential_tag_assc.c.tag_id.label("tag_id"))) + + subquery = None + if len(sub) > 0: + subquery = sub[0] + if len(sub) > 1: + subquery = subquery.union(*sub[1:]) + subquery = subquery.subquery() + + if subquery is not None: + query = query.join(subquery, subquery.c.tag_id == models.Tag.id) + + if q: + query = query.filter( + or_( + models.Tag.name.like(f"%{q}%"), + ) + ) + + if order_by == TagOrderOptions.name: + order_by_prop = func.lower(models.Tag.name) + elif order_by == TagOrderOptions.created_at: + order_by_prop = models.Tag.created_at + else: + order_by_prop = models.Tag.updated_at + + if order_direction == OrderDirection.asc: + query = query.order_by(order_by_prop.asc()) + else: + query = query.order_by(order_by_prop.desc()) + + if limit > 0: + query = query.limit(limit).offset(offset) + + results = query.all() + + total = 0 if len(results) == 0 else results[0].total + results = list(map(lambda x: x[0], results)) + + return results, total + + def add_tag( + self, + db: Session, + taggable: Union[ + models.Listener, + models.Agent, + models.AgentTask, + models.PluginTask, + models.Credential, + models.Download, + ], + tag_req, + ): + tag = models.Tag(name=tag_req.name, value=tag_req.value, color=tag_req.color) + taggable.tags.append(tag) + db.flush() + + hooks.run_hooks(hooks.AFTER_TAG_CREATED_HOOK, db, tag, taggable) + + return tag + + def update_tag( + self, + db: Session, + db_tag: models.Tag, + taggable: Union[ + models.Listener, + models.Agent, + models.AgentTask, + models.PluginTask, + models.Credential, + models.Download, + ], + tag_req, + ): + db_tag.name = tag_req.name + db_tag.value = tag_req.value + db_tag.color = tag_req.color + db.flush() + + hooks.run_hooks(hooks.AFTER_TAG_UPDATED_HOOK, db, db_tag, taggable) + + return db_tag + + def delete_tag( + self, + db: Session, + taggable: Union[ + models.Listener, + models.Agent, + models.AgentTask, + models.PluginTask, + models.Credential, + models.Download, + ], + tag_id: int, + ): + if tag_id in [tag.id for tag in taggable.tags]: + taggable.tags = [tag for tag in taggable.tags if tag.id != tag_id] + db.query(models.Tag).filter(models.Tag.id == tag_id).delete() diff --git a/empire/test/conftest.py b/empire/test/conftest.py index b7940362b..0c8e221fb 100644 --- a/empire/test/conftest.py +++ b/empire/test/conftest.py @@ -1,6 +1,7 @@ import os import shutil import sys +from contextlib import suppress from importlib import reload from pathlib import Path @@ -62,6 +63,7 @@ def client(): from empire.server.api.v2.plugin import plugin_api, plugin_task_api from empire.server.api.v2.profile import profile_api from empire.server.api.v2.stager import stager_api, stager_template_api + from empire.server.api.v2.tag import tag_api from empire.server.api.v2.user import user_api v2App = FastAPI() @@ -84,6 +86,7 @@ def client(): v2App.include_router(process_api.router) v2App.include_router(download_api.router) v2App.include_router(meta_api.router) + v2App.include_router(tag_api.router) yield TestClient(v2App) @@ -221,6 +224,23 @@ def base_listener_non_fixture(): } +@pytest.fixture(scope="module", autouse=True) +def listener(client, admin_auth_header): + # not using fixture because scope issues + response = client.post( + "/api/v2/listeners/", + headers=admin_auth_header, + json=base_listener_non_fixture(), + ) + + yield response.json() + + with suppress(Exception): + client.delete( + f"/api/v2/listeners/{response.json()['id']}", headers=admin_auth_header + ) + + @pytest.fixture(scope="function") def base_stager(): return { @@ -299,6 +319,7 @@ def host(session_local, models): yield host_id with session_local.begin() as db: + db.query(models.Agent).filter(models.Agent.host_id == host_id).delete() db.query(models.Host).filter(models.Host.id == host_id).delete() @@ -347,6 +368,79 @@ def agent(session_local, models, host, main): db.query(models.Agent).filter(models.Agent.session_id == agent_id).delete() +@pytest.fixture(scope="function") +def agent_task(client, admin_auth_header, agent): + resp = client.post( + f"/api/v2/agents/{agent}/tasks/shell", + headers=admin_auth_header, + json={"command": 'echo "HELLO WORLD"'}, + ) + + yield resp.json() + + # No need to delete the task, it will be deleted when the agent is deleted + # After the test. + + +@pytest.fixture(scope="module") +def plugin_name(): + return "basic_reporting" + + +@pytest.fixture(scope="function") +def plugin_task(main, session_local, models, plugin_name): + with session_local.begin() as db: + plugin_task = models.PluginTask( + plugin_id=plugin_name, + input="This is the trimmed input for the task.", + input_full="This is the full input for the task.", + user_id=1, + ) + db.add(plugin_task) + db.flush() + task_id = plugin_task.id + + yield task_id + + with session_local.begin() as db: + db.query(models.PluginTask).delete() + + +@pytest.fixture(scope="function") +def credential(client, admin_auth_header): + resp = client.post( + "/api/v2/credentials/", + headers=admin_auth_header, + json={ + "credtype": "hash", + "domain": "the-domain", + "username": "user", + "password": "hunter2", + "host": "host1", + }, + ) + + yield resp.json()["id"] + + client.delete(f"/api/v2/credentials/{resp.json()['id']}", headers=admin_auth_header) + + +@pytest.fixture(scope="function") +def download(client, admin_auth_header): + response = client.post( + "/api/v2/downloads", + headers=admin_auth_header, + files={ + "file": ( + "test-upload-2.yaml", + open("./empire/test/test-upload-2.yaml", "r").read(), + ) + }, + ) + + yield response.json()["id"] + + @pytest.fixture(scope="session") def server_config_dict(): # load the config file diff --git a/empire/test/test_agent_task_api.py b/empire/test/test_agent_task_api.py index 072761baf..c3f3d9000 100644 --- a/empire/test/test_agent_task_api.py +++ b/empire/test/test_agent_task_api.py @@ -2,24 +2,6 @@ import pytest -from empire.test.conftest import base_listener_non_fixture - - -@pytest.fixture(scope="module", autouse=True) -def listener(client, admin_auth_header): - # not using fixture because scope issues - response = client.post( - "/api/v2/listeners/", - headers=admin_auth_header, - json=base_listener_non_fixture(), - ) - - yield response.json() - - client.delete( - f"/api/v2/listeners/{response.json()['id']}", headers=admin_auth_header - ) - @pytest.fixture(scope="module", autouse=True) def agent_low_version(db, models, main): diff --git a/empire/test/test_listener_api.py b/empire/test/test_listener_api.py index 82910cd7d..2ad4fd5f3 100644 --- a/empire/test/test_listener_api.py +++ b/empire/test/test_listener_api.py @@ -1,6 +1,3 @@ -my_globals = {"listener_id": 0} - - def test_get_listener_templates(client, admin_auth_header): response = client.get( "/api/v2/listener-templates/", @@ -24,9 +21,11 @@ def test_get_listener_template(client, admin_auth_header): def test_create_listener_validation_fails_required_field( client, base_listener, admin_auth_header ): + base_listener_copy = base_listener.copy() + base_listener_copy["name"] = "temp123" base_listener["options"]["Port"] = "" response = client.post( - "/api/v2/listeners/", headers=admin_auth_header, json=base_listener + "/api/v2/listeners/", headers=admin_auth_header, json=base_listener_copy ) assert response.status_code == 400 assert response.json()["detail"] == "required option missing: Port" @@ -45,46 +44,56 @@ def test_create_listener_validation_fails_required_field( def test_create_listener_custom_validation_fails( client, base_listener, admin_auth_header ): - base_listener["options"]["Host"] = "https://securedomain.com" + base_listener_copy = base_listener.copy() + base_listener_copy["name"] = "temp123" + base_listener_copy["options"]["Host"] = "https://securedomain.com" response = client.post( - "/api/v2/listeners/", headers=admin_auth_header, json=base_listener + "/api/v2/listeners/", headers=admin_auth_header, json=base_listener_copy ) assert response.status_code == 400 assert response.json()["detail"] == "[!] HTTPS selected but no CertPath specified." def test_create_listener_template_not_found(client, base_listener, admin_auth_header): - base_listener["template"] = "qwerty" + base_listener_copy = base_listener.copy() + base_listener_copy["name"] = "temp123" + base_listener_copy["template"] = "qwerty" response = client.post( - "/api/v2/listeners/", headers=admin_auth_header, json=base_listener + "/api/v2/listeners/", headers=admin_auth_header, json=base_listener_copy ) assert response.status_code == 400 assert response.json()["detail"] == "Listener Template qwerty not found" def test_create_listener(client, base_listener, admin_auth_header): + base_listener_copy = base_listener.copy() + base_listener_copy["name"] = "temp123" + base_listener_copy["options"]["Port"] = "1234" + # test that it ignore extra params - base_listener["options"]["xyz"] = "xyz" + base_listener_copy["options"]["xyz"] = "xyz" response = client.post( - "/api/v2/listeners/", headers=admin_auth_header, json=base_listener + "/api/v2/listeners/", headers=admin_auth_header, json=base_listener_copy ) assert response.status_code == 201 assert response.json()["options"].get("xyz") is None - assert response.json()["options"]["Name"] == base_listener["options"]["Name"] - assert response.json()["options"]["Port"] == base_listener["options"]["Port"] + assert response.json()["options"]["Name"] == base_listener_copy["name"] + assert response.json()["options"]["Port"] == base_listener_copy["options"]["Port"] assert ( response.json()["options"]["DefaultJitter"] - == base_listener["options"]["DefaultJitter"] + == base_listener_copy["options"]["DefaultJitter"] ) assert ( response.json()["options"]["DefaultDelay"] - == base_listener["options"]["DefaultDelay"] + == base_listener_copy["options"]["DefaultDelay"] ) - my_globals["listener_id"] = response.json()["id"] + client.delete( + f"/api/v2/listeners/{response.json()['id']}", headers=admin_auth_header + ) def test_create_listener_name_conflict(client, base_listener, admin_auth_header): @@ -98,13 +107,13 @@ def test_create_listener_name_conflict(client, base_listener, admin_auth_header) ) -def test_get_listener(client, admin_auth_header): +def test_get_listener(client, admin_auth_header, listener): response = client.get( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, ) assert response.status_code == 200 - assert response.json()["id"] == my_globals["listener_id"] + assert response.json()["id"] == listener["id"] def test_get_listener_not_found(client, admin_auth_header): @@ -125,15 +134,15 @@ def test_update_listener_not_found(client, base_listener, admin_auth_header): assert response.json()["detail"] == "Listener not found for id 9999" -def test_update_listener_blocks_while_enabled(client, admin_auth_header): +def test_update_listener_blocks_while_enabled(client, admin_auth_header, listener): response = client.get( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, ) assert response.json()["enabled"] is True response = client.put( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, json=response.json(), ) @@ -141,9 +150,11 @@ def test_update_listener_blocks_while_enabled(client, admin_auth_header): assert response.json()["detail"] == "Listener must be disabled before modifying" -def test_update_listener_allows_and_disables_while_enabled(client, admin_auth_header): +def test_update_listener_allows_and_disables_while_enabled( + client, admin_auth_header, listener +): response = client.get( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, ) assert response.json()["enabled"] is True @@ -153,7 +164,7 @@ def test_update_listener_allows_and_disables_while_enabled(client, admin_auth_he new_port = str(int(listener["options"]["Port"]) + 1) listener["options"]["Port"] = new_port response = client.put( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, json=listener, ) @@ -162,9 +173,9 @@ def test_update_listener_allows_and_disables_while_enabled(client, admin_auth_he assert response.json()["options"]["Port"] == new_port -def test_update_listener_allows_while_disabled(client, admin_auth_header): +def test_update_listener_allows_while_disabled(client, admin_auth_header, listener): response = client.get( - f"/api/v2/listeners/{my_globals['listener_id']}", headers=admin_auth_header + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header ) assert response.json()["enabled"] is False @@ -175,7 +186,7 @@ def test_update_listener_allows_while_disabled(client, admin_auth_header): listener["options"]["xyz"] = "xyz" response = client.put( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, json=listener, ) @@ -186,11 +197,12 @@ def test_update_listener_allows_while_disabled(client, admin_auth_header): def test_update_listener_name_conflict(client, base_listener, admin_auth_header): + base_listener_copy = base_listener.copy() # Create a second listener. - base_listener["name"] = "new-listener-2" - base_listener["options"]["Port"] = "1299" + base_listener_copy["name"] = "new-listener-2" + base_listener_copy["options"]["Port"] = "1299" response = client.post( - "/api/v2/listeners/", headers=admin_auth_header, json=base_listener + "/api/v2/listeners/", headers=admin_auth_header, json=base_listener_copy ) assert response.status_code == 201 @@ -216,9 +228,11 @@ def test_update_listener_name_conflict(client, base_listener, admin_auth_header) ) -def test_update_listener_reverts_if_validation_fails(client, admin_auth_header): +def test_update_listener_reverts_if_validation_fails( + client, admin_auth_header, listener +): response = client.get( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, ) assert response.json()["enabled"] is False @@ -235,14 +249,16 @@ def test_update_listener_reverts_if_validation_fails(client, admin_auth_header): assert response.json()["detail"] == "required option missing: Port" response = client.get( - f"/api/v2/listeners/{my_globals['listener_id']}", headers=admin_auth_header + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header ) assert response.json()["options"]["BindIP"] == "0.0.0.0" -def test_update_listener_reverts_if_custom_validation_fails(client, admin_auth_header): +def test_update_listener_reverts_if_custom_validation_fails( + client, admin_auth_header, listener +): response = client.get( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, ) assert response.json()["enabled"] is False @@ -259,15 +275,17 @@ def test_update_listener_reverts_if_custom_validation_fails(client, admin_auth_h assert response.json()["detail"] == "[!] HTTPS selected but no CertPath specified." response = client.get( - f"/api/v2/listeners/ {my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, ) assert response.json()["options"]["BindIP"] == "0.0.0.0" -def test_update_listener_allows_and_enables_while_disabled(client, admin_auth_header): +def test_update_listener_allows_and_enables_while_disabled( + client, admin_auth_header, listener +): response = client.get( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, ) assert response.json()["enabled"] is False @@ -277,7 +295,7 @@ def test_update_listener_allows_and_enables_while_disabled(client, admin_auth_he listener["enabled"] = True listener["options"]["Port"] = new_port response = client.put( - f"/api/v2/listeners/{my_globals['listener_id']}", + f"/api/v2/listeners/{listener['id']}", headers=admin_auth_header, json=listener, ) @@ -293,50 +311,45 @@ def test_get_listeners(client, admin_auth_header): assert len(response.json()["records"]) == 2 -def test_delete_listener_while_enabled(client, admin_auth_header): - response = client.get("/api/v2/listeners", headers=admin_auth_header) - assert response.status_code == 200 - assert len(response.json()["records"]) == 2 - - to_delete = list( - filter(lambda x: x["enabled"] is True, response.json()["records"]) - )[0] - assert to_delete["enabled"] is True +def test_delete_listener_while_enabled(client, admin_auth_header, base_listener): + to_delete = base_listener.copy() + to_delete["name"] = "to-delete" + to_delete["options"]["Port"] = "1299" + response = client.post( + "/api/v2/listeners/", headers=admin_auth_header, json=to_delete + ) + assert response.status_code == 201 + to_delete_id = response.json()["id"] response = client.delete( - f"/api/v2/listeners/{to_delete['id']}", headers=admin_auth_header + f"/api/v2/listeners/{to_delete_id}", headers=admin_auth_header ) assert response.status_code == 204 response = client.get( - "/api/v2/listeners", - headers=admin_auth_header, + f"/api/v2/listeners/{to_delete_id}", headers=admin_auth_header ) - assert response.status_code == 200 - assert len(response.json()["records"]) == 1 - assert response.json()["records"][0]["id"] != to_delete["id"] + assert response.status_code == 404 -def test_delete_listener_while_disabled(client, admin_auth_header): - response = client.get( - "/api/v2/listeners", - headers=admin_auth_header, - ) - assert response.status_code == 200 - assert len(response.json()["records"]) == 1 - to_delete = response.json()["records"][0] - assert to_delete["enabled"] is False +def test_delete_listener_while_disabled(client, admin_auth_header, base_listener): + to_delete = base_listener.copy() + to_delete["name"] = "to-delete" + to_delete["options"]["Port"] = "1298" + + response = client.post( + "/api/v2/listeners/", headers=admin_auth_header, json=to_delete + ) + assert response.status_code == 201 + to_delete_id = response.json()["id"] response = client.delete( - f"/api/v2/listeners/{to_delete['id']}", - headers=admin_auth_header, + f"/api/v2/listeners/{to_delete_id}", headers=admin_auth_header ) assert response.status_code == 204 response = client.get( - "/api/v2/listeners", - headers=admin_auth_header, + f"/api/v2/listeners/{to_delete_id}", headers=admin_auth_header ) - assert response.status_code == 200 - assert len(response.json()["records"]) == 0 + assert response.status_code == 404 diff --git a/empire/test/test_plugin_task_api.py b/empire/test/test_plugin_task_api.py index 5a512bfed..26d497c60 100644 --- a/empire/test/test_plugin_task_api.py +++ b/empire/test/test_plugin_task_api.py @@ -1,11 +1,6 @@ import pytest -@pytest.fixture(scope="module") -def plugin_name(): - return "basic_reporting" - - @pytest.fixture(scope="module", autouse=True) def plugin_task_1(main, db, models, plugin_name): db.add( diff --git a/empire/test/test_stager_api.py b/empire/test/test_stager_api.py index 20adf9da4..648588e54 100644 --- a/empire/test/test_stager_api.py +++ b/empire/test/test_stager_api.py @@ -1,19 +1,15 @@ import pytest -from empire.test.conftest import base_listener_non_fixture - -my_globals = {"stager_id_1": 0, "stager_id_2": 0} - @pytest.fixture(scope="module", autouse=True) -def create_listener(client, admin_auth_header): - # not using fixture because scope issues - response = client.post( - "/api/v2/listeners/", - headers=admin_auth_header, - json=base_listener_non_fixture(), - ) - return response.json() +def cleanup_stagers(session_local, models): + yield + + with session_local.begin() as db: + db.query(models.stager_download_assc).delete() + db.query(models.upload_download_assc).delete() + db.query(models.Stager).delete() + db.query(models.Download).delete() def test_get_stager_templates(client, admin_auth_header): @@ -93,7 +89,7 @@ def test_create_stager_one_liner(client, base_stager, admin_auth_header): response.json().get("downloads", [])[0]["link"].startswith("/api/v2/downloads") ) - my_globals["stager_id_1"] = response.json()["id"] + client.delete(f"/api/v2/stagers/{response.json()['id']}", headers=admin_auth_header) def test_create_obfuscated_stager_one_liner(client, base_stager, admin_auth_header): @@ -113,7 +109,7 @@ def test_create_obfuscated_stager_one_liner(client, base_stager, admin_auth_head response.json().get("downloads", [])[0]["link"].startswith("/api/v2/downloads") ) - my_globals["stager_id_1"] = response.json()["id"] + client.delete(f"/api/v2/stagers/{response.json()['id']}", headers=admin_auth_header) def test_create_stager_file(client, base_stager_2, admin_auth_header): @@ -130,10 +126,16 @@ def test_create_stager_file(client, base_stager_2, admin_auth_header): response.json().get("downloads", [])[0]["link"].startswith("/api/v2/downloads") ) - my_globals["stager_id_2"] = response.json()["id"] + client.delete(f"/api/v2/stagers/{response.json()['id']}", headers=admin_auth_header) def test_create_stager_name_conflict(client, base_stager, admin_auth_header): + response = client.post( + "/api/v2/stagers/?save=true", headers=admin_auth_header, json=base_stager + ) + assert response.status_code == 201 + stager_id = response.json()["id"] + response = client.post( "/api/v2/stagers/?save=true", headers=admin_auth_header, json=base_stager ) @@ -143,6 +145,8 @@ def test_create_stager_name_conflict(client, base_stager, admin_auth_header): == f'Stager with name {base_stager["name"]} already exists.' ) + client.delete(f"/api/v2/stagers/{stager_id}", headers=admin_auth_header) + def test_create_stager_save_false(client, base_stager, admin_auth_header): response = client.post( @@ -156,13 +160,22 @@ def test_create_stager_save_false(client, base_stager, admin_auth_header): ) -def test_get_stager(client, admin_auth_header): +def test_get_stager(client, admin_auth_header, base_stager): + response = client.post( + "/api/v2/stagers/?save=true", headers=admin_auth_header, json=base_stager + ) + stager_id = response.json()["id"] + + assert response.status_code == 201 + response = client.get( - f"/api/v2/stagers/{my_globals['stager_id_1']}", + f"/api/v2/stagers/{stager_id}", headers=admin_auth_header, ) assert response.status_code == 200 - assert response.json()["id"] == my_globals["stager_id_1"] + assert response.json()["id"] == stager_id + + client.delete(f"/api/v2/stagers/{stager_id}", headers=admin_auth_header) def test_get_stager_not_found(client, admin_auth_header): @@ -182,9 +195,17 @@ def test_update_stager_not_found(client, base_stager, admin_auth_header): assert response.json()["detail"] == "Stager not found for id 9999" -def test_download_stager_one_liner(client, admin_auth_header): +def test_download_stager_one_liner(client, admin_auth_header, base_stager): + response = client.post( + "/api/v2/stagers/?save=true", + headers=admin_auth_header, + json=base_stager, + ) + assert response.status_code == 201 + stager_id = response.json()["id"] + response = client.get( - f"/api/v2/stagers/{my_globals['stager_id_1']}", + f"/api/v2/stagers/{stager_id}", headers=admin_auth_header, ) response = client.get( @@ -195,10 +216,20 @@ def test_download_stager_one_liner(client, admin_auth_header): assert response.headers.get("content-type").split(";")[0] == "text/plain" assert response.text.startswith("powershell -noP -sta") + client.delete(f"/api/v2/stagers/{stager_id}", headers=admin_auth_header) + + +def test_download_stager_file(client, admin_auth_header, base_stager_2): + response = client.post( + "/api/v2/stagers/?save=true", + headers=admin_auth_header, + json=base_stager_2, + ) + assert response.status_code == 201 + stager_id = response.json()["id"] -def test_download_stager_file(client, admin_auth_header): response = client.get( - f"/api/v2/stagers/{my_globals['stager_id_2']}", + f"/api/v2/stagers/{stager_id}", headers=admin_auth_header, ) response = client.get( @@ -212,10 +243,22 @@ def test_download_stager_file(client, admin_auth_header): ] assert type(response.content) == bytes + client.delete(f"/api/v2/stagers/{stager_id}", headers=admin_auth_header) + + +def test_update_stager_allows_edits_and_generates_new_file( + client, admin_auth_header, base_stager +): + response = client.post( + "/api/v2/stagers/?save=true", + headers=admin_auth_header, + json=base_stager, + ) + assert response.status_code == 201 + stager_id = response.json()["id"] -def test_update_stager_allows_edits_and_generates_new_file(client, admin_auth_header): response = client.get( - f"/api/v2/stagers/{my_globals['stager_id_1']}", + f"/api/v2/stagers/{stager_id}", headers=admin_auth_header, ) assert response.status_code == 200 @@ -226,7 +269,7 @@ def test_update_stager_allows_edits_and_generates_new_file(client, admin_auth_he stager["options"]["Base64"] = "False" response = client.put( - f"/api/v2/stagers/{my_globals['stager_id_1']}", + f"/api/v2/stagers/{stager_id}", headers=admin_auth_header, json=stager, ) @@ -234,16 +277,36 @@ def test_update_stager_allows_edits_and_generates_new_file(client, admin_auth_he assert response.json()["options"]["Base64"] == "False" assert response.json()["name"] == original_name + "_updated!" + client.delete(f"/api/v2/stagers/{stager_id}", headers=admin_auth_header) + + +def test_update_stager_name_conflict(client, admin_auth_header, base_stager): + response = client.post( + "/api/v2/stagers/?save=true", + headers=admin_auth_header, + json=base_stager, + ) + assert response.status_code == 201 + stager_id = response.json()["id"] -def test_update_stager_name_conflict(client, admin_auth_header): response = client.get( - f"/api/v2/stagers/{my_globals['stager_id_1']}", + f"/api/v2/stagers/{stager_id}", headers=admin_auth_header, ) assert response.status_code == 200 + base_stager_2 = base_stager.copy() + base_stager_2["name"] = "test_stager_2" + response2 = client.post( + "/api/v2/stagers/?save=true", + headers=admin_auth_header, + json=base_stager_2, + ) + assert response2.status_code == 201 + stager_id_2 = response2.json()["id"] + response2 = client.get( - f"/api/v2/stagers/{my_globals['stager_id_2']}", + f"/api/v2/stagers/{stager_id_2}", headers=admin_auth_header, ) assert response.status_code == 200 @@ -252,7 +315,7 @@ def test_update_stager_name_conflict(client, admin_auth_header): stager_1["name"] = stager_2["name"] response = client.put( - f"/api/v2/stagers/{my_globals['stager_id_1']}", + f"/api/v2/stagers/{stager_id}", headers=admin_auth_header, json=stager_1, ) @@ -263,39 +326,54 @@ def test_update_stager_name_conflict(client, admin_auth_header): == f"Stager with name {stager_2['name']} already exists." ) - -def test_get_stagers(client, admin_auth_header): - response = client.get( - "/api/v2/stagers", - headers=admin_auth_header, - ) - - assert response.status_code == 200 - assert len(response.json()["records"]) == 3 + client.delete(f"/api/v2/stagers/{stager_id}", headers=admin_auth_header) + client.delete(f"/api/v2/stagers/{stager_id_2}", headers=admin_auth_header) -def test_delete_stager(client, admin_auth_header): - response = client.get( - "/api/v2/stagers", +def test_get_stagers(client, admin_auth_header, base_stager): + response = client.post( + "/api/v2/stagers/?save=true", headers=admin_auth_header, + json=base_stager, ) - assert response.status_code == 200 - assert len(response.json()["records"]) == 3 + assert response.status_code == 201 + stager_id = response.json()["id"] - to_delete = response.json()["records"][0] - response = client.delete( - f"/api/v2/stagers/{to_delete['id']}", + base_stager_2 = base_stager.copy() + base_stager_2["name"] = "test_stager_2" + response = client.post( + "/api/v2/stagers/?save=true", headers=admin_auth_header, + json=base_stager_2, ) - assert response.status_code == 204 + assert response.status_code == 201 + stager_id_2 = response.json()["id"] response = client.get( "/api/v2/stagers", headers=admin_auth_header, ) + assert response.status_code == 200 assert len(response.json()["records"]) == 2 - assert response.json()["records"][0]["id"] != to_delete["id"] + assert response.json()["records"][0]["id"] == stager_id + assert response.json()["records"][1]["id"] == stager_id_2 + + client.delete(f"/api/v2/stagers/{stager_id}", headers=admin_auth_header) + client.delete(f"/api/v2/stagers/{stager_id_2}", headers=admin_auth_header) + + +def test_delete_stager(client, admin_auth_header, base_stager): + response = client.post( + "/api/v2/stagers/?save=true", + headers=admin_auth_header, + json=base_stager, + ) + assert response.status_code == 201 + stager_id = response.json()["id"] + + response = client.delete(f"/api/v2/stagers/{stager_id}", headers=admin_auth_header) + assert response.status_code == 204 def test_pyinstaller_stager_creation(client, pyinstaller_stager, admin_auth_header): @@ -330,3 +408,5 @@ def test_pyinstaller_stager_creation(client, pyinstaller_stager, admin_auth_head # Check if the downloaded file is not empty assert len(response.content) > 0 + + client.delete(f"/api/v2/stagers/{stager_id}", headers=admin_auth_header) diff --git a/empire/test/test_tags_api.py b/empire/test/test_tags_api.py new file mode 100644 index 000000000..5800afb52 --- /dev/null +++ b/empire/test/test_tags_api.py @@ -0,0 +1,581 @@ +import pytest + +from empire.server.core.db.models import PluginTaskStatus + + +def _test_add_tag(client, admin_auth_header, path, taggable_id): + resp = client.post( + f"{path}/{taggable_id}/tags", + headers=admin_auth_header, + json={"name": "test:tag", "value": "test:value"}, + ) + assert resp.status_code == 422 + assert resp.json() == { + "detail": [ + { + "ctx": {"pattern": "^[^:]+$"}, + "loc": ["body", "name"], + "msg": 'string does not match regex "^[^:]+$"', + "type": "value_error.str.regex", + }, + { + "ctx": {"pattern": "^[^:]+$"}, + "loc": ["body", "value"], + "msg": 'string does not match regex "^[^:]+$"', + "type": "value_error.str.regex", + }, + ] + } + + resp = client.post( + f"{path}/{taggable_id}/tags", + headers=admin_auth_header, + json={"name": "test_tag", "value": "test_value"}, + ) + + expected_tag_1 = { + "name": "test_tag", + "value": "test_value", + "color": None, + "label": "test_tag:test_value", + } + + assert resp.status_code == 201 + actual_tag_1 = resp.json() + actual_tag_1.pop("id") + assert actual_tag_1 == expected_tag_1 + + resp = client.get(f"{path}/{taggable_id}", headers=admin_auth_header) + assert resp.status_code == 200 + + actual_tags = resp.json()["tags"] + assert len(actual_tags) == 1 + + actual_tags[0].pop("id") + assert actual_tags == [expected_tag_1] + + resp = client.post( + f"{path}/{taggable_id}/tags", + headers=admin_auth_header, + json={ + "name": "test_tag", + "value": "test_value", + "color": "#0000FF", + }, + ) + + expected_tag_2 = { + "name": "test_tag", + "value": "test_value", + "color": "#0000FF", + "label": "test_tag:test_value", + } + + assert resp.status_code == 201 + actual_tag_2 = resp.json() + actual_tag_2.pop("id") + assert actual_tag_2 == expected_tag_2 + + resp = client.get(f"{path}/{taggable_id}", headers=admin_auth_header) + assert resp.status_code == 200 + + actual_tags = resp.json()["tags"] + assert len(actual_tags) == 2 + + for tag in actual_tags: + tag.pop("id") + + assert actual_tags == [expected_tag_1, expected_tag_2] + + for tag in resp.json()["tags"]: + resp = client.delete( + f"{path}/{taggable_id}/tags/{tag['id']}", + headers=admin_auth_header, + ) + assert resp.status_code == 204 + + +def _test_update_tag(client, admin_auth_header, path, taggable_id): + resp = client.post( + f"{path}/{taggable_id}/tags", + headers=admin_auth_header, + json={"name": "test_tag", "value": "test_value"}, + ) + + assert resp.status_code == 201 + + expected_tag = { + "name": "test_tag_updated", + "value": "test_value_updated", + "color": "#0000FF", + "label": "test_tag_updated:test_value_updated", + } + + resp_bad = client.put( + f"{path}/{taggable_id}/tags/{resp.json()['id']}", + headers=admin_auth_header, + json={"name": "test:tag", "value": "test:value"}, + ) + assert resp_bad.status_code == 422 + assert resp_bad.json() == { + "detail": [ + { + "ctx": {"pattern": "^[^:]+$"}, + "loc": ["body", "name"], + "msg": 'string does not match regex "^[^:]+$"', + "type": "value_error.str.regex", + }, + { + "ctx": {"pattern": "^[^:]+$"}, + "loc": ["body", "value"], + "msg": 'string does not match regex "^[^:]+$"', + "type": "value_error.str.regex", + }, + ] + } + + resp = client.put( + f"{path}/{taggable_id}/tags/{resp.json()['id']}", + headers=admin_auth_header, + json=expected_tag, + ) + + assert resp.status_code == 200 + + actual_tag = resp.json() + actual_tag.pop("id") + assert actual_tag == expected_tag + + resp = client.delete( + f"{path}/{taggable_id}/tags/{resp.json()['id']}", + headers=admin_auth_header, + ) + assert resp.status_code == 204 + + +def _test_delete_tag(client, admin_auth_header, path, taggable_id): + resp = client.post( + f"{path}/{taggable_id}/tags", + headers=admin_auth_header, + json={"name": "test_tag", "value": "test_value"}, + ) + + assert resp.status_code == 201 + + resp = client.delete( + f"{path}/{taggable_id}/tags/{resp.json()['id']}", + headers=admin_auth_header, + ) + assert resp.status_code == 204 + + resp = client.get(f"{path}/{taggable_id}", headers=admin_auth_header) + assert resp.status_code == 200 + assert resp.json()["tags"] == [] + + +def test_listener_add_tag(client, admin_auth_header, listener): + _test_add_tag(client, admin_auth_header, "/api/v2/listeners", listener["id"]) + + +def test_agent_add_tag(client, admin_auth_header, agent): + _test_add_tag(client, admin_auth_header, "/api/v2/agents", agent) + + +def test_agent_task_add_tag(client, admin_auth_header, agent_task): + _test_add_tag( + client, + admin_auth_header, + f"/api/v2/agents/{agent_task['agent_id']}/tasks", + agent_task["id"], + ) + + +def test_plugin_task_add_tag(client, admin_auth_header, plugin_task): + _test_add_tag( + client, + admin_auth_header, + "/api/v2/plugins/basic_reporting/tasks", + plugin_task, + ) + + +def test_credential_add_tag(client, admin_auth_header, credential): + _test_add_tag(client, admin_auth_header, "/api/v2/credentials", credential) + + +def test_download_add_tag(client, admin_auth_header, download): + _test_add_tag(client, admin_auth_header, "/api/v2/downloads", download) + + +def test_listener_update_tag(client, admin_auth_header, listener): + _test_update_tag(client, admin_auth_header, "/api/v2/listeners", listener["id"]) + + +def test_agent_update_tag(client, admin_auth_header, agent): + _test_update_tag(client, admin_auth_header, "/api/v2/agents", agent) + + +def test_agent_task_update_tag(client, admin_auth_header, agent_task): + _test_update_tag( + client, + admin_auth_header, + f"/api/v2/agents/{agent_task['agent_id']}/tasks", + agent_task["id"], + ) + + +def test_plugin_task_update_tag(client, admin_auth_header, plugin_task): + _test_update_tag( + client, + admin_auth_header, + "/api/v2/plugins/basic_reporting/tasks", + plugin_task, + ) + + +def test_credential_update_tag(client, admin_auth_header, credential): + _test_update_tag(client, admin_auth_header, "/api/v2/credentials", credential) + + +def test_download_update_tag(client, admin_auth_header, download): + _test_update_tag(client, admin_auth_header, "/api/v2/downloads", download) + + +def test_listener_delete_tag(client, admin_auth_header, listener): + _test_delete_tag(client, admin_auth_header, "/api/v2/listeners", listener["id"]) + + +def test_agent_delete_tag(client, admin_auth_header, agent): + _test_delete_tag(client, admin_auth_header, "/api/v2/agents", agent) + + +def test_agent_task_delete_tag(client, admin_auth_header, agent_task): + _test_delete_tag( + client, + admin_auth_header, + f"/api/v2/agents/{agent_task['agent_id']}/tasks", + agent_task["id"], + ) + + +def test_plugin_task_delete_tag(client, admin_auth_header, plugin_task): + _test_delete_tag( + client, + admin_auth_header, + "/api/v2/plugins/basic_reporting/tasks", + plugin_task, + ) + + +def test_credential_delete_tag(client, admin_auth_header, credential): + _test_delete_tag(client, admin_auth_header, "/api/v2/credentials", credential) + + +def test_download_delete_tag(client, admin_auth_header, download): + _test_delete_tag(client, admin_auth_header, "/api/v2/downloads", download) + + +@pytest.fixture(scope="function") +def _create_tags( + client, + admin_auth_header, + listener, + agent, + agent_task, + plugin_task, + credential, + download, +): + paths = [ + "/api/v2/listeners", + "/api/v2/agents", + f"/api/v2/agents/{agent_task['agent_id']}/tasks", + "/api/v2/plugins/basic_reporting/tasks", + "/api/v2/credentials", + "/api/v2/downloads", + ] + cleanup = [] + expected_tags = [] + for taggable in zip( + [listener, agent, agent_task, plugin_task, credential, download], + paths, + ): + if isinstance(taggable[0], dict): + taggable_id = taggable[0]["id"] + else: + taggable_id = taggable[0] + resp = client.post( + f"{taggable[1]}/{taggable_id}/tags", + headers=admin_auth_header, + json={"name": f"test_tag_{taggable[1]}", "value": "test_value"}, + ) + assert resp.status_code == 201 + + res = resp.json() + cleanup.append(f"{taggable[1]}/{taggable_id}/tags/{res['id']}") + res.pop("id") + expected_tags.append(res) + + yield expected_tags + + for tag in cleanup: + resp = client.delete(tag, headers=admin_auth_header) + assert resp.status_code == 204 + + +def test_get_tags(client, admin_auth_header, _create_tags): + expected_tags = _create_tags + resp = client.get("/api/v2/tags?order_by=name", headers=admin_auth_header) + assert resp.status_code == 200 + + actual_tags = resp.json()["records"] + for tag in actual_tags: + tag.pop("id") + + expected_tags = sorted(expected_tags, key=lambda k: k["name"]) + assert actual_tags == expected_tags + + +@pytest.fixture(scope="function") +def _create_agent_tasks_with_tags(client, admin_auth_header, agent): + agent_id = agent + agent_tasks = [] + tags = [] + for i in range(3): + resp = client.post( + f"/api/v2/agents/{agent_id}/tasks/shell", + headers=admin_auth_header, + json={"command": f"whoami_{i}"}, + ) + assert resp.status_code == 201 + agent_tasks.append(resp.json()) + + for i, agent_task in enumerate(agent_tasks): + resp = client.post( + f"/api/v2/agents/{agent_id}/tasks/{agent_task['id']}/tags", + headers=admin_auth_header, + json={"name": f"test_tag_{i}", "value": f"test_value_{i}"}, + ) + assert resp.status_code == 201 + tags.append((agent_task, resp.json())) + + yield agent_tasks + + for task, tag in tags: + resp = client.delete( + f"/api/v2/agents/{agent_id}/tasks/{task['id']}/tags/{tag['id']}", + headers=admin_auth_header, + ) + assert resp.status_code == 204 + + for agent_task in agent_tasks: + resp = client.delete( + f"/api/v2/agents/{agent_id}/tasks/{agent_task['id']}", + headers=admin_auth_header, + ) + assert resp.status_code == 204 + + +def test_get_agent_tasks_tag_filter( + client, admin_auth_header, agent, _create_agent_tasks_with_tags +): + resp = client.get(f"/api/v2/agents/{agent}/tasks", headers=admin_auth_header) + + assert resp.status_code == 200 + assert len(resp.json()["records"]) == 3 + + resp = client.get( + f"/api/v2/agents/{agent}/tasks?tags=test_tag_0:test_value_0", + headers=admin_auth_header, + ) + + assert resp.status_code == 200 + assert len(resp.json()["records"]) == 1 + assert resp.json()["records"][0]["input"] == "whoami_0" + assert resp.json()["records"][0]["tags"][0]["name"] == "test_tag_0" + + resp = client.get( + f"/api/v2/agents/{agent}/tasks?tags=test_tag_0:test_value_0&tags=test_tag_1:test_value_1", + headers=admin_auth_header, + ) + + assert resp.status_code == 200 + assert len(resp.json()["records"]) == 2 + assert resp.json()["records"][1]["input"] == "whoami_0" + assert resp.json()["records"][1]["tags"][0]["name"] == "test_tag_0" + assert resp.json()["records"][0]["input"] == "whoami_1" + assert resp.json()["records"][0]["tags"][0]["name"] == "test_tag_1" + + # Test tag value bad + resp = client.get( + f"/api/v2/agents/{agent}/tasks?tags=test_tag_0", headers=admin_auth_header + ) + + assert resp.status_code == 422 + assert ( + resp.json()["detail"][0]["msg"] == 'string does not match regex "^[^:]+:[^:]+$"' + ) + + +@pytest.fixture(scope="function") +def _create_plugin_tasks_with_tags( + models, session_local, client, admin_auth_header, plugin_name +): + plugin_tasks = [] + tags = [] + for i in range(3): + plugin_task = models.PluginTask( + plugin_id=plugin_name, + input=f"input {i}", + input_full=f"input {i}", + user_id=None, + status=PluginTaskStatus.completed, + ) + with session_local.begin() as db: + db.add(plugin_task) + db.flush() + plugin_tasks.append({"id": plugin_task.id}) + + for i, plugin_task in enumerate(plugin_tasks): + resp = client.post( + f"/api/v2/plugins/{plugin_name}/tasks/{plugin_task['id']}/tags", + headers=admin_auth_header, + json={"name": f"test_tag_{i}", "value": f"test_value_{i}"}, + ) + assert resp.status_code == 201 + tags.append((plugin_task, resp.json())) + + yield plugin_tasks + + for task, tag in tags: + resp = client.delete( + f"/api/v2/plugins/{plugin_name}/tasks/{task['id']}/tags/{tag['id']}", + headers=admin_auth_header, + ) + assert resp.status_code == 204 + + with session_local.begin() as db: + db.query(models.PluginTask).delete() + + +def test_get_plugin_tasks_tag_filter( + client, admin_auth_header, plugin_name, _create_plugin_tasks_with_tags +): + resp = client.get(f"/api/v2/plugins/{plugin_name}/tasks", headers=admin_auth_header) + + assert resp.status_code == 200 + assert len(resp.json()["records"]) == 3 + + resp = client.get( + f"/api/v2/plugins/{plugin_name}/tasks?tags=test_tag_0:test_value_0", + headers=admin_auth_header, + ) + + assert resp.status_code == 200 + assert len(resp.json()["records"]) == 1 + assert resp.json()["records"][0]["input"] == "input 0" + assert resp.json()["records"][0]["tags"][0]["name"] == "test_tag_0" + + resp = client.get( + f"/api/v2/plugins/{plugin_name}/tasks?tags=test_tag_0:test_value_0&tags=test_tag_1:test_value_1", + headers=admin_auth_header, + ) + + assert resp.status_code == 200 + assert len(resp.json()["records"]) == 2 + assert resp.json()["records"][1]["input"] == "input 0" + assert resp.json()["records"][1]["tags"][0]["name"] == "test_tag_0" + assert resp.json()["records"][0]["input"] == "input 1" + assert resp.json()["records"][0]["tags"][0]["name"] == "test_tag_1" + + # Test tag value bad + resp = client.get( + f"/api/v2/plugins/{plugin_name}/tasks?tags=test_tag_0", + headers=admin_auth_header, + ) + + assert resp.status_code == 422 + assert ( + resp.json()["detail"][0]["msg"] == 'string does not match regex "^[^:]+:[^:]+$"' + ) + + +@pytest.fixture(scope="function") +def _create_downloads_with_tags(models, session_local, client, admin_auth_header): + downloads = [] + tags = [] + with session_local.begin() as db: + # Unsure why this is needed, but it is. + # Some other test must be adding a download and not removing it. + db.query(models.upload_download_assc).delete() + db.query(models.Download).delete() + + for i in range(3): + download = models.Download( + location=f"path/{i}", filename=f"filename_{i}", size=1 + ) + with session_local.begin() as db: + db.add(download) + db.flush() + downloads.append({"id": download.id}) + + for i, download in enumerate(downloads): + resp = client.post( + f"/api/v2/downloads/{download['id']}/tags", + headers=admin_auth_header, + json={"name": f"test_tag_{i}", "value": f"test_value_{i}"}, + ) + assert resp.status_code == 201 + tags.append(resp.json()) + + yield downloads + + for tag in tags: + resp = client.delete( + f"/api/v2/downloads/{downloads[0]['id']}/tags/{tag['id']}", + headers=admin_auth_header, + ) + assert resp.status_code == 204 + + with session_local.begin() as db: + db.query(models.download_tag_assc).delete() + db.query(models.Download).delete() + + +def test_get_downloads_tag_filter( + client, admin_auth_header, _create_downloads_with_tags +): + resp = client.get("/api/v2/downloads/", headers=admin_auth_header) + + assert resp.status_code == 200 + assert len(resp.json()["records"]) == 3 + + resp = client.get( + "/api/v2/downloads?tags=test_tag_0:test_value_0", + headers=admin_auth_header, + ) + + assert resp.status_code == 200 + assert len(resp.json()["records"]) == 1 + assert resp.json()["records"][0]["location"] == "path/0" + assert resp.json()["records"][0]["tags"][0]["name"] == "test_tag_0" + + resp = client.get( + "/api/v2/downloads?tags=test_tag_0:test_value_0&tags=test_tag_1:test_value_1", + headers=admin_auth_header, + ) + + assert resp.status_code == 200 + assert len(resp.json()["records"]) == 2 + assert resp.json()["records"][1]["location"] == "path/1" + assert resp.json()["records"][1]["tags"][0]["name"] == "test_tag_1" + assert resp.json()["records"][0]["location"] == "path/0" + assert resp.json()["records"][0]["tags"][0]["name"] == "test_tag_0" + + # Test tag value bad + resp = client.get("/api/v2/downloads?tags=test_tag_0", headers=admin_auth_header) + + assert resp.status_code == 422 + assert ( + resp.json()["detail"][0]["msg"] == 'string does not match regex "^[^:]+:[^:]+$"' + )