-
Notifications
You must be signed in to change notification settings - Fork 546
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .openai_rv import openai | ||
from .transformers_rv import transformers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import os | ||
from typing import Callable, Dict, List, Union | ||
|
||
import numpy as np | ||
|
||
|
||
class OpenAIRV: | ||
"""Represents a token random variable defined by an OpenAI model.""" | ||
|
||
def __init__(self, model_name: str): | ||
import tiktoken | ||
|
||
self.model_name = model_name | ||
self.tokenizer = tiktoken.encoding_for_model(model_name) | ||
|
||
if "text-" in model_name: | ||
self.call_api = call_completion_api | ||
self.format_prompt = lambda x: x | ||
self.extract_choice = lambda x: x["text"] | ||
elif "gpt-" in model_name: | ||
self.call_api = call_chat_completion_api | ||
self.format_prompt = lambda x: {"role": "user", "content": x[0]} | ||
self.extract_choice = lambda x: x["message"]["content"] | ||
else: | ||
raise NameError( | ||
f"The model {model_name} requested is not available. Only the completion and chat completion models are available for OpenAI." | ||
) | ||
|
||
async def __call__(self, input_ids: Union[str, List[str]], samples: int = 1): | ||
prompt = self.tokenizer.decode_batch(input_ids) | ||
response = await self.call_api( | ||
self.model_name, self.format_prompt(prompt), 1, 1.0, [], {}, samples | ||
) | ||
|
||
results = [self.extract_choice(choice) for choice in response["choices"]] | ||
token_ids = np.array(self.tokenizer.encode_batch(results)).reshape( | ||
len(input_ids), samples, -1 | ||
) | ||
|
||
return token_ids.squeeze() | ||
|
||
|
||
def openai(model_name: str): | ||
return OpenAIRV(model_name) | ||
|
||
|
||
def error_handler(api_call_fn: Callable) -> Callable: | ||
"""Handle OpenAI API errors and missing API key.""" | ||
|
||
def call(*args, **kwargs): | ||
try: | ||
os.environ["OPENAI_API_KEY"] | ||
except KeyError: | ||
raise OSError( | ||
"Could not find the `OPENAI_API_KEY` environment variable, which is necessary to call " | ||
"OpenAI's APIs. Please make sure it is set before re-running your model." | ||
) | ||
|
||
try: | ||
return api_call_fn(*args, **kwargs) | ||
except ( | ||
openai.error.RateLimitError, | ||
openai.error.Timeout, | ||
openai.error.TryAgain, | ||
openai.error.APIConnectionError, | ||
openai.error.ServiceUnavailableError, | ||
) as e: | ||
raise OSError(f"Could not connect to the OpenAI API: {e}") | ||
except ( | ||
openai.error.AuthenticationError, | ||
openai.error.PermissionError, | ||
openai.error.InvalidRequestError, | ||
openai.error.InvalidAPIType, | ||
) as e: | ||
raise e | ||
|
||
return call | ||
|
||
|
||
@error_handler | ||
async def call_completion_api( | ||
model: str, | ||
prompt: str, | ||
max_tokens: int, | ||
temperature: float, | ||
stop_sequences: List[str], | ||
logit_bias: Dict[str, int], | ||
num_samples: int, | ||
): | ||
import openai | ||
|
||
response = await openai.Completion.acreate( | ||
engine=model, | ||
prompt=prompt, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
stop=list(stop_sequences) if len(stop_sequences) > 0 else None, | ||
logit_bias=logit_bias, | ||
n=int(num_samples), | ||
) | ||
|
||
return response | ||
|
||
|
||
@error_handler | ||
async def call_chat_completion_api( | ||
model: str, | ||
messages: List[Dict[str, str]], | ||
max_tokens: int, | ||
temperature: float, | ||
stop_sequences: List[str], | ||
logit_bias: Dict[str, int], | ||
num_samples: int, | ||
): | ||
import openai | ||
|
||
response = await openai.ChatCompletion.acreate( | ||
model=model, | ||
messages=messages, | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
stop=list(stop_sequences) if len(stop_sequences) > 0 else None, | ||
logit_bias=logit_bias, | ||
n=int(num_samples), | ||
) | ||
|
||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ test = [ | |
"diffusers", | ||
"pre-commit", | ||
"pytest", | ||
"tiktoken", | ||
"torch", | ||
"transformers" | ||
] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import asyncio | ||
import itertools | ||
|
||
import pytest | ||
import tiktoken | ||
from numpy.testing import assert_array_equal | ||
|
||
import outlines.text.random as random | ||
from outlines.text.random.openai_rv import OpenAIRV | ||
|
||
|
||
async def mock_completion_api_call( | ||
model, prompts, max_tokens, temperature, stop_sequence, logit_bias, num_samples | ||
): | ||
"""Mock completion API call. | ||
The returned dictionary was copied from the OpenAI API reference | ||
at https://platform.openai.com/docs/api-reference/completions/create | ||
on 06/09/2023. | ||
""" | ||
choices = [ | ||
{"text": f"{p}{s}", "index": 0, "logprobs": None, "finish_reason": "length"} | ||
for p, s in itertools.product(prompts, range(num_samples)) | ||
] | ||
return { | ||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", | ||
"object": "text_completion", | ||
"created": 1589478378, | ||
"model": "text-davinci-003", | ||
"choices": choices, | ||
"usage": {"prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12}, | ||
} | ||
|
||
|
||
async def mock_chat_completion_api_call( | ||
model, message, max_tokens, temperature, stop_sequence, logit_bias, num_samples | ||
): | ||
"""Mock completion API call. | ||
The returned dictionary was copied from the OpenAI API reference | ||
at https://platform.openai.com/docs/api-reference/completions/create | ||
on 06/09/2023. | ||
""" | ||
prompt = message["content"] | ||
choices = [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": f"{p}{s}", | ||
}, | ||
"finish_reason": "stop", | ||
} | ||
for p, s in itertools.product(prompt, range(num_samples)) | ||
] | ||
return { | ||
"id": "chatcmpl-123", | ||
"object": "chat.completion", | ||
"created": 1677652288, | ||
"choices": choices, | ||
"usage": {"prompt_tokens": 9, "completion_tokens": 12, "total_tokens": 21}, | ||
} | ||
|
||
|
||
def test_wrong_name(): | ||
with pytest.raises(KeyError): | ||
OpenAIRV("davinci-003") | ||
|
||
|
||
def test_completion(): | ||
rv = random.openai("text-davinci-003") | ||
assert isinstance(rv, OpenAIRV) | ||
|
||
tokenizer = tiktoken.encoding_for_model("text-davinci-003") | ||
rv.call_api = mock_completion_api_call | ||
|
||
prompt = "A" | ||
input_ids = tokenizer.encode_batch(prompt) | ||
result = asyncio.run(rv(input_ids)) | ||
assert result.ndim == 1 | ||
assert_array_equal(result, tokenizer.encode("A0")) | ||
|
||
result = asyncio.run(rv(input_ids, samples=3)) | ||
assert result.shape[0] == 3 | ||
assert_array_equal(result, tokenizer.encode_batch(["A0", "A1", "A2"])) | ||
|
||
|
||
def test_completion_list(): | ||
rv = random.openai("text-davinci-003") | ||
assert isinstance(rv, OpenAIRV) | ||
|
||
tokenizer = tiktoken.encoding_for_model("text-davinci-003") | ||
rv.call_api = mock_completion_api_call | ||
|
||
prompts = ["A", "B"] | ||
input_ids = tokenizer.encode_batch(prompts) | ||
|
||
result = asyncio.run(rv(input_ids)) | ||
assert result.shape[0] == 2 | ||
assert_array_equal(result.reshape(2, -1), tokenizer.encode_batch(["A0", "B0"])) | ||
|
||
result = asyncio.run(rv(input_ids, samples=3)) | ||
assert result.shape[0] == 2 | ||
assert result.shape[1] == 3 | ||
assert_array_equal( | ||
result.reshape(6, -1), | ||
tokenizer.encode_batch(["A0", "A1", "A2", "B0", "B1", "B2"]), | ||
) | ||
|
||
|
||
def test_chat_completion(): | ||
rv = random.openai("gpt-3.5-turbo") | ||
assert isinstance(rv, OpenAIRV) | ||
|
||
tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") | ||
rv.call_api = mock_chat_completion_api_call | ||
|
||
prompt = "A" | ||
input_ids = tokenizer.encode_batch(prompt) | ||
result = asyncio.run(rv(input_ids)) | ||
assert_array_equal(result, tokenizer.encode("A0")) | ||
|
||
result = asyncio.run(rv(input_ids, samples=3)) | ||
assert_array_equal(result, tokenizer.encode_batch(["A0", "A1", "A2"])) |