Skip to content

Commit

Permalink
Extract openai API calls and retry at lowest level (Significant-Gravi…
Browse files Browse the repository at this point in the history
…tas#3696)

* Extract open ai api calls and retry at lowest level

* Forgot a test

* Gotta fix my local docker config so I can let pre-commit hooks run, ugh

* fix: merge artiface

* Fix linting

* Update memory.vector.utils

* feat: make sure resp exists

* fix: raise error message if created

* feat: rename file

* fix: partial test fix

* fix: update comments

* fix: linting

* fix: remove broken test

* fix: require a model to exist

* fix: BaseError issue

* fix: runtime error

* Fix mock response in test_make_agent

* add 429 as errors to retry

---------

Co-authored-by: k-boikov <64261260+k-boikov@users.noreply.github.com>
Co-authored-by: Nicholas Tindle <nick@ntindle.com>
Co-authored-by: Reinier van der Leer <github@pwuts.nl>
Co-authored-by: Nicholas Tindle <nicktindle@outlook.com>
Co-authored-by: Luke K (pr-0f3t) <2609441+lc0rp@users.noreply.github.com>
Co-authored-by: Merwane Hamadi <merwanehamadi@gmail.com>
  • Loading branch information
7 people authored Jun 14, 2023
1 parent 49d1a5a commit 6e6e7fc
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 245 deletions.
3 changes: 3 additions & 0 deletions autogpt/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def start_agent(name: str, task: str, prompt: str, agent: Agent, model=None) ->
first_message = f"""You are {name}. Respond with: "Acknowledged"."""
agent_intro = f"{voice_name} here, Reporting for duty!"

if model is None:
model = config.smart_llm_model

# Create agent
if agent.config.speak_mode:
say_text(agent_intro, 1)
Expand Down
49 changes: 1 addition & 48 deletions autogpt/llm/api_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import openai
from openai import Model

from autogpt.config import Config
from autogpt.llm.base import MessageDict
from autogpt.llm.modelsinfo import COSTS
from autogpt.logs import logger
from autogpt.singleton import Singleton
Expand All @@ -27,52 +25,7 @@ def reset(self):
self.total_budget = 0.0
self.models = None

def create_chat_completion(
self,
messages: list[MessageDict],
model: str | None = None,
temperature: float = None,
max_tokens: int | None = None,
deployment_id=None,
) -> str:
"""
Create a chat completion and update the cost.
Args:
messages (list): The list of messages to send to the API.
model (str): The model to use for the API call.
temperature (float): The temperature to use for the API call.
max_tokens (int): The maximum number of tokens for the API call.
Returns:
str: The AI's response.
"""
cfg = Config()
if temperature is None:
temperature = cfg.temperature
if deployment_id is not None:
response = openai.ChatCompletion.create(
deployment_id=deployment_id,
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=cfg.openai_api_key,
)
else:
response = openai.ChatCompletion.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
api_key=cfg.openai_api_key,
)
if not hasattr(response, "error"):
logger.debug(f"Response: {response}")
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
self.update_cost(prompt_tokens, completion_tokens, model)
return response

def update_cost(self, prompt_tokens, completion_tokens, model: str):
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Expand Down
3 changes: 3 additions & 0 deletions autogpt/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
MessageRole = Literal["system", "user", "assistant"]
MessageType = Literal["ai_response", "action_result"]

TText = list[int]
"""Token array representing tokenized text"""


class MessageDict(TypedDict):
role: MessageRole
Expand Down
178 changes: 177 additions & 1 deletion autogpt/llm/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
from autogpt.llm.base import ChatModelInfo, EmbeddingModelInfo, TextModelInfo
import functools
import time
from typing import List
from unittest.mock import patch

import openai
import openai.api_resources.abstract.engine_api_resource as engine_api_resource
from colorama import Fore, Style
from openai.error import APIError, RateLimitError, Timeout
from openai.openai_object import OpenAIObject

from autogpt.llm.api_manager import ApiManager
from autogpt.llm.base import (
ChatModelInfo,
EmbeddingModelInfo,
MessageDict,
TextModelInfo,
TText,
)
from autogpt.logs import logger

OPEN_AI_CHAT_MODELS = {
info.name: info
Expand Down Expand Up @@ -72,3 +91,160 @@
**OPEN_AI_TEXT_MODELS,
**OPEN_AI_EMBEDDING_MODELS,
}


def meter_api(func):
"""Adds ApiManager metering to functions which make OpenAI API calls"""
api_manager = ApiManager()

openai_obj_processor = openai.util.convert_to_openai_object

def update_usage_with_response(response: OpenAIObject):
try:
usage = response.usage
logger.debug(f"Reported usage from call to model {response.model}: {usage}")
api_manager.update_cost(
response.usage.prompt_tokens,
response.usage.completion_tokens if "completion_tokens" in usage else 0,
response.model,
)
except Exception as err:
logger.warn(f"Failed to update API costs: {err.__class__.__name__}: {err}")

def metering_wrapper(*args, **kwargs):
openai_obj = openai_obj_processor(*args, **kwargs)
if isinstance(openai_obj, OpenAIObject) and "usage" in openai_obj:
update_usage_with_response(openai_obj)
return openai_obj

def metered_func(*args, **kwargs):
with patch.object(
engine_api_resource.util,
"convert_to_openai_object",
side_effect=metering_wrapper,
):
return func(*args, **kwargs)

return metered_func


def retry_api(
num_retries: int = 10,
backoff_base: float = 2.0,
warn_user: bool = True,
):
"""Retry an OpenAI API call.
Args:
num_retries int: Number of retries. Defaults to 10.
backoff_base float: Base for exponential backoff. Defaults to 2.
warn_user bool: Whether to warn the user. Defaults to True.
"""
retry_limit_msg = f"{Fore.RED}Error: " f"Reached rate limit, passing...{Fore.RESET}"
api_key_error_msg = (
f"Please double check that you have setup a "
f"{Fore.CYAN + Style.BRIGHT}PAID{Style.RESET_ALL} OpenAI API Account. You can "
f"read more here: {Fore.CYAN}https://docs.agpt.co/setup/#getting-an-api-key{Fore.RESET}"
)
backoff_msg = (
f"{Fore.RED}Error: API Bad gateway. Waiting {{backoff}} seconds...{Fore.RESET}"
)

def _wrapper(func):
@functools.wraps(func)
def _wrapped(*args, **kwargs):
user_warned = not warn_user
num_attempts = num_retries + 1 # +1 for the first attempt
for attempt in range(1, num_attempts + 1):
try:
return func(*args, **kwargs)

except RateLimitError:
if attempt == num_attempts:
raise

logger.debug(retry_limit_msg)
if not user_warned:
logger.double_check(api_key_error_msg)
user_warned = True

except (APIError, Timeout) as e:
if (e.http_status not in [502, 429]) or (attempt == num_attempts):
raise

backoff = backoff_base ** (attempt + 2)
logger.debug(backoff_msg.format(backoff=backoff))
time.sleep(backoff)

return _wrapped

return _wrapper


@meter_api
@retry_api()
def create_chat_completion(
messages: List[MessageDict],
*_,
**kwargs,
) -> OpenAIObject:
"""Create a chat completion using the OpenAI API
Args:
messages: A list of messages to feed to the chatbot.
kwargs: Other arguments to pass to the OpenAI API chat completion call.
Returns:
OpenAIObject: The ChatCompletion response from OpenAI
"""
completion: OpenAIObject = openai.ChatCompletion.create(
messages=messages,
**kwargs,
)
if not hasattr(completion, "error"):
logger.debug(f"Response: {completion}")
return completion


@meter_api
@retry_api()
def create_text_completion(
prompt: str,
*_,
**kwargs,
) -> OpenAIObject:
"""Create a text completion using the OpenAI API
Args:
prompt: A text prompt to feed to the LLM
kwargs: Other arguments to pass to the OpenAI API text completion call.
Returns:
OpenAIObject: The Completion response from OpenAI
"""
return openai.Completion.create(
prompt=prompt,
**kwargs,
)


@meter_api
@retry_api()
def create_embedding(
input: str | TText | List[str] | List[TText],
*_,
**kwargs,
) -> OpenAIObject:
"""Create an embedding using the OpenAI API
Args:
input: The text to embed.
kwargs: Other arguments to pass to the OpenAI API embedding call.
Returns:
OpenAIObject: The Embedding response from OpenAI
"""
return openai.Embedding.create(
input=input,
**kwargs,
)
Loading

0 comments on commit 6e6e7fc

Please sign in to comment.