Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 better handling for function calling models (#43) #44

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 56 additions & 56 deletions aide/backend/backend_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,13 @@
openai.InternalServerError,
)

# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
SUPPORTED_FUNCTION_CALL_MODELS = {
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
}


@once
def _setup_openai_client():
global _client
_client = openai.OpenAI(max_retries=0)


def is_function_call_supported(model_name: str) -> bool:
"""Return True if the model supports function calling."""
return model_name in SUPPORTED_FUNCTION_CALL_MODELS


def query(
system_message: str | None,
user_message: str | None,
Expand All @@ -56,64 +34,86 @@ def query(
) -> tuple[OutputType, float, int, int, dict]:
"""
Query the OpenAI API, optionally with function calling.
Function calling support is only checked for feedback/review operations.
If the model doesn't support function calling, gracefully degrade to text generation.
"""
_setup_openai_client()
filtered_kwargs: dict = select_values(notnone, model_kwargs)
model_name = filtered_kwargs.get("model", "")
logger.debug(f"OpenAI query called with model='{model_name}'")

# Convert system/user messages to the format required by the client
messages = opt_messages_to_list(system_message, user_message)

# If function calling is requested, attach the function spec
if func_spec is not None:
# Only check function call support for feedback/search operations
if func_spec.name == "submit_review":
if not is_function_call_supported(model_name):
logger.warning(
f"Review function calling was requested, but model '{model_name}' "
"does not support function calling. Falling back to plain text generation."
)
filtered_kwargs.pop("tools", None)
filtered_kwargs.pop("tool_choice", None)
else:
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict

completion = None
t0 = time.time()
completion = backoff_create(
_client.chat.completions.create,
OPENAI_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
req_time = time.time() - t0

# Attempt the API call
try:
completion = backoff_create(
_client.chat.completions.create,
OPENAI_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
except openai.error.InvalidRequestError as e:
# Check whether the error indicates that function calling is not supported
if "function calling" in str(e).lower() or "tools" in str(e).lower():
logger.warning(
"Function calling was attempted but is not supported by this model. "
"Falling back to plain text generation."
)
# Remove function-calling parameters and retry
filtered_kwargs.pop("tools", None)
filtered_kwargs.pop("tool_choice", None)

# Retry without function calling
completion = backoff_create(
_client.chat.completions.create,
OPENAI_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
else:
# If it's some other error, re-raise
raise

req_time = time.time() - t0
choice = completion.choices[0]

# Decide how to parse the response
if func_spec is None or "tools" not in filtered_kwargs:
# No function calling was ultimately used
output = choice.message.content
else:
# Attempt to extract tool calls
tool_calls = getattr(choice.message, "tool_calls", None)

if not tool_calls:
logger.warning(
f"No function call used despite function spec. Fallback to text. "
"No function call was used despite function spec. Fallback to text.\n"
f"Message content: {choice.message.content}"
)
output = choice.message.content
else:
first_call = tool_calls[0]
assert first_call.function.name == func_spec.name, (
f"Function name mismatch: expected {func_spec.name}, "
f"got {first_call.function.name}"
)
try:
output = json.loads(first_call.function.arguments)
except json.JSONDecodeError as e:
logger.error(
f"Error decoding function arguments:\n{first_call.function.arguments}"
# Optional: verify that the function name matches
if first_call.function.name != func_spec.name:
logger.warning(
f"Function name mismatch: expected {func_spec.name}, "
f"got {first_call.function.name}. Fallback to text."
)
raise e
output = choice.message.content
else:
try:
output = json.loads(first_call.function.arguments)
except json.JSONDecodeError as ex:
logger.error(
"Error decoding function arguments:\n"
f"{first_call.function.arguments}"
)
raise ex

in_tokens = completion.usage.prompt_tokens
out_tokens = completion.usage.completion_tokens
Expand Down
15 changes: 14 additions & 1 deletion aide/utils/tree_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def normalize_layout(layout: np.ndarray):
return layout


def strip_code_markers(code: str) -> str:
"""Remove markdown code block markers if present."""
code = code.strip()
if code.startswith("```"):
# Remove opening backticks and optional language identifier
first_newline = code.find("\n")
if first_newline != -1:
code = code[first_newline:].strip()
if code.endswith("```"):
code = code[:-3].strip()
return code


def cfg_to_tree_struct(cfg, jou: Journal):
edges = list(get_edges(jou))
layout = normalize_layout(generate_layout(len(jou), edges))
Expand All @@ -52,7 +65,7 @@ def cfg_to_tree_struct(cfg, jou: Journal):
edges=edges,
layout=layout.tolist(),
plan=[textwrap.fill(n.plan, width=80) for n in jou.nodes],
code=[n.code for n in jou],
code=[strip_code_markers(n.code) for n in jou],
term_out=[n.term_out for n in jou],
analysis=[n.analysis for n in jou],
exp_name=cfg.exp_name,
Expand Down