Skip to content

Commit

Permalink
🐛 better handling for function calling errors
Browse files Browse the repository at this point in the history
* remove the supported model lists
* update error handling #43
* update code block display for html
  • Loading branch information
dexhunter committed Jan 30, 2025
1 parent 51d09b7 commit f824c1b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 57 deletions.
112 changes: 56 additions & 56 deletions aide/backend/backend_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,13 @@
openai.InternalServerError,
)

# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models
SUPPORTED_FUNCTION_CALL_MODELS = {
"gpt-4o",
"gpt-4o-2024-08-06",
"gpt-4o-2024-05-13",
"gpt-4o-mini",
"gpt-4o-mini-2024-07-18",
"gpt-4-turbo",
"gpt-4-turbo-2024-04-09",
"gpt-4-turbo-preview",
"gpt-4-0125-preview",
"gpt-4-1106-preview",
"gpt-3.5-turbo",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
}


@once
def _setup_openai_client():
global _client
_client = openai.OpenAI(max_retries=0)


def is_function_call_supported(model_name: str) -> bool:
"""Return True if the model supports function calling."""
return model_name in SUPPORTED_FUNCTION_CALL_MODELS


def query(
system_message: str | None,
user_message: str | None,
Expand All @@ -56,64 +34,86 @@ def query(
) -> tuple[OutputType, float, int, int, dict]:
"""
Query the OpenAI API, optionally with function calling.
Function calling support is only checked for feedback/review operations.
If the model doesn't support function calling, gracefully degrade to text generation.
"""
_setup_openai_client()
filtered_kwargs: dict = select_values(notnone, model_kwargs)
model_name = filtered_kwargs.get("model", "")
logger.debug(f"OpenAI query called with model='{model_name}'")

# Convert system/user messages to the format required by the client
messages = opt_messages_to_list(system_message, user_message)

# If function calling is requested, attach the function spec
if func_spec is not None:
# Only check function call support for feedback/search operations
if func_spec.name == "submit_review":
if not is_function_call_supported(model_name):
logger.warning(
f"Review function calling was requested, but model '{model_name}' "
"does not support function calling. Falling back to plain text generation."
)
filtered_kwargs.pop("tools", None)
filtered_kwargs.pop("tool_choice", None)
else:
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict
filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict]
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict

completion = None
t0 = time.time()
completion = backoff_create(
_client.chat.completions.create,
OPENAI_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
req_time = time.time() - t0

# Attempt the API call
try:
completion = backoff_create(
_client.chat.completions.create,
OPENAI_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
except openai.error.InvalidRequestError as e:
# Check whether the error indicates that function calling is not supported
if "function calling" in str(e).lower() or "tools" in str(e).lower():
logger.warning(
"Function calling was attempted but is not supported by this model. "
"Falling back to plain text generation."
)
# Remove function-calling parameters and retry
filtered_kwargs.pop("tools", None)
filtered_kwargs.pop("tool_choice", None)

# Retry without function calling
completion = backoff_create(
_client.chat.completions.create,
OPENAI_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
else:
# If it's some other error, re-raise
raise

req_time = time.time() - t0
choice = completion.choices[0]

# Decide how to parse the response
if func_spec is None or "tools" not in filtered_kwargs:
# No function calling was ultimately used
output = choice.message.content
else:
# Attempt to extract tool calls
tool_calls = getattr(choice.message, "tool_calls", None)

if not tool_calls:
logger.warning(
f"No function call used despite function spec. Fallback to text. "
"No function call was used despite function spec. Fallback to text.\n"
f"Message content: {choice.message.content}"
)
output = choice.message.content
else:
first_call = tool_calls[0]
assert first_call.function.name == func_spec.name, (
f"Function name mismatch: expected {func_spec.name}, "
f"got {first_call.function.name}"
)
try:
output = json.loads(first_call.function.arguments)
except json.JSONDecodeError as e:
logger.error(
f"Error decoding function arguments:\n{first_call.function.arguments}"
# Optional: verify that the function name matches
if first_call.function.name != func_spec.name:
logger.warning(
f"Function name mismatch: expected {func_spec.name}, "
f"got {first_call.function.name}. Fallback to text."
)
raise e
output = choice.message.content
else:
try:
output = json.loads(first_call.function.arguments)
except json.JSONDecodeError as ex:
logger.error(
"Error decoding function arguments:\n"
f"{first_call.function.arguments}"
)
raise ex

in_tokens = completion.usage.prompt_tokens
out_tokens = completion.usage.completion_tokens
Expand Down
15 changes: 14 additions & 1 deletion aide/utils/tree_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def normalize_layout(layout: np.ndarray):
return layout


def strip_code_markers(code: str) -> str:
"""Remove markdown code block markers if present."""
code = code.strip()
if code.startswith("```"):
# Remove opening backticks and optional language identifier
first_newline = code.find("\n")
if first_newline != -1:
code = code[first_newline:].strip()
if code.endswith("```"):
code = code[:-3].strip()
return code


def cfg_to_tree_struct(cfg, jou: Journal):
edges = list(get_edges(jou))
layout = normalize_layout(generate_layout(len(jou), edges))
Expand All @@ -52,7 +65,7 @@ def cfg_to_tree_struct(cfg, jou: Journal):
edges=edges,
layout=layout.tolist(),
plan=[textwrap.fill(n.plan, width=80) for n in jou.nodes],
code=[n.code for n in jou],
code=[strip_code_markers(n.code) for n in jou],
term_out=[n.term_out for n in jou],
analysis=[n.analysis for n in jou],
exp_name=cfg.exp_name,
Expand Down

0 comments on commit f824c1b

Please sign in to comment.