Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Callable-like types in get_request_model (eg streaming_callback) + Fix component output serialization #41

Merged
merged 3 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions src/hayhooks/server/pipelines/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pandas import DataFrame
from pydantic import BaseModel, ConfigDict, create_model

from hayhooks.server.utils.create_valid_type import handle_unsupported_types


Expand Down Expand Up @@ -31,10 +30,12 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
except TypeError as e:
print(f"ERROR at {component_name!r}, {name}: {typedef}")
raise e
component_model[name] = (
input_type,
typedef.get("default_value", ...),
)

if input_type is not None:
component_model[name] = (
input_type,
typedef.get("default_value", ...),
)
request_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)

return create_model(f"{pipeline_name.capitalize()}RunRequest", **request_model, __config__=config)
Expand All @@ -61,30 +62,34 @@ def get_response_model(pipeline_name: str, pipeline_outputs):
return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config)


def convert_value_to_dict(value):
"""Convert a single value to a dictionary if possible"""
if hasattr(value, "to_dict"):
if "init_parameters" in value.to_dict():
return value.to_dict()["init_parameters"]
return value.to_dict()
elif hasattr(value, "model_dump"):
return value.model_dump()
elif isinstance(value, dict):
return {k: convert_value_to_dict(v) for k, v in value.items()}
elif isinstance(value, list):
return [convert_value_to_dict(item) for item in value]
else:
return value


def convert_component_output(component_output):
"""
Converts outputs from a component as a dict so that it can be validated against response model

Component output has this form:
Converts component outputs to dictionaries that can be validated against response model.
Handles nested structures recursively.

"documents":[
{"id":"818170...", "content":"RapidAPI for Mac is a full-featured HTTP client."}
]
Args:
component_output: Dict with component outputs

Returns:
Dict with all nested objects converted to dictionaries
"""
result = {}
for output_name, data in component_output.items():

def get_value(data):
if hasattr(data, "to_dict") and "init_parameters" in data.to_dict():
return data.to_dict()["init_parameters"]
elif hasattr(data, "to_dict"):
return data.to_dict()
else:
return data

if type(data) is list:
result[output_name] = [get_value(d) for d in data]
else:
result[output_name] = get_value(data)
return result
if isinstance(component_output, dict):
return {name: convert_value_to_dict(data) for name, data in component_output.items()}

return convert_value_to_dict(component_output)
48 changes: 33 additions & 15 deletions src/hayhooks/server/utils/create_valid_type.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
from collections.abc import Callable as CallableABC
from inspect import isclass
from types import GenericAlias
from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints
from typing import Callable, Dict, Optional, Union, get_args, get_origin, get_type_hints


def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]:
def is_callable_type(t):
"""Check if a type is any form of callable"""
if t in (Callable, CallableABC):
return True

# Check origin type
origin = get_origin(t)
if origin in (Callable, CallableABC):
return True

# Handle Optional/Union types
if origin in (Union, type(Optional[int])): # type(Optional[int]) handles runtime Optional type
args = get_args(t)
return any(is_callable_type(arg) for arg in args)

return False


def handle_unsupported_types(
type_: type, types_mapping: Dict[type, type], skip_callables: bool = True
) -> Union[GenericAlias, type, None]:
"""
Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping.

:param type_: Type to replace if not supported
:param types_mapping: Mapping of types to replace
"""

def _handle_generics(t_) -> GenericAlias:
"""
Handle generics recursively
"""
def handle_generics(t_) -> Union[GenericAlias, None]:
"""Handle generics recursively"""
if is_callable_type(t_) and skip_callables:
return None

child_typing = []
for t in get_args(t_):
if t in types_mapping:
Expand All @@ -26,20 +45,19 @@ def _handle_generics(t_) -> GenericAlias:
child_typing.append(result)

if len(child_typing) == 2 and child_typing[1] is type(None):
# because TypedDict can't handle union types with None
# rewrite them as Optional[type]
return Optional[child_typing[0]]
else:
return GenericAlias(get_origin(t_), tuple(child_typing))

if is_callable_type(type_) and skip_callables:
return None

if isclass(type_):
new_type = {}
for arg_name, arg_type in get_type_hints(type_).items():
if get_args(arg_type):
new_type[arg_name] = _handle_generics(arg_type)
new_type[arg_name] = handle_generics(arg_type)
else:
new_type[arg_name] = arg_type

return type_

return _handle_generics(type_)
return handle_generics(type_)
5 changes: 3 additions & 2 deletions src/hayhooks/server/utils/deploy_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from fastapi.concurrency import run_in_threadpool
from fastapi.responses import JSONResponse

from hayhooks.server.pipelines import registry
from hayhooks.server.pipelines.models import (
PipelineDefinition,
convert_component_output,
get_request_model,
get_response_model,
convert_component_output,
)


def deploy_pipeline_def(app, pipeline_def: PipelineDefinition):
try:
pipe = registry.add(pipeline_def.name, pipeline_def.source_code)
Expand Down
43 changes: 43 additions & 0 deletions tests/test_convert_component_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from hayhooks.server.pipelines.models import convert_component_output
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails


def test_convert_component_output_with_nested_models():
sample_response = [
{
'model': 'gpt-4o-mini-2024-07-18',
'index': 0,
'finish_reason': 'stop',
'usage': {
'completion_tokens': 52,
'prompt_tokens': 29,
'total_tokens': 81,
'completion_tokens_details': CompletionTokensDetails(
accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0
),
'prompt_tokens_details': PromptTokensDetails(audio_tokens=0, cached_tokens=0),
},
}
]

converted_output = convert_component_output(sample_response)

assert converted_output == [
{
'model': 'gpt-4o-mini-2024-07-18',
'index': 0,
'finish_reason': 'stop',
'usage': {
'completion_tokens': 52,
'prompt_tokens': 29,
'total_tokens': 81,
'completion_tokens_details': {
'accepted_prediction_tokens': 0,
'audio_tokens': 0,
'reasoning_tokens': 0,
'rejected_prediction_tokens': 0,
},
'prompt_tokens_details': {'audio_tokens': 0, 'cached_tokens': 0},
},
}
]
54 changes: 54 additions & 0 deletions tests/test_handle_callable_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from collections.abc import Callable as CallableABC
from types import NoneType
from typing import Any, Callable, Optional, Union

import haystack
import pytest

from hayhooks.server.pipelines.models import get_request_model
from hayhooks.server.utils.create_valid_type import is_callable_type


@pytest.mark.parametrize(
"t, expected",
[
(Callable, True),
(CallableABC, True),
(Callable[[int], str], True),
(Callable[..., Any], True),
(int, False),
(str, False),
(Any, False),
(Union[int, str], False),
(Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]], True),
],
)
def test_is_callable_type(t, expected):
assert is_callable_type(t) == expected


def test_skip_callables_when_creating_pipeline_models():
pipeline_name = "test_pipeline"
pipeline_inputs = {
"generator": {
"system_prompt": {"type": Optional[str], "is_mandatory": False, "default_value": None},
"streaming_callback": {
"type": Optional[Callable[[haystack.dataclasses.streaming_chunk.StreamingChunk], NoneType]],
"is_mandatory": False,
"default_value": None,
},
"generation_kwargs": {
"type": Optional[dict[str, Any]],
"is_mandatory": False,
"default_value": None,
},
}
}

request_model = get_request_model(pipeline_name, pipeline_inputs)

# This line used to throw an error because the Callable type was not handled correctly
# by the handle_unsupported_types function
assert request_model.model_json_schema() is not None
assert request_model.__name__ == "Test_pipelineRunRequest"
assert "streaming_callback" not in request_model.model_json_schema()["$defs"]["ComponentParams"]["properties"]