Skip to content

Commit

Permalink
GH-26 add meeting.py with Meeting and Transcript
Browse files Browse the repository at this point in the history
Add client factory
  • Loading branch information
lynxrv21 committed Nov 2, 2023
1 parent 21407fb commit 3401a4d
Show file tree
Hide file tree
Showing 17 changed files with 253 additions and 128 deletions.
7 changes: 5 additions & 2 deletions .github/workflows/style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand All @@ -20,6 +20,9 @@ jobs:
python -m pip install --upgrade pip
pip install poetry
poetry install
- name: Analysing the code with pylint
- name: Analysing the code
run: |
make style
- name: Run tests
run: |
make test
6 changes: 3 additions & 3 deletions dream_team_gpt/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from textwrap import dedent
from typing import Callable

from dream_team_gpt.clients.base import AIClient
from dream_team_gpt.constants import NO_COMMENT

DEFAULT_SYSTEM_PROMPT = dedent(
Expand All @@ -16,14 +16,14 @@
class Agent:
def __init__(
self,
client: AIClient,
client_factory: Callable,
name: str,
user_prompt: str,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
):
self.name = name

self.client = client
self.client = client_factory()
self.system_prompt = system_prompt
self.user_prompt = user_prompt
self.client.common_instructions = system_prompt
Expand Down
16 changes: 8 additions & 8 deletions dream_team_gpt/agents/chairman.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from typing import Callable

from loguru import logger

from dream_team_gpt.agents.agent import Agent
from dream_team_gpt.agents.sme import SME
from dream_team_gpt.clients.base import AIClient


class Chairman(Agent):
def __init__(self, client: AIClient, executives: list[SME], name: str = "Chairman"):
def __init__(self, client_factory: Callable, executives: list[SME], name: str = "Chairman"):
# Construct the user_prompt string with details of the executives
self.user_prompt = self.update_user_prompt(executives)

system_prompt = f"Answer with only the name and nothing else."

# Call the superclass constructor with the constructed user_prompt
super().__init__(client, name, self.user_prompt, system_prompt)
super().__init__(client_factory, name, self.user_prompt, system_prompt)

self.executives = executives

def update_user_prompt(self, SMEs: list[SME]) -> str:
@staticmethod
def update_user_prompt(SMEs: list[SME]) -> str:
frequency_info_list = []
for sme in SMEs:
frequency_info_list.append(
Expand All @@ -31,12 +33,10 @@ def update_user_prompt(self, SMEs: list[SME]) -> str:
f"Participants:\n{''.join(frequency_info_list)} "
)

def decide_if_meeting_over(self, transcript: list) -> bool:
def decide_if_meeting_over(self, transcript: str) -> bool:
return False

def decide_next_speaker(self, transcript_list: list[str]) -> SME:
transcript = " ".join(transcript_list)

def decide_next_speaker(self, transcript: str) -> SME:
while True:
next_speaker = self.query_gpt(transcript).strip().rstrip(".")
logger.info(f"Chairman called speaker: {next_speaker}")
Expand Down
6 changes: 3 additions & 3 deletions dream_team_gpt/agents/idea_refiner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from textwrap import dedent
from typing import Callable

from dream_team_gpt.agents.agent import Agent
from dream_team_gpt.clients import AIClient

REFINER_PROMPT = dedent(
"""\
Expand All @@ -14,9 +14,9 @@


class IdeaRefiner(Agent):
def __init__(self, client: AIClient, name: str = "Refiner"):
def __init__(self, client_factory: Callable, name: str = "Refiner"):
# Call the superclass constructor with the constructed user_prompt
super().__init__(client, name, REFINER_PROMPT)
super().__init__(client_factory, name, REFINER_PROMPT)

def refine_idea(self, idea: str) -> str:
return self.query_gpt(idea)
9 changes: 4 additions & 5 deletions dream_team_gpt/agents/sme.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from textwrap import dedent
from typing import Callable

from dream_team_gpt.agents.agent import Agent
from dream_team_gpt.clients.base import AIClient

USER_PROMPT_TEMPLATE = dedent(
"""\
Expand All @@ -18,18 +18,17 @@


class SME(Agent):
def __init__(self, client: AIClient, name: str, expertise: str, concerns: list[str]):
def __init__(self, client_factory: Callable, name: str, expertise: str, concerns: list[str]):
# Construct the user_prompt string
user_prompt = USER_PROMPT_TEMPLATE.format(
name=name, expertise=expertise, concerns=", ".join(concerns)
)

# Call the superclass constructor with the constructed user_prompt
super().__init__(client, name, user_prompt)
super().__init__(client_factory, name, user_prompt)
self.expertise = expertise
self.concerns = concerns
self.spoken_count = 0

def opinion(self, transcript_list: list[str]) -> str:
transcript = " ".join(transcript_list)
def opinion(self, transcript: str) -> str:
return self.query_gpt(transcript)
5 changes: 3 additions & 2 deletions dream_team_gpt/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base import AIClient
from .config import AIClientConfig
from .get_client import AIClientType, GPTClient, get_ai_client
from .config import AIClientConfig, AIClientType
from .get_client import GPTClient, ai_client_factory, get_ai_client
from .gpt_client import Models
10 changes: 9 additions & 1 deletion dream_team_gpt/clients/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from dataclasses import dataclass
from enum import Enum

from .gpt_client import Models


class AIClientType(str, Enum):
ChatGPT = "ChatGPT"


@dataclass
class AIClientConfig:
client_type: AIClientType
api_key: str
model: str | None
model: Models | None
56 changes: 45 additions & 11 deletions dream_team_gpt/clients/get_client.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,50 @@
from enum import Enum
from dataclasses import dataclass
from typing import Any, Callable

from dream_team_gpt.clients.base import AIClient
from dream_team_gpt.clients.config import AIClientConfig
from dream_team_gpt.clients.gpt_client import GPTClient
from .base import AIClient
from .config import AIClientConfig, AIClientType
from .gpt_client import GPTClient, Models


class AIClientType(str, Enum):
ChatGPT = "ChatGPT"


def get_ai_client(client_type: AIClientType, config: AIClientConfig) -> AIClient:
if client_type == AIClientType.ChatGPT:
def get_ai_client(config: AIClientConfig) -> AIClient:
if config.client_type == AIClientType.ChatGPT:
return GPTClient(config.api_key)
else:
raise ValueError(f"Unknown AI client type: {client_type}")
raise ValueError(f"Unknown AI client type: {config.client_type}")


def ai_client_factory(config: AIClientConfig) -> Callable[[Any], AIClient]:
return lambda _: get_ai_client(config)


@dataclass
class AIClientFactory:
"""Callable factory for AIClient.
Usage:
factory = AIClientFactory(config=AIClientConfig(...))
Agent(factory)
or
factory.config.client_type=<AIClientType>
factory.config.model=<Models>
Agent(factory)
or update these config params in Agent on calling the factory:
factory(client_type=<AIClientType>,model=<Models>)
"""

config: AIClientConfig

def __call__(
self, client_type: AIClientType = None, model: Models = None
) -> Callable[[Any], AIClient]:
if client_type:
self.config.client_type = client_type
if model:
self.config.model = model

return lambda _: get_ai_client(self.config)
80 changes: 38 additions & 42 deletions dream_team_gpt/clients/gpt_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,22 @@
import time

from loguru import logger
import backoff
import openai

from .base import AIClient

MAX_RETRIES = 6 # Number of retries
RETRY_DELAY = 10 # Delay between retries in seconds


class Models(str, Enum):
GPT3 = "gpt-3.5-turbo"
GPT4 = "gpt-4"


class GPTClient(AIClient):
def __init__(self, api_key: str, model: str = Models.GPT3.value):
def __init__(self, api_key: str, model: str = Models.GPT3):
openai.api_key = api_key
self._system_instructions = None
self._user_prompt = None
Expand All @@ -28,63 +32,55 @@ def __init__(self, api_key: str, model: str = Models.GPT3.value):
logger.info(f"Temperature: {self.temperature}")

@property
def system_instructions(self):
def system_instructions(self) -> str:
return self._system_instructions

@system_instructions.setter
def system_instructions(self, value):
def system_instructions(self, value) -> None:
logger.debug(f"Setting system instructions: {self._system_instructions}")
self._system_instructions = value

@property
def user_prompt(self):
def user_prompt(self) -> str:
return self._user_prompt

@user_prompt.setter
def user_prompt(self, value):
def user_prompt(self, value: str) -> None:
logger.debug(f"Setting user prompt: {self._user_prompt}")
self._user_prompt = value

def query(self, transcript: str) -> str:
if self._system_instructions is None:
if not self._system_instructions:
logger.error("self._system_instructions is None. Aborting the query.")
raise RuntimeError("self._system_instructions is None, cannot proceed with query.")
if self._user_prompt is None:
if not self._user_prompt:
logger.error("self._user_prompt is None. Aborting the query.")
raise RuntimeError("self._user_prompt is None, cannot proceed with query.")

max_retries = 6 # Number of retries
retry_delay = 10 # Delay between retries in seconds

# TODO: use backoff decorator
for i in range(max_retries):
try:
start_time = time.time()
messages = [
{"role": "system", "content": self._system_instructions},
{"role": "user", "content": self._user_prompt},
{"role": "assistant", "content": transcript},
]
logger.info(json.dumps(messages, indent=4).replace("\\n", "\n"))

response = openai.ChatCompletion.create(
model=self.model,
temperature=self.temperature,
messages=messages,
)

elapsed_time = time.time() - start_time

# Log the time taken and token usage
logger.info(f"GPT query took {elapsed_time:.2f} seconds")
logger.info(f"Tokens used in the request: {response['usage']}")

return response.choices[0].message.content.strip()
except openai.error.RateLimitError as e:
logger.warning(
f"Rate limit reached. Retrying in {retry_delay} seconds. Details: {e}"
)
time.sleep(retry_delay)

logger.error(f"Max retries reached. Could not complete the GPT query.")
return "Error in GPT client that could not be resolved by retrying."
return self._query(transcript)

@backoff.on_exception(
backoff.constant, openai.error.RateLimitError, max_tries=MAX_RETRIES, interval=RETRY_DELAY
)
def _query(self, transcript: str) -> str:
start_time = time.time()
messages = [
{"role": "system", "content": self._system_instructions},
{"role": "user", "content": self._user_prompt},
{"role": "assistant", "content": transcript},
]
logger.info(json.dumps(messages, indent=4).replace("\\n", "\n"))

response = openai.ChatCompletion.create(
model=self.model,
temperature=self.temperature,
messages=messages,
)

elapsed_time = time.time() - start_time

# Log the time taken and token usage
logger.info(f"GPT query took {elapsed_time:.2f} seconds")
logger.info(f"Tokens used in the request: {response['usage']}")

return response.choices[0].message.content.strip()
Loading

0 comments on commit 3401a4d

Please sign in to comment.