Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add comprehensive rate limit handling across API providers #161

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 71 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,53 @@ pip install -r requirements.txt

### Supported Models and API Keys

We support a wide variety of models, including open-weight and API-only models. In general, we recommend using only frontier models above the capability of the original GPT-4. To see a full list of supported models, see [here](https://github.com/SakanaAI/AI-Scientist/blob/main/ai_scientist/llm.py).
We support a wide variety of models, including open-weight and API-only models. In general, we recommend using only frontier models above the capability of the original GPT-4. Below is a comprehensive list of supported models and their variants.

#### OpenAI API (GPT-4o, GPT-4o-mini, o1 models)
## Available Models

By default, this uses the `OPENAI_API_KEY` environment variable.
AI-Scientist supports multiple model providers and variants:

#### Anthropic API (Claude Sonnet 3.5)
### Claude Models
- Claude 3.5 Sonnet (via Anthropic API)
- Claude 3.5 Sonnet (via Amazon Bedrock)
- Claude 3.5 Sonnet (via Vertex AI)

### GPT Models
- GPT-4o and variants (via OpenAI API)
- GPT-4o-mini and variants (via OpenAI API)
- o1 models and variants (via OpenAI API)

### LLaMa Models
- LLaMa 3.3 70B (via OpenRouter API)
- LLaMa 3.3 70B Local (via Ollama)
- LLaMa 3.2 1B Local (via Ollama, for resource-constrained environments)
- LLaMa 3.1 8B Local (via Ollama, optimized for segmented templates)

### Additional Models
- Gemini Pro (via Google Cloud)
- Grok-1 (via xAI)
- DeepSeek Coder V2 (via DeepSeek API)

## Model Performance and Template Compatibility

### Performance Tiers
- Tier 1 (Full Capability): LLaMa 3.3, GPT-4o, Claude 3.5
- Tier 2 (Standard): LLaMa 3.1, GPT-3.5
- Tier 3 (Resource-Constrained): LLaMa 3.2 1B

### Template Formats
AI-Scientist supports two template editing modes:
- **Diff Mode**: Default for high-capability models (Tier 1)
- **Whole Mode**: Optimized for resource-constrained models (Tier 2 & 3)

### Template Segmentation
For improved compatibility with resource-constrained models:
- Segmented templates split papers into manageable sections
- Recommended for LLaMa 3.1 8B and LLaMa 3.2 1B
- Helps prevent edit mode termination issues
- Improves reliability for paper generation tasks

For detailed configuration of each model type, see the sections below.

By default, this uses the `ANTHROPIC_API_KEY` environment variable.

Expand Down Expand Up @@ -121,9 +161,34 @@ export VERTEXAI_PROJECT="PROJECT_ID" # for Aider/LiteLLM call
#### DeepSeek API (deepseek-chat, deepseek-reasoner)
By default, this uses the `DEEPSEEK_API_KEY` environment variable.

#### OpenRouter API (Llama3.1)
#### OpenRouter API (LLaMa Models)

By default, this uses the `OPENROUTER_API_KEY` environment variable. Supported models:
- LLaMa 3.3 70B: High-performance model suitable for complex research tasks
- LLaMa 3.1: Mid-tier model for general research tasks

#### Local Models via Ollama

For local model execution without API keys, AI-Scientist supports running models through Ollama:

1. Install Ollama:
```bash
curl https://ollama.ai/install.sh | sh
```

2. Pull the LLaMa model:
```bash
ollama pull llama2
```

3. Start the Ollama server:
```bash
ollama serve
```

4. Use the local model by specifying "llama3.3-70b-local" as the model identifier in your experiments.

By default, this uses the `OPENROUTER_API_KEY` environment variable.
Note: Local model performance may vary based on your system's resources. The Ollama server provides an OpenAI-compatible endpoint at `http://localhost:11434/v1`.

#### Google Gemini
We support Google Gemini models (e.g., "gemini-1.5-flash", "gemini-1.5-pro") via the [google-generativeai](https://pypi.org/project/google-generativeai) Python library. By default, it uses the environment variable:
Expand Down
85 changes: 66 additions & 19 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import requests

from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
from ai_scientist.rate_limit import rate_limiter

S2_API_KEY = os.getenv("S2_API_KEY")

Expand Down Expand Up @@ -130,9 +131,18 @@ def generate_ideas(
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
print(json_output)
try:
json_output = extract_json_between_markers(text)
if json_output is None:
print("Failed to extract JSON from LLM output")
continue
print(json_output)
except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

# Iteratively improve task.
if num_reflections > 1:
Expand All @@ -148,11 +158,18 @@ def generate_ideas(
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert (
json_output is not None
), "Failed to extract JSON from LLM output"
print(json_output)
try:
json_output = extract_json_between_markers(text)
if json_output is None:
print("Failed to extract JSON from LLM output")
continue
print(json_output)
except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

if "I am done" in text:
print(f"Idea generation converged after {j + 2} iterations.")
Expand Down Expand Up @@ -229,9 +246,18 @@ def generate_next_idea(
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"
print(json_output)
try:
json_output = extract_json_between_markers(text)
if json_output is None:
print("Failed to extract JSON from LLM output")
continue
print(json_output)
except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

# Iteratively improve task.
if num_reflections > 1:
Expand All @@ -247,11 +273,18 @@ def generate_next_idea(
msg_history=msg_history,
)
## PARSE OUTPUT
json_output = extract_json_between_markers(text)
assert (
json_output is not None
), "Failed to extract JSON from LLM output"
print(json_output)
try:
json_output = extract_json_between_markers(text)
if json_output is None:
print("Failed to extract JSON from LLM output")
continue
print(json_output)
except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

if "I am done" in text:
print(
Expand Down Expand Up @@ -280,9 +313,13 @@ def on_backoff(details):


@backoff.on_exception(
backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
backoff.expo,
requests.exceptions.HTTPError,
on_backoff=on_backoff
)
@rate_limiter.handle_rate_limit("semantic_scholar") # Add rate limiting for Semantic Scholar
def search_for_papers(query, result_limit=10, engine="semanticscholar") -> Union[None, List[Dict]]:

if not query:
return None
if engine == "semanticscholar":
Expand Down Expand Up @@ -454,6 +491,7 @@ def check_idea_novelty(
break

## PARSE OUTPUT

json_output = extract_json_between_markers(text)
assert json_output is not None, "Failed to extract JSON from LLM output"

Expand All @@ -475,8 +513,17 @@ def check_idea_novelty(
cites=paper["citationCount"],
abstract=paper["abstract"],
)
)
papers_str = "\n\n".join(paper_strings)
papers_str = "\n\n".join(paper_strings)

except ValueError as e:
print(f"Error extracting JSON: {e}")
continue
except KeyError as e:
print(f"Missing required field in JSON: {e}")
continue
except Exception as e:
print(f"Unexpected error while extracting JSON: {e}")
continue

except Exception as e:
print(f"Error: {e}")
Expand Down
57 changes: 50 additions & 7 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import anthropic
import backoff
import openai
from ai_scientist.rate_limit import rate_limiter
import google.generativeai as genai
from google.generativeai.types import GenerationConfig

Expand All @@ -20,6 +21,13 @@
"gpt-4o-2024-08-06",
"o1-preview-2024-09-12",
"o1-mini-2024-09-12",
"deepseek-r1",
"llama3.3-70b",
"llama3.3-70b-local",
"llama3.2:1b",
"llama3.1:8b", # New viable option with segmented templates
"gemini-pro",
"grok-3",
"o1-2024-12-17",
# OpenRouter models
"llama3.1-405b",
Expand All @@ -44,8 +52,31 @@
"gemini-1.5-pro",
]

class Model:
def __init__(self, model_name, system_message="You are a helpful AI assistant."):
self.model_name = model_name
self.system_message = system_message
self.client, self.client_model = create_client(model_name)
self.msg_history = []
# Determine edit format based on model capabilities
self.edit_format = "whole" if model_name in ["llama3.1:8b", "llama3.2:1b"] else "diff"

@rate_limiter.handle_rate_limit(lambda self: self.model_name)
def get_response(self, msg, temperature=0.75, print_debug=False):
content, self.msg_history = get_response_from_llm(
msg=msg,
client=self.client,
model=self.model_name,
system_message=self.system_message,
print_debug=print_debug,
msg_history=self.msg_history,
temperature=temperature,
edit_format=self.edit_format # Pass edit format to get_response_from_llm
)
return content

# Get N responses from a single message, used for ensembling.
@rate_limiter.handle_rate_limit(lambda args: args[2])
@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
def get_batch_responses_from_llm(
msg,
Expand Down Expand Up @@ -74,7 +105,7 @@ def get_batch_responses_from_llm(
],
temperature=temperature,
max_tokens=MAX_NUM_TOKENS,
n=n_responses,
n=n_responses, # Fix parameter position
stop=None,
seed=0,
)
Expand All @@ -92,7 +123,7 @@ def get_batch_responses_from_llm(
],
temperature=temperature,
max_tokens=MAX_NUM_TOKENS,
n=n_responses,
n_responses,
stop=None,
)
content = [r.message.content for r in response.choices]
Expand Down Expand Up @@ -126,6 +157,7 @@ def get_batch_responses_from_llm(
return content, new_msg_history


@rate_limiter.handle_rate_limit(lambda args: args[2])
@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
def get_response_from_llm(
msg,
Expand All @@ -135,7 +167,12 @@ def get_response_from_llm(
print_debug=False,
msg_history=None,
temperature=0.75,
edit_format="diff" # Default to diff mode for stronger models
):
# Use "whole" mode for weaker models that benefit from segmented templates
if model in ["llama3.1:8b", "llama3.2:3b", "llama3.2:1b"]:
edit_format = "whole"

if msg_history is None:
msg_history = []

Expand Down Expand Up @@ -190,7 +227,7 @@ def get_response_from_llm(
)
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif model in ["o1-preview-2024-09-12", "o1-mini-2024-09-12"]:
elif model == "gemini-pro":
new_msg_history = msg_history + [{"role": "user", "content": msg}]
response = client.chat.completions.create(
model=model,
Expand All @@ -203,7 +240,7 @@ def get_response_from_llm(
n=1,
seed=0,
)
content = response.choices[0].message.content
content = response.text
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
new_msg_history = msg_history + [{"role": "user", "content": msg}]
Expand All @@ -220,8 +257,14 @@ def get_response_from_llm(
)
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif model in ["deepseek-chat", "deepseek-coder"]:
elif model in ["llama3.3-70b", "llama3.3-70b-local", "llama3.2:1b", "llama3.1:8b", "deepseek-r1"]:
new_msg_history = msg_history + [{"role": "user", "content": msg}]
model_name = {
"llama3.3-70b": "meta-llama/llama-3.3-70b-instruct",
"llama3.3-70b-local": "llama2",
"llama3.2:1b": "llama3.2:1b",
"llama3.1:8b": "llama3.1:8b"
}[model]
response = client.chat.completions.create(
model=model,
messages=[
Expand Down Expand Up @@ -330,8 +373,8 @@ def create_client(model):
api_key=os.environ["DEEPSEEK_API_KEY"],
base_url="https://api.deepseek.com"
), model
elif model == "llama3.1-405b":
print(f"Using OpenAI API with {model}.")
elif model == "llama3.3-70b":
print(f"Using OpenRouter API with {model}.")
return openai.OpenAI(
api_key=os.environ["OPENROUTER_API_KEY"],
base_url="https://openrouter.ai/api/v1"
Expand Down
Loading