Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
vblagoje committed Dec 20, 2024
1 parent 2dbb8d2 commit b35c098
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 16 deletions.
4 changes: 2 additions & 2 deletions haystack_experimental/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from haystack.utils import deserialize_callable, serialize_callable
from pydantic import TypeAdapter, create_model

from haystack_experimental.tools import create_tool_parameters_schema
from haystack_experimental.tools.component_schema import _create_tool_parameters_schema

with LazyImport(message="Run 'pip install jsonschema'") as jsonschema_import:
from jsonschema import Draft202012Validator
Expand Down Expand Up @@ -225,7 +225,7 @@ def from_component(cls, component: Component, name: str, description: str) -> "T
raise ValueError(message)

# Create the tools schema from the component run method parameters
tool_schema = create_tool_parameters_schema(component)
tool_schema = _create_tool_parameters_schema(component)

def component_invoker(**kwargs):
"""
Expand Down
4 changes: 0 additions & 4 deletions haystack_experimental/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

from .component_schema import create_tool_parameters_schema

__all__ = ["create_tool_parameters_schema"]
20 changes: 10 additions & 10 deletions haystack_experimental/tools/component_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
logger = logging.getLogger(__name__)


def create_tool_parameters_schema(component: Component) -> Dict[str, Any]:
def _create_tool_parameters_schema(component: Component) -> Dict[str, Any]:
"""
Creates an OpenAI tools schema from a component's run method parameters.
Expand All @@ -25,14 +25,14 @@ def create_tool_parameters_schema(component: Component) -> Dict[str, Any]:
properties = {}
required = []

param_descriptions = get_param_descriptions(component.run)
param_descriptions = _get_param_descriptions(component.run)

for input_name, socket in component.__haystack_input__._sockets_dict.items():
input_type = socket.type
description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")

try:
property_schema = create_property_schema(input_type, description)
property_schema = _create_property_schema(input_type, description)
except ValueError as e:
raise ValueError(f"Error processing input '{input_name}': {e}")

Expand All @@ -50,7 +50,7 @@ def create_tool_parameters_schema(component: Component) -> Dict[str, Any]:
return parameters_schema


def get_param_descriptions(method: Callable) -> Dict[str, str]:
def _get_param_descriptions(method: Callable) -> Dict[str, str]:
"""
Extracts parameter descriptions from the method's docstring using docstring_parser.
Expand All @@ -76,7 +76,7 @@ def get_param_descriptions(method: Callable) -> Dict[str, str]:
return param_descriptions


def is_nullable_type(python_type: Any) -> bool:
def _is_nullable_type(python_type: Any) -> bool:
"""
Checks if the type is a Union with NoneType (i.e., Optional).
Expand All @@ -97,7 +97,7 @@ def _create_list_schema(item_type: Any, description: str) -> Dict[str, Any]:
:param description: The description of the list.
:returns: A dictionary representing the list schema.
"""
items_schema = create_property_schema(item_type, "")
items_schema = _create_property_schema(item_type, "")
items_schema.pop("description", None)
return {"type": "array", "description": description, "items": items_schema}

Expand All @@ -115,7 +115,7 @@ def _create_dataclass_schema(python_type: Any, description: str) -> Dict[str, An
for field in fields(cls):
field_description = f"Field '{field.name}' of '{cls.__name__}'."
if isinstance(schema["properties"], dict):
schema["properties"][field.name] = create_property_schema(field.type, field_description)
schema["properties"][field.name] = _create_property_schema(field.type, field_description)
return schema


Expand All @@ -133,7 +133,7 @@ def _create_pydantic_schema(python_type: Any, description: str) -> Dict[str, Any
for m_name, m_field in python_type.model_fields.items():
field_description = f"Field '{m_name}' of '{python_type.__name__}'."
if isinstance(schema["properties"], dict):
schema["properties"][m_name] = create_property_schema(m_field.annotation, field_description)
schema["properties"][m_name] = _create_property_schema(m_field.annotation, field_description)
if m_field.is_required():
required_fields.append(m_name)

Expand All @@ -154,7 +154,7 @@ def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, A
return {"type": type_mapping.get(python_type, "string"), "description": description}


def create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]:
def _create_property_schema(python_type: Any, description: str, default: Any = None) -> Dict[str, Any]:
"""
Creates a property schema for a given Python type, recursively if necessary.
Expand All @@ -163,7 +163,7 @@ def create_property_schema(python_type: Any, description: str, default: Any = No
:param default: The default value of the property.
:returns: A dictionary representing the property schema.
"""
nullable = is_nullable_type(python_type)
nullable = _is_nullable_type(python_type)
if nullable:
non_none_types = [t for t in get_args(python_type) if t is not type(None)]
python_type = non_none_types[0] if non_none_types else str
Expand Down

0 comments on commit b35c098

Please sign in to comment.