Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(multiturn-conv): update prompts and use api's for multiturn #928

Merged
merged 5 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pandasai/helpers/dataframe_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict:
"name": df.name,
"description": df.description,
"type": extras["type"],
"data": {},
}
# Add DataFrame details to the result
data = {
Expand All @@ -121,9 +120,9 @@ def convert_df_to_json(self, df: pd.DataFrame, extras: dict) -> dict:

data["schema"]["fields"].append(col_info)

df_info["data"] = data
result = df_info | data

return {df_number_key: df_info}
return {df_number_key: result}

def convert_df_to_json_str(self, df: pd.DataFrame, extras: dict) -> str:
"""
Expand Down
2 changes: 1 addition & 1 deletion pandasai/helpers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_messages(self, limit: int = None) -> list:
limit = self._memory_size if limit is None else limit

return [
f"{'Q' if message['is_user'] else 'A'}: {message['message'] if message['is_user'] else self._truncate(message['message'])}"
f"{'### QUERY' if message['is_user'] else '### ANSWER'}\n {message['message'] if message['is_user'] else self._truncate(message['message'])}"
for message in self._messages[-limit:]
]

Expand Down
43 changes: 38 additions & 5 deletions pandasai/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class CustomLLM(BaseOpenAI):
from typing import Any, Dict, Optional, Union, Mapping, Tuple

from pandasai.helpers.memory import Memory
from pandasai.prompts.generate_system_message import GenerateSystemMessagePrompt

from ..exceptions import (
APIKeyNotFoundError,
Expand Down Expand Up @@ -118,6 +119,30 @@ def _extract_code(self, response: str, separator: str = "```") -> str:

return code

def prepend_system_prompt(self, prompt: BasePrompt, memory: Memory):
"""
Append system prompt to the chat prompt, useful when model doesn't have messages for chat history
Args:
prompt (BasePrompt): prompt for chat method
memory (Memory): user conversation history
"""
return self.get_system_prompt(memory) + prompt if memory else prompt
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (llm): The use of the ternary operator without parentheses around the condition might lead to unexpected behavior due to operator precedence. It's recommended to add parentheses around the ternary operation for clarity and to ensure the intended logic is executed.


def get_system_prompt(self, memory: Memory) -> Any:
"""
Generate system prompt with agent info and previous conversations
"""
system_prompt = GenerateSystemMessagePrompt(memory=memory)
return system_prompt.to_string()

def get_messages(self, memory: Memory) -> Any:
"""
Return formatted messages
Args:
memory (Memory): Get past Conversation from memory
"""
return memory.get_previous_conversation()

def _extract_tag_text(self, response: str, tag: str) -> str:
"""
Extracts the text between two tags in the response.
Expand Down Expand Up @@ -282,11 +307,8 @@ def completion(self, prompt: str, memory: Memory) -> str:
str: LLM response.

"""
prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)
prompt = self.prepend_system_prompt(prompt, memory)

params = {**self._invocation_params, "prompt": prompt}

if self.stop is not None:
Expand All @@ -297,6 +319,8 @@ def completion(self, prompt: str, memory: Memory) -> str:
if openai_handler := openai_callback_var.get():
openai_handler(response)

self.last_prompt = prompt

return response.choices[0].text

def chat_completion(self, value: str, memory: Memory) -> str:
Expand All @@ -319,6 +343,15 @@ def chat_completion(self, value: str, memory: Memory) -> str:
}
)

for message in memory.all():
if message["is_user"]:
messages.append({"role": "user", "content": message["message"]})
else:
messages.append(
{"role": "assistant", "content": message["message"]}
)

# adding current prompt as latest query message
messages.append(
{
"role": "user",
Expand Down
7 changes: 2 additions & 5 deletions pandasai/llm/google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,7 @@ def _generate_text(self, prompt: str, memory: Memory = None) -> str:

"""
self._validate()
prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)
prompt = self.prepend_system_prompt(prompt, memory)
completion = self.google_palm.generate_text(
model=self.model,
prompt=prompt,
Expand All @@ -95,6 +91,7 @@ def _generate_text(self, prompt: str, memory: Memory = None) -> str:
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
)
self.last_prompt = prompt
return completion.result

@property
Expand Down
41 changes: 31 additions & 10 deletions pandasai/llm/google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class GoogleVertexAI(BaseGoogle):
_supported_generative_models = [
"gemini-pro",
]
_supported_code_chat_models = ["codechat-bison@001", "codechat-bison@002"]

def __init__(
self, project_id: str, location: str, model: Optional[str] = None, **kwargs
Expand Down Expand Up @@ -110,19 +111,13 @@ def _generate_text(self, prompt: str, memory: Memory = None) -> str:
"""
self._validate()

from vertexai.preview.language_models import (
CodeGenerationModel,
TextGenerationModel,
)
from vertexai.preview.generative_models import GenerativeModel
updated_prompt = self.prepend_system_prompt(prompt, memory)

updated_prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)
self.last_prompt = updated_prompt

if self.model in self._supported_code_models:
from vertexai.preview.language_models import CodeGenerationModel

code_generation = CodeGenerationModel.from_pretrained(self.model)

completion = code_generation.predict(
Expand All @@ -131,6 +126,8 @@ def _generate_text(self, prompt: str, memory: Memory = None) -> str:
max_output_tokens=self.max_output_tokens,
)
elif self.model in self._supported_text_models:
from vertexai.preview.language_models import TextGenerationModel

text_generation = TextGenerationModel.from_pretrained(self.model)

completion = text_generation.predict(
Expand All @@ -141,6 +138,8 @@ def _generate_text(self, prompt: str, memory: Memory = None) -> str:
max_output_tokens=self.max_output_tokens,
)
elif self.model in self._supported_generative_models:
from vertexai.preview.generative_models import GenerativeModel

model = GenerativeModel(self.model)

responses = model.generate_content(
Expand All @@ -154,6 +153,28 @@ def _generate_text(self, prompt: str, memory: Memory = None) -> str:
)

completion = responses.candidates[0].content.parts[0]
elif self.model in self._supported_code_chat_models:
from vertexai.language_models import CodeChatModel, ChatMessage

code_chat_model = CodeChatModel.from_pretrained(self.model)
messages = []

for message in memory.all():
if message["is_user"]:
messages.append(
ChatMessage(author="user", content=message["message"])
)
else:
messages.append(
ChatMessage(author="model", content=message["message"])
)
chat = code_chat_model.start_chat(
context=memory.get_system_prompt(), message_history=messages
)

response = chat.send_message(prompt)
return response.text

else:
raise UnsupportedModelError(self.model)

Expand Down
7 changes: 2 additions & 5 deletions pandasai/llm/huggingface_text_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ def _default_params(self) -> Dict[str, Any]:
def call(self, instruction: BasePrompt, memory: Memory = "") -> str:
prompt = instruction.to_string()

prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)
prompt = self.prepend_system_prompt(prompt, memory)

params = self._default_params
if self.streaming:
Expand All @@ -100,6 +96,7 @@ def call(self, instruction: BasePrompt, memory: Memory = "") -> str:
res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq)
]
self.last_prompt = prompt
return res.generated_text

@property
Expand Down
7 changes: 2 additions & 5 deletions pandasai/llm/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ def __init__(self, langchain_llm):

def call(self, instruction: BasePrompt, memory: Memory = None) -> str:
prompt = instruction.to_string()
prompt = (
memory.get_system_prompt() + "\n" + prompt
if memory and memory.agent_info
else prompt
)
prompt = self.prepend_system_prompt(prompt, memory)
self.last_prompt = prompt
return self.langchain_llm.predict(prompt)

@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions pandasai/pipelines/chat/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ def execute(self, input: Any, **kwargs) -> Any:
code = pipeline_context.config.llm.generate_code(input, pipeline_context.memory)

pipeline_context.add("last_code_generated", code)
logger.log(
f"""Prompt used:
{pipeline_context.config.llm.last_prompt}
"""
)
logger.log(
f"""Code generated:
```
Expand Down
7 changes: 7 additions & 0 deletions pandasai/prompts/generate_system_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .base import BasePrompt


class GenerateSystemMessagePrompt(BasePrompt):
"""Prompt to generate Python code from a dataframe."""

template_path = "generate_system_message.tmpl"
3 changes: 0 additions & 3 deletions pandasai/prompts/templates/generate_python_code.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
{% if context.skills_manager.has_skills() %}
{{context.skills_manager.prompt_display()}}
{% endif %}

{{ context.memory.get_previous_conversation() }}

{% if last_code_generated != "" and context.memory.count() > 0 %}
{{ last_code_generated }}
{% else %}
Expand Down
2 changes: 0 additions & 2 deletions pandasai/prompts/templates/generate_python_code_with_sql.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ def execute_sql_query(sql_query: str) -> pd.Dataframe
"""This method connects to the database, executes the sql query and returns the dataframe"""
</function>

{{ context.memory.get_previous_conversation() }}

{% if last_code_generated != "" and context.memory.count() > 0 %}
{{ last_code_generated }}
{% else %}
Expand Down
5 changes: 5 additions & 0 deletions pandasai/prompts/templates/generate_system_message.tmpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{% if memory.agent_info %} {{memory.get_system_prompt()}} {% endif %}
{% if memory.count() > 1 %}
### PREVIOUS CONVERSATION
{{ memory.get_previous_conversation() }}
{% endif %}
3 changes: 2 additions & 1 deletion tests/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,8 @@ def test_retry_on_error_with_single_df(
</dataframe>

The user asked the following question:
Q: Hello world
### QUERY
Hello world

You generated this python code:
result = {'type': 'string', 'value': 'Hello World'}
Expand Down
35 changes: 35 additions & 0 deletions tests/llms/test_base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

from pandasai.exceptions import APIKeyNotFoundError
from pandasai.helpers.memory import Memory
from pandasai.llm import LLM


Expand Down Expand Up @@ -64,3 +65,37 @@ def test_extract_code(self):
"""

assert LLM()._extract_code(code) == "print('Hello World')"

def test_get_system_prompt_empty_memory(self):
assert LLM().get_system_prompt(Memory()) == "\n"

def test_get_system_prompt_memory_with_agent_info(self):
mem = Memory(agent_info="xyz")
assert LLM().get_system_prompt(mem) == " xyz \n"

def test_get_system_prompt_memory_with_agent_info_messages(self):
mem = Memory(agent_info="xyz", memory_size=10)
mem.add("hello world", True)
mem.add('print("hello world)', False)
mem.add("hello world", True)
print(mem.get_messages())
assert (
LLM().get_system_prompt(mem)
== ' xyz \n\n### PREVIOUS CONVERSATION\n### QUERY\n hello world\n### ANSWER\n print("hello world)\n'
)

def test_prepend_system_prompt_with_empty_mem(self):
assert LLM().prepend_system_prompt("hello world", Memory()) == "\nhello world"

def test_prepend_system_prompt_with_non_empty_mem(self):
mem = Memory(agent_info="xyz", memory_size=10)
mem.add("hello world", True)
mem.add('print("hello world)', False)
mem.add("hello world", True)
assert (
LLM().prepend_system_prompt("hello world", mem)
== ' xyz \n\n### PREVIOUS CONVERSATION\n### QUERY\n hello world\n### ANSWER\n print("hello world)\nhello world'
)

def test_prepend_system_prompt_with_memory_none(self):
assert LLM().prepend_system_prompt("hello world", None) == "hello world"
4 changes: 4 additions & 0 deletions tests/llms/test_google_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,7 @@ def test_validate_without_model(self, google_vertexai: GoogleVertexAI):
google_vertexai.model = None
with pytest.raises(ValueError, match="model is required."):
google_vertexai._validate()

def test_validate_with_code_chat_model(self, google_vertexai: GoogleVertexAI):
google_vertexai.model = "codechat-bison@001"
google_vertexai._validate()
Loading
Loading