Skip to content

Commit

Permalink
refactor: removed env var management functions, improved tests
Browse files Browse the repository at this point in the history
  • Loading branch information
maxmekiska committed Dec 21, 2023
1 parent 75af56e commit c389004
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 84 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,7 @@

### 0.0.12

- added support for Google's `gemini-pro` LLM
- added support for Google's `gemini-pro` LLM
- removed env var helper functions `reset_openai_key`, `reset_google_key` for OPENAI and GOOGLE ENV vars, please set `OPENAI_API_KEY` and `GOOGLE_API_KEY` on your system or manage vars via:
- os.environ["OPENAI_API_KEY"] = "..."
- os.environ["GOOGLE_API_KEY"] = "..."
5 changes: 0 additions & 5 deletions fukkatsu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import functools
import traceback

from fukkatsu.llm.googlegate import reset_google_key, set_google_key
from fukkatsu.llm.openaigate import reset_openai_key, set_openai_key
from fukkatsu.memory import SHORT_TERM_MEMORY
from fukkatsu.observer.tracker import track
from fukkatsu.utils.helper import (check_and_install_libraries,
Expand All @@ -16,9 +14,6 @@
return_source_code, sampler)
from fukkatsu.utils.synthesize import defibrillate, enhance, stalker, twin

set_openai_key()
set_google_key()


def resurrect(
lives: int = 1,
Expand Down
1 change: 1 addition & 0 deletions fukkatsu/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from fukkatsu.llm.googlegate import *
from fukkatsu.llm.openaigate import *
19 changes: 0 additions & 19 deletions fukkatsu/llm/googlegate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from dataclasses import dataclass
from typing import Optional

Expand All @@ -18,24 +17,6 @@ class GoogleGenerateContentConfig:
top_k: Optional[int]


def set_google_key():
track.warning("Setting GOOGLE_API_KEY")
try:
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
track.warning("GOOGLE_API_KEY found in environment variables.")
except:
track.error("GOOGLE_API_KEY not found in environment variables.")


def reset_google_key(key: str):
if type(key) != str:
track.error("Invalid Key format. GOOGLE_API_KEY not overwritten.")
raise Exception("Invalid Key format. GOOGLE_API_KEY not overwritten.")
else:
genai.configure(api_key=key)
track.warning("GOOGLE_API_KEY overwritten.")


def request_google_model(
set_prompt: str,
model: str = "gemini-pro",
Expand Down
19 changes: 0 additions & 19 deletions fukkatsu/llm/openaigate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from dataclasses import dataclass
from typing import Optional

Expand All @@ -16,24 +15,6 @@ class OpenaiChatCompletionConfig:
stop: Optional[str]


def set_openai_key():
track.warning("Setting OPENAI_API_KEY")
try:
openai.api_key = os.environ.get("OPENAI_API_KEY")
track.warning("OPENAI_API_KEY found in environment variables.")
except:
track.error("OPENAI_API_KEY not found in environment variables.")


def reset_openai_key(key: str):
if type(key) != str:
track.error("Invalid Key format. OPENAI_API_KEY not overwritten.")
raise Exception("Invalid Key format. OPENAI_API_KEY not overwritten.")
else:
openai.api_key = key
track.warning("OPENAI_API_KEY overwritten.")


def request_openai_model(
set_prompt: str,
model: str = "gpt-3.5-turbo",
Expand Down
50 changes: 50 additions & 0 deletions tests/test_llm/test_googlegate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest.mock import Mock, patch

import pytest

from fukkatsu.llm.googlegate import request_google_model


@pytest.fixture
def mock_generative_model():
with patch("fukkatsu.llm.googlegate.genai.GenerativeModel") as mock_model:
yield mock_model


@pytest.fixture
def mock_track_warning():
with patch("fukkatsu.observer.tracker.track.warning") as mock_warning:
yield mock_warning


def test_request_google_model(mock_generative_model, mock_track_warning):
set_prompt = "Test prompt"
model = "gemini-pro"
candidate_count = 1
stop_sequences = None
max_output_tokens = 1024
temperature = 0.1
top_p = None
top_k = None

mock_response = Mock(text="Test response")
mock_generative_model.return_value.generate_content.return_value = mock_response

result = request_google_model(
set_prompt=set_prompt,
model=model,
candidate_count=candidate_count,
stop_sequences=stop_sequences,
max_output_tokens=max_output_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
)

mock_track_warning.assert_called_once_with(
f"API REQUEST to {model} - Temperature: {temperature} - Max Tokens: {max_output_tokens} - candidate_count: {candidate_count} - Stop: {stop_sequences}"
)

mock_generative_model.assert_called_once_with(model)

assert result == "Test response"
88 changes: 48 additions & 40 deletions tests/test_llm/test_openaigate.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,57 @@
import io
import logging
import os
import sys
from unittest.mock import patch
from unittest.mock import Mock, patch

import openai
import pytest

from fukkatsu.llm.openaigate import reset_openai_key, set_openai_key
from fukkatsu.observer.tracker import track
from fukkatsu.llm.openaigate import request_openai_model


@pytest.fixture
def captured_output():
captured_output = io.StringIO()
sys.stdout = captured_output
yield captured_output
sys.stdout = sys.__stdout__
def mock_openai_create():

with patch("fukkatsu.llm.openaigate.openai.ChatCompletion.create") as mock_create:
yield mock_create

def test_set_openai_key_with_api_key():
with patch.dict(os.environ, {"OPENAI_API_KEY": "test_key"}):
set_openai_key()
assert openai.api_key == "test_key"
assert "OPENAI_API_KEY" in os.environ


def test_set_openai_key_without_api_key(captured_output):
handler = logging.StreamHandler(captured_output)
track.addHandler(handler)
with patch("os.environ.get") as import_module_mock:
import_module_mock.side_effect = Exception
set_openai_key()
output = captured_output.getvalue().strip()
assert "OPENAI_API_KEY not found" in output


def test_overwrite_openai_key():
with patch.dict(os.environ, {"OPENAI_API_KEY": "test_key"}):
reset_openai_key("new_key")
assert openai.api_key == "new_key"


def test_overwrite_openai_key_error():
with pytest.raises(
Exception, match="Invalid Key format. OPENAI_API_KEY not overwritten."
):
reset_openai_key(23)
@pytest.fixture
def mock_track_warning():
with patch("fukkatsu.observer.tracker.track.warning") as mock_warning:
yield mock_warning


def test_request_openai_model(mock_openai_create, mock_track_warning):
set_prompt = "Test prompt"
model = "gpt-3.5-turbo"
temperature = 0.1
max_tokens = 1024
n = 1
stop = None

mock_openai_response = {"choices": [{"message": {"content": "Test response"}}]}

mock_openai_create.return_value = mock_openai_response

result = request_openai_model(
set_prompt=set_prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
n=n,
stop=stop,
)

mock_track_warning.assert_called_once_with(
f"API REQUEST to {model} - Temperature: {temperature} - Max Tokens: {max_tokens} - N: {n} - Stop: {stop}"
)

mock_openai_create.assert_called_once_with(
model=model,
messages=[
{"role": "system", "content": set_prompt},
],
max_tokens=max_tokens,
n=n,
stop=stop,
temperature=temperature,
)

assert result == "Test response"

0 comments on commit c389004

Please sign in to comment.