diff --git a/aide/backend/backend_openai.py b/aide/backend/backend_openai.py index db76938..2b00f8d 100644 --- a/aide/backend/backend_openai.py +++ b/aide/backend/backend_openai.py @@ -1,5 +1,3 @@ -"""Backend for OpenAI API.""" - import json import logging import time @@ -19,23 +17,6 @@ openai.InternalServerError, ) -# (docs) https://platform.openai.com/docs/guides/function-calling/supported-models -SUPPORTED_FUNCTION_CALL_MODELS = { - "gpt-4o", - "gpt-4o-2024-08-06", - "gpt-4o-2024-05-13", - "gpt-4o-mini", - "gpt-4o-mini-2024-07-18", - "gpt-4-turbo", - "gpt-4-turbo-2024-04-09", - "gpt-4-turbo-preview", - "gpt-4-0125-preview", - "gpt-4-1106-preview", - "gpt-3.5-turbo", - "gpt-3.5-turbo-0125", - "gpt-3.5-turbo-1106", -} - @once def _setup_openai_client(): @@ -43,11 +24,6 @@ def _setup_openai_client(): _client = openai.OpenAI(max_retries=0) -def is_function_call_supported(model_name: str) -> bool: - """Return True if the model supports function calling.""" - return model_name in SUPPORTED_FUNCTION_CALL_MODELS - - def query( system_message: str | None, user_message: str | None, @@ -56,64 +32,88 @@ def query( ) -> tuple[OutputType, float, int, int, dict]: """ Query the OpenAI API, optionally with function calling. - Function calling support is only checked for feedback/review operations. + If the model doesn't support function calling, gracefully degrade to text generation. """ _setup_openai_client() filtered_kwargs: dict = select_values(notnone, model_kwargs) - model_name = filtered_kwargs.get("model", "") - logger.debug(f"OpenAI query called with model='{model_name}'") + # Convert system/user messages to the format required by the client messages = opt_messages_to_list(system_message, user_message) + # If function calling is requested, attach the function spec if func_spec is not None: - # Only check function call support for feedback/search operations - if func_spec.name == "submit_review": - if not is_function_call_supported(model_name): - logger.warning( - f"Review function calling was requested, but model '{model_name}' " - "does not support function calling. Falling back to plain text generation." - ) - filtered_kwargs.pop("tools", None) - filtered_kwargs.pop("tool_choice", None) - else: - filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict] - filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict + filtered_kwargs["tools"] = [func_spec.as_openai_tool_dict] + filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict + completion = None t0 = time.time() - completion = backoff_create( - _client.chat.completions.create, - OPENAI_TIMEOUT_EXCEPTIONS, - messages=messages, - **filtered_kwargs, - ) - req_time = time.time() - t0 + # Attempt the API call + try: + completion = backoff_create( + _client.chat.completions.create, + OPENAI_TIMEOUT_EXCEPTIONS, + messages=messages, + **filtered_kwargs, + ) + except openai.error.InvalidRequestError as e: + # Check whether the error indicates that function calling is not supported + # Different language may appear here depending on the OpenAI error messages, + # so adjust the substring check as necessary. + if "function calling" in str(e).lower() or "tools" in str(e).lower(): + logger.warning( + "Function calling was attempted but is not supported by this model. " + "Falling back to plain text generation." + ) + # Remove function-calling parameters and retry + filtered_kwargs.pop("tools", None) + filtered_kwargs.pop("tool_choice", None) + + # Retry without function calling + completion = backoff_create( + _client.chat.completions.create, + OPENAI_TIMEOUT_EXCEPTIONS, + messages=messages, + **filtered_kwargs, + ) + else: + # If it's some other error, re-raise + raise + + req_time = time.time() - t0 choice = completion.choices[0] + # Decide how to parse the response if func_spec is None or "tools" not in filtered_kwargs: + # No function calling was ultimately used output = choice.message.content else: + # Attempt to extract tool calls tool_calls = getattr(choice.message, "tool_calls", None) - if not tool_calls: logger.warning( - f"No function call used despite function spec. Fallback to text. " + "No function call was used despite function spec. Fallback to text.\n" f"Message content: {choice.message.content}" ) output = choice.message.content else: first_call = tool_calls[0] - assert first_call.function.name == func_spec.name, ( - f"Function name mismatch: expected {func_spec.name}, " - f"got {first_call.function.name}" - ) - try: - output = json.loads(first_call.function.arguments) - except json.JSONDecodeError as e: - logger.error( - f"Error decoding function arguments:\n{first_call.function.arguments}" + # Optional: verify that the function name matches + if first_call.function.name != func_spec.name: + logger.warning( + f"Function name mismatch: expected {func_spec.name}, " + f"got {first_call.function.name}. Fallback to text." ) - raise e + output = choice.message.content + else: + try: + output = json.loads(first_call.function.arguments) + except json.JSONDecodeError as ex: + logger.error( + "Error decoding function arguments:\n" + f"{first_call.function.arguments}" + ) + raise ex in_tokens = completion.usage.prompt_tokens out_tokens = completion.usage.completion_tokens diff --git a/aide/utils/tree_export.py b/aide/utils/tree_export.py index 918c889..a4bfb85 100644 --- a/aide/utils/tree_export.py +++ b/aide/utils/tree_export.py @@ -38,6 +38,19 @@ def normalize_layout(layout: np.ndarray): return layout +def strip_code_markers(code: str) -> str: + """Remove markdown code block markers if present.""" + code = code.strip() + if code.startswith("```"): + # Remove opening backticks and optional language identifier + first_newline = code.find("\n") + if first_newline != -1: + code = code[first_newline:].strip() + if code.endswith("```"): + code = code[:-3].strip() + return code + + def cfg_to_tree_struct(cfg, jou: Journal): edges = list(get_edges(jou)) layout = normalize_layout(generate_layout(len(jou), edges)) @@ -52,7 +65,7 @@ def cfg_to_tree_struct(cfg, jou: Journal): edges=edges, layout=layout.tolist(), plan=[textwrap.fill(n.plan, width=80) for n in jou.nodes], - code=[n.code for n in jou], + code=[strip_code_markers(n.code) for n in jou], term_out=[n.term_out for n in jou], analysis=[n.analysis for n in jou], exp_name=cfg.exp_name,