Skip to content

Commit

Permalink
fix(schema): Use ExecutionSchema from galileo-core
Browse files Browse the repository at this point in the history
  • Loading branch information
setu4993 committed Jul 30, 2024
1 parent 754b2c2 commit ad41758
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 14 deletions.
3 changes: 2 additions & 1 deletion src/galileo_protect/langchain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional, Sequence, Type

from galileo_core.schemas.protect.execution_status import ExecutionStatus
from galileo_core.schemas.protect.response import Response
from langchain_core.runnables.base import Runnable
from langchain_core.tools import BaseTool
Expand Down Expand Up @@ -87,7 +88,7 @@ def parser(self, response_raw_json: str) -> str:
text = response.text
if self.echo_output:
print(f"> Raw response: {text}")
if response.status == "TRIGGERED" and not self.ignore_trigger:
if response.status == ExecutionStatus.triggered and not self.ignore_trigger:
return text
else:
return self.chain.invoke(text)
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from galileo_core.constants.request_method import RequestMethod
from galileo_core.constants.routes import Routes as CoreRoutes
from galileo_core.schemas.protect.execution_status import ExecutionStatus
from galileo_core.schemas.protect.response import Response, TraceMetadata
from galileo_core.schemas.protect.rule import Rule, RuleOperator
from galileo_core.schemas.protect.ruleset import Ruleset
Expand Down Expand Up @@ -66,9 +67,9 @@ def mock_invoke(mock_request: Mock) -> Generator[None, None, None]:
matcher = mock_request(
RequestMethod.POST,
Routes.invoke,
json=Response(text=A_PROTECT_INPUT, status="NOT_TRIGGERED", trace_metadata=TraceMetadata()).model_dump(
mode="json"
),
json=Response(
text=A_PROTECT_INPUT, status=ExecutionStatus.not_triggered, trace_metadata=TraceMetadata()
).model_dump(mode="json"),
)
yield matcher
assert matcher.called
Expand Down
3 changes: 2 additions & 1 deletion tests/test_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest.mock import Mock
from uuid import uuid4

from galileo_core.schemas.protect.execution_status import ExecutionStatus
from galileo_core.schemas.protect.response import Response
from pytest import mark

Expand Down Expand Up @@ -132,5 +133,5 @@ def test_langchain_tool(
assert isinstance(response_json, str)
response = Response.model_validate_json(response_json)
assert response.text is not None
assert response.status == "NOT_TRIGGERED"
assert response.status == ExecutionStatus.not_triggered
assert mock_invoke.called
39 changes: 30 additions & 9 deletions tests/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, List, Optional
from unittest.mock import patch

from galileo_core.schemas.protect.execution_status import ExecutionStatus
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from pytest import CaptureFixture, mark
Expand All @@ -28,14 +29,34 @@ def _call(
@mark.parametrize(
["output", "ignore_trigger", "expected_return", "expected_call_count"],
[
[dumps({"text": "foo", "status": "NOT_TRIGGERED", **A_TRACE_METADATA_DICT}), False, "foo", 1],
[dumps({"text": "foo", "status": "NOT_TRIGGERED", **A_TRACE_METADATA_DICT}), True, "foo", 1],
[dumps({"text": "timeout", "status": "TIMEOUT", **A_TRACE_METADATA_DICT}), False, "timeout", 1],
[dumps({"text": "timeout", "status": "TIMEOUT", **A_TRACE_METADATA_DICT}), True, "timeout", 1],
[dumps({"text": "success", "status": "SUCCESS", **A_TRACE_METADATA_DICT}), False, "success", 1],
[dumps({"text": "success", "status": "SUCCESS", **A_TRACE_METADATA_DICT}), True, "success", 1],
[dumps({"text": "triggered", "status": "TRIGGERED", **A_TRACE_METADATA_DICT}), False, "triggered", 0],
[dumps({"text": "triggered", "status": "TRIGGERED", **A_TRACE_METADATA_DICT}), True, "triggered", 1],
[dumps({"text": "foo", "status": ExecutionStatus.not_triggered, **A_TRACE_METADATA_DICT}), False, "foo", 1],
[dumps({"text": "foo", "status": ExecutionStatus.not_triggered, **A_TRACE_METADATA_DICT}), True, "foo", 1],
[dumps({"text": "timeout", "status": ExecutionStatus.timeout, **A_TRACE_METADATA_DICT}), False, "timeout", 1],
[dumps({"text": "timeout", "status": ExecutionStatus.timeout, **A_TRACE_METADATA_DICT}), True, "timeout", 1],
[
dumps({"text": "success", "status": ExecutionStatus.not_triggered, **A_TRACE_METADATA_DICT}),
False,
"success",
1,
],
[
dumps({"text": "success", "status": ExecutionStatus.not_triggered, **A_TRACE_METADATA_DICT}),
True,
"success",
1,
],
[
dumps({"text": "triggered", "status": ExecutionStatus.triggered, **A_TRACE_METADATA_DICT}),
False,
"triggered",
0,
],
[
dumps({"text": "triggered", "status": ExecutionStatus.triggered, **A_TRACE_METADATA_DICT}),
True,
"triggered",
1,
],
],
)
def test_parser(output: str, ignore_trigger: bool, expected_return: str, expected_call_count: int) -> None:
Expand All @@ -49,6 +70,6 @@ def test_parser(output: str, ignore_trigger: bool, expected_return: str, expecte
@mark.parametrize(["echo_output", "expected_output"], [[True, "> Raw response: foo\n"], [False, ""]])
def test_echo(echo_output: bool, expected_output: str, capsys: CaptureFixture) -> None:
parser = ProtectParser(chain=ProtectLLM(), echo_output=echo_output)
parser.parser(dumps({"text": "foo", "status": "NOT_TRIGGERED", **A_TRACE_METADATA_DICT}))
parser.parser(dumps({"text": "foo", "status": ExecutionStatus.not_triggered, **A_TRACE_METADATA_DICT}))
captured = capsys.readouterr()
assert captured.out == expected_output

0 comments on commit ad41758

Please sign in to comment.