Skip to content

Commit

Permalink
Merge branch 'release/v2.0' into fix/code_manager
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri authored Feb 5, 2024
2 parents 76897ef + 7763716 commit d993185
Show file tree
Hide file tree
Showing 15 changed files with 100 additions and 39 deletions.
7 changes: 6 additions & 1 deletion pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
memory_size: Optional[int] = 10,
pipeline: Optional[Type[GenerateChatPipeline]] = None,
vectorstore: Optional[VectorStore] = None,
description: str = None,
):
"""
Args:
Expand All @@ -59,6 +60,7 @@ def __init__(
self.last_result = None
self.last_code_generated = None
self.last_code_executed = None
self.agent_info = description

self.conversation_id = uuid.uuid4()

Expand All @@ -69,7 +71,10 @@ def __init__(
# Instantiate the context
config = self.get_config(config)
self.context = PipelineContext(
dfs=dfs, config=config, memory=Memory(memory_size), vectorstore=vectorstore
dfs=dfs,
config=config,
memory=Memory(memory_size, agent_info=description),
vectorstore=vectorstore,
)

# Instantiate the logger
Expand Down
2 changes: 1 addition & 1 deletion pandasai/connectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def to_string(
self,
extras={
"index": index,
"type": "sql" if is_direct_sql else "pandas",
"type": "sql" if is_direct_sql else "pd.DataFrame",
"is_direct_sql": is_direct_sql,
},
type_=serializer,
Expand Down
11 changes: 10 additions & 1 deletion pandasai/helpers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ class Memory:

_messages: list
_memory_size: int
_agent_info: str

def __init__(self, memory_size: int = 1):
def __init__(self, memory_size: int = 1, agent_info: Union[str, None] = None):
self._messages = []
self._memory_size = memory_size
self._agent_info = agent_info

def add(self, message: str, is_user: bool):
self._messages.append({"message": message, "is_user": is_user})
Expand Down Expand Up @@ -65,9 +67,16 @@ def get_last_message(self) -> str:
messages = self.get_messages(self._memory_size)
return "" if len(messages) == 0 else messages[-1]

def get_system_prompt(self) -> str:
return self._agent_info

def clear(self):
self._messages = []

@property
def size(self):
return self._memory_size

@property
def agent_info(self):
return self._agent_info
54 changes: 36 additions & 18 deletions pandasai/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class CustomLLM(BaseOpenAI):
from abc import abstractmethod
from typing import Any, Dict, Optional, Union, Mapping, Tuple

from pandasai.helpers.memory import Memory

from ..exceptions import (
APIKeyNotFoundError,
MethodNotImplementedError,
Expand Down Expand Up @@ -137,7 +139,7 @@ def _extract_tag_text(self, response: str, tag: str) -> str:
return None

@abstractmethod
def call(self, instruction: BasePrompt, suffix: str = "") -> str:
def call(self, instruction: BasePrompt, memory: Memory = None) -> str:
"""
Execute the LLM with given prompt.
Expand All @@ -151,7 +153,7 @@ def call(self, instruction: BasePrompt, suffix: str = "") -> str:
"""
raise MethodNotImplementedError("Call method has not been implemented")

def generate_code(self, instruction: BasePrompt) -> str:
def generate_code(self, instruction: BasePrompt, memory: Memory) -> str:
"""
Generate the code based on the instruction and the given prompt.
Expand All @@ -162,7 +164,7 @@ def generate_code(self, instruction: BasePrompt) -> str:
str: A string of Python code.
"""
response = self.call(instruction, suffix="")
response = self.call(instruction, memory)
return self._extract_code(response)


Expand Down Expand Up @@ -269,7 +271,7 @@ def _client_params(self) -> Dict[str, any]:
"http_client": self.http_client,
}

def completion(self, prompt: str) -> str:
def completion(self, prompt: str, memory: Memory) -> str:
"""
Query the completion API
Expand All @@ -280,6 +282,11 @@ def completion(self, prompt: str) -> str:
str: LLM response.
"""
prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)
params = {**self._invocation_params, "prompt": prompt}

if self.stop is not None:
Expand All @@ -292,7 +299,7 @@ def completion(self, prompt: str) -> str:

return response.choices[0].text

def chat_completion(self, value: str) -> str:
def chat_completion(self, value: str, memory: Memory) -> str:
"""
Query the chat completion API
Expand All @@ -303,14 +310,25 @@ def chat_completion(self, value: str) -> str:
str: LLM response.
"""
params = {
**self._invocation_params,
"messages": [
messages = []
if memory and memory.agent_info:
messages.append(
{
"role": "system",
"content": value,
"content": memory.get_system_prompt(),
}
],
)

messages.append(
{
"role": "user",
"content": value,
},
)

params = {
**self._invocation_params,
"messages": messages,
}

if self.stop is not None:
Expand All @@ -323,7 +341,7 @@ def chat_completion(self, value: str) -> str:

return response.choices[0].message.content

def call(self, instruction: BasePrompt, suffix: str = ""):
def call(self, instruction: BasePrompt, memory: Memory = None):
"""
Call the OpenAI LLM.
Expand All @@ -337,12 +355,12 @@ def call(self, instruction: BasePrompt, suffix: str = ""):
Returns:
str: Response
"""
self.last_prompt = instruction.to_string() + suffix
self.last_prompt = instruction.to_string()

return (
self.chat_completion(self.last_prompt)
self.chat_completion(self.last_prompt, memory)
if self._is_chat_model
else self.completion(self.last_prompt)
else self.completion(self.last_prompt, memory)
)


Expand Down Expand Up @@ -395,7 +413,7 @@ def _validate(self):
raise ValueError("max_output_tokens must be greater than zero")

@abstractmethod
def _generate_text(self, prompt: str) -> str:
def _generate_text(self, prompt: str, memory: Memory) -> str:
"""
Generates text for prompt, specific to implementation.
Expand All @@ -408,7 +426,7 @@ def _generate_text(self, prompt: str) -> str:
"""
raise MethodNotImplementedError("method has not been implemented")

def call(self, instruction: BasePrompt, suffix: str = "") -> str:
def call(self, instruction: BasePrompt, memory: Memory = None) -> str:
"""
Call the Google LLM.
Expand All @@ -420,5 +438,5 @@ def call(self, instruction: BasePrompt, suffix: str = "") -> str:
str: LLM response.
"""
self.last_prompt = instruction.to_string() + suffix
return self._generate_text(self.last_prompt)
self.last_prompt = instruction.to_string()
return self._generate_text(self.last_prompt, memory)
6 changes: 4 additions & 2 deletions pandasai/llm/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Optional

from pandasai.helpers.memory import Memory

from ..prompts.base import BasePrompt
from .base import LLM

Expand All @@ -15,8 +17,8 @@ def __init__(self, output: Optional[str] = None):
if output is not None:
self._output = output

def call(self, instruction: BasePrompt, suffix: str = "") -> str:
self.last_prompt = instruction.to_string() + suffix
def call(self, instruction: BasePrompt, memory: Memory = None) -> str:
self.last_prompt = instruction.to_string()
return self._output

@property
Expand Down
8 changes: 7 additions & 1 deletion pandasai/llm/google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
>>> from pandasai.llm.google_palm import GooglePalm
"""
from pandasai.helpers.memory import Memory
from .base import BaseGoogle
from typing import Any
from ..helpers.optional import import_dependency
Expand Down Expand Up @@ -69,7 +70,7 @@ def _validate(self):
if not self.model:
raise ValueError("model is required.")

def _generate_text(self, prompt: str) -> str:
def _generate_text(self, prompt: str, memory: Memory = None) -> str:
"""
Generates text for prompt.
Expand All @@ -81,6 +82,11 @@ def _generate_text(self, prompt: str) -> str:
"""
self._validate()
prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)
completion = self.google_palm.generate_text(
model=self.model,
prompt=prompt,
Expand Down
15 changes: 12 additions & 3 deletions pandasai/llm/google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
"""
from typing import Optional

from pandasai.helpers.memory import Memory
from .base import BaseGoogle
from ..exceptions import UnsupportedModelError
from ..helpers.optional import import_dependency
Expand Down Expand Up @@ -95,7 +97,7 @@ def _validate(self):
if not self.model:
raise ValueError("model is required.")

def _generate_text(self, prompt: str) -> str:
def _generate_text(self, prompt: str, memory: Memory = None) -> str:
"""
Generates text for prompt.
Expand All @@ -114,6 +116,12 @@ def _generate_text(self, prompt: str) -> str:
)
from vertexai.preview.generative_models import GenerativeModel

updated_prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)

if self.model in self._supported_code_models:
code_generation = CodeGenerationModel.from_pretrained(self.model)

Expand All @@ -126,16 +134,17 @@ def _generate_text(self, prompt: str) -> str:
text_generation = TextGenerationModel.from_pretrained(self.model)

completion = text_generation.predict(
prompt=prompt,
prompt=updated_prompt,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
)
elif self.model in self._supported_generative_models:
model = GenerativeModel(self.model)

responses = model.generate_content(
[prompt],
[updated_prompt],
generation_config={
"max_output_tokens": self.max_output_tokens,
"temperature": self.temperature,
Expand Down
12 changes: 10 additions & 2 deletions pandasai/llm/huggingface_text_gen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, Any, List, Dict

from pandasai.helpers.memory import Memory

from .base import LLM
from ..helpers import load_dotenv
from ..prompts.base import BasePrompt
Expand Down Expand Up @@ -75,8 +77,14 @@ def _default_params(self) -> Dict[str, Any]:
"seed": self.seed,
}

def call(self, instruction: BasePrompt, suffix: str = "") -> str:
prompt = instruction.to_string() + suffix
def call(self, instruction: BasePrompt, memory: Memory = "") -> str:
prompt = instruction.to_string()

prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)

params = self._default_params
if self.streaming:
Expand Down
10 changes: 8 additions & 2 deletions pandasai/llm/langchain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pandasai.helpers.memory import Memory
from pandasai.prompts.base import BasePrompt
from .base import LLM

Expand All @@ -22,8 +23,13 @@ class LangchainLLM(LLM):
def __init__(self, langchain_llm):
self.langchain_llm = langchain_llm

def call(self, instruction: BasePrompt, suffix: str = "") -> str:
prompt = instruction.to_string() + suffix
def call(self, instruction: BasePrompt, memory: Memory = None) -> str:
prompt = instruction.to_string()
prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)
return self.langchain_llm.predict(prompt)

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion pandasai/pipelines/chat/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def execute(self, input: Any, **kwargs) -> Any:
pipeline_context: PipelineContext = kwargs.get("context")
logger: Logger = kwargs.get("logger")

code = pipeline_context.config.llm.generate_code(input)
code = pipeline_context.config.llm.generate_code(input, pipeline_context.memory)

pipeline_context.add("last_code_generated", code)
logger.log(
Expand Down
2 changes: 1 addition & 1 deletion tests/llms/test_google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,5 @@ def test_call(self, mocker, prompt):
generativeai, "generate_text", return_value=expected_response
)

result = llm.call(instruction=prompt, suffix="!")
result = llm.call(instruction=prompt)
assert result == expected_text
4 changes: 1 addition & 3 deletions tests/llms/test_langchain_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,4 @@ def test_langchain_llm_type(self, langchain_llm):
def test_langchain_model_call(self, langchain_llm, prompt):
langchain_wrapper = LangchainLLM(langchain_llm)

assert (
langchain_wrapper.call(instruction=prompt, suffix="!") == "Custom response"
)
assert langchain_wrapper.call(instruction=prompt) == "Custom response"
2 changes: 1 addition & 1 deletion tests/llms/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_call_supported_chat_model(self, mocker, prompt):
result = openai.call(instruction=prompt)
assert result == "response"

def test_call_finetuned_model(self, mocker, prompt):
def test_call_with_system_prompt(self, mocker, prompt):
openai = OpenAI(
api_token="test", model="ft:gpt-3.5-turbo:my-org:custom_suffix:id"
)
Expand Down
Loading

0 comments on commit d993185

Please sign in to comment.