Skip to content

Commit

Permalink
feat(stages): Add helper method to get a stage by name or ID (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
setu4993 authored Jun 5, 2024
1 parent 1fbd895 commit 3aa515e
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 210 deletions.
264 changes: 123 additions & 141 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ packages = [{include = "galileo_protect", from = "src"}]

[tool.poetry.dependencies]
python = "^3.8.1,<3.13"
galileo-core = "^1.3.0"
galileo-core = "^1.8.0"

langchain-core = { version = ">=0.1.52,<0.3.0", optional = true }

Expand All @@ -25,7 +25,7 @@ pytest-cov = "^5.0.0"
pytest-xdist = "^3.5.0"
pytest-socket = "^0.7.0"
pytest-asyncio = "^0.23.6"
galileo-core = { extras = ["testing"], version = "^1.2.0" }
galileo-core = { extras = ["testing"], version = "^1.8.0" }


[tool.poetry.group.dev.dependencies]
Expand Down
2 changes: 1 addition & 1 deletion src/galileo_protect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Ruleset,
Stage,
)
from galileo_protect.stage import create_stage, pause_stage, resume_stage
from galileo_protect.stage import create_stage, get_stage, pause_stage, resume_stage

if is_dependency_available("langchain_core"):
from galileo_protect.langchain import ProtectParser, ProtectTool
Expand Down
66 changes: 65 additions & 1 deletion src/galileo_protect/stage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional
from typing import Dict, Optional

from galileo_core.constants.request_method import RequestMethod
from galileo_core.helpers.project import get_project_from_name
from galileo_core.utils.name import ts_name
from pydantic import UUID4

Expand Down Expand Up @@ -36,6 +37,69 @@ def create_stage(
return stage


def get_stage(
project_id: Optional[UUID4] = None,
project_name: Optional[str] = None,
stage_id: Optional[UUID4] = None,
stage_name: Optional[str] = None,
config: Optional[ProtectConfig] = None,
) -> StageResponse:
"""
Get a stage by ID or name.
Parameters
----------
project_id : Optional[UUID4], optional
Project ID, by default, we will try to get it from the config.
project_name : Optional[str], optional
Project name, by default we will try to get it from the server if the project
ID is not provided.
stage_id : Optional[UUID4], optional
Stage ID, by default we will try to get it from the config.
stage_name : Optional[str], optional
Stage name, by default we will try to get it from the config.
config : Optional[ProtectConfig], optional
Protect config, by default we will get it from the env vars or the local
config file.
Returns
-------
StageResponse
The stage response.
Raises
------
ValueError
If the project ID is not provided or found.
"""
config = config or ProtectConfig.get()
project_id = project_id or config.project_id
stage_id = stage_id or config.stage_id
stage_name = stage_name or config.stage_name
if project_id is None:
if project_name:
project = get_project_from_name(project_name, config=config, raise_if_missing=True)
assert project is not None, "Project should not be None."
project_id = project.id
else:
raise ValueError("Project ID or name must be provided to get a stage.")
params: Dict[str, str] = dict()
if stage_id:
params["stage_id"] = str(stage_id)
if stage_name:
params["stage_name"] = stage_name
if not params:
raise ValueError("Stage ID or name must be provided to get a stage.")
stage = StageResponse.model_validate(
config.api_client.request(RequestMethod.GET, Routes.stages.format(project_id=project_id), params=params)
)
config.project_id = project_id
config.stage_id = stage.id
config.stage_name = stage.name
config.write()
return stage


def pause_stage(
project_id: Optional[UUID4] = None, stage_id: Optional[UUID4] = None, config: Optional[ProtectConfig] = None
) -> None:
Expand Down
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from galileo_core.constants.request_method import RequestMethod
from galileo_core.constants.routes import Routes as CoreRoutes
from galileo_core.schemas.protect.response import Response
from galileo_core.schemas.protect.response import Response, TraceMetadata
from pytest import MonkeyPatch, fixture

from galileo_protect.constants.routes import Routes
Expand Down Expand Up @@ -57,7 +57,11 @@ def curry(
@fixture
def mock_invoke(mock_request: Mock) -> Generator[None, None, None]:
matcher = mock_request(
RequestMethod.POST, Routes.invoke, json=Response(text=A_PROTECT_INPUT, status="NOT_TRIGGERED").model_dump()
RequestMethod.POST,
Routes.invoke,
json=Response(text=A_PROTECT_INPUT, status="NOT_TRIGGERED", trace_metadata=TraceMetadata()).model_dump(
mode="json"
),
)
yield matcher
assert matcher.called
9 changes: 9 additions & 0 deletions tests/data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
A_CONSOLE_URL = "https://console.test.rungalileo.io/"
A_PROJECT_NAME = "project_name"
A_STAGE_NAME = "stage_name"
A_JWT_TOKEN = "secret_jwt_token"
A_PROTECT_INPUT = "invoke"
A_TRACE_METADATA_DICT = {
"trace_metadata": {
"id": "57f7ec49-8e44-42cb-8825-4a971e44b252",
"received_at": 1717538501568372000,
"response_at": 1717538501568372000,
"execution_time": 0.46,
}
}
19 changes: 10 additions & 9 deletions tests/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pytest import CaptureFixture, mark

from galileo_protect.langchain import ProtectParser
from tests.data import A_TRACE_METADATA_DICT


class ProtectLLM(LLM):
Expand All @@ -27,14 +28,14 @@ def _call(
@mark.parametrize(
["output", "ignore_trigger", "expected_return", "expected_call_count"],
[
[dumps({"text": "foo"}), False, "foo", 1],
[dumps({"text": "foo"}), True, "foo", 1],
[dumps({"text": "timeout", "status": "TIMEOUT"}), False, "timeout", 1],
[dumps({"text": "timeout", "status": "TIMEOUT"}), True, "timeout", 1],
[dumps({"text": "success", "status": "SUCCESS"}), False, "success", 1],
[dumps({"text": "success", "status": "SUCCESS"}), True, "success", 1],
[dumps({"text": "triggered", "status": "TRIGGERED"}), False, "triggered", 0],
[dumps({"text": "triggered", "status": "TRIGGERED"}), True, "triggered", 1],
[dumps({"text": "foo", **A_TRACE_METADATA_DICT}), False, "foo", 1],
[dumps({"text": "foo", **A_TRACE_METADATA_DICT}), True, "foo", 1],
[dumps({"text": "timeout", "status": "TIMEOUT", **A_TRACE_METADATA_DICT}), False, "timeout", 1],
[dumps({"text": "timeout", "status": "TIMEOUT", **A_TRACE_METADATA_DICT}), True, "timeout", 1],
[dumps({"text": "success", "status": "SUCCESS", **A_TRACE_METADATA_DICT}), False, "success", 1],
[dumps({"text": "success", "status": "SUCCESS", **A_TRACE_METADATA_DICT}), True, "success", 1],
[dumps({"text": "triggered", "status": "TRIGGERED", **A_TRACE_METADATA_DICT}), False, "triggered", 0],
[dumps({"text": "triggered", "status": "TRIGGERED", **A_TRACE_METADATA_DICT}), True, "triggered", 1],
],
)
def test_parser(output: str, ignore_trigger: bool, expected_return: str, expected_call_count: int) -> None:
Expand All @@ -48,6 +49,6 @@ def test_parser(output: str, ignore_trigger: bool, expected_return: str, expecte
@mark.parametrize(["echo_output", "expected_output"], [[True, "> Raw response: foo\n"], [False, ""]])
def test_echo(echo_output: bool, expected_output: str, capsys: CaptureFixture) -> None:
parser = ProtectParser(chain=ProtectLLM(), echo_output=echo_output)
parser.parser(dumps({"text": "foo"}))
parser.parser(dumps({"text": "foo", **A_TRACE_METADATA_DICT}))
captured = capsys.readouterr()
assert captured.out == expected_output
10 changes: 6 additions & 4 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
from galileo_core.schemas.core.project import ProjectType

from galileo_protect.project import create_project
from tests.data import A_PROJECT_NAME


def test_create_project(set_validated_config: Callable, mock_request: Callable) -> None:
config = set_validated_config()
project_name = "foo-bar"
project_id = uuid4()
matcher_get = mock_request(RequestMethod.GET, CoreRoutes.projects + f"?project_name={project_name}", json=[])
matcher_get = mock_request(
RequestMethod.GET, CoreRoutes.projects, params=dict(project_name=A_PROJECT_NAME), json=[]
)
matcher_post = mock_request(
RequestMethod.POST,
CoreRoutes.projects,
json={"id": str(project_id), "type": ProjectType.protect, "name": project_name},
json={"id": str(project_id), "type": ProjectType.protect, "name": A_PROJECT_NAME},
)
create_project(name=project_name, config=config)
create_project(name=A_PROJECT_NAME, config=config)
assert matcher_get.called
assert matcher_post.called
# Verify that the project ID was set in the config.
Expand Down
Loading

0 comments on commit 3aa515e

Please sign in to comment.