Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming and Agent Routing using langgraph #213

Merged
merged 6 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions app/modules/intelligence/agents/agents/callback_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
from datetime import datetime
import json
from typing import Any, Dict, List, Optional, Tuple, Union
from crewai.agents.parser import AgentAction

class FileCallbackHandler:
def __init__(self, filename: str = "agent_execution_log.md"):
"""Initialize the file callback handler.

Args:
filename (str): The markdown file to write the logs to
"""
self.filename = filename
# Create or clear the file initially
with open(self.filename, 'w', encoding='utf-8') as f:
f.write(f"# Agent Execution Log\nStarted at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")

def __call__(self, step_output: Union[str, List[Tuple[Dict[str, Any], str]], AgentAction]) -> None:
"""Callback function to handle agent execution steps.

Args:
step_output: Output from the agent's execution step. Can be:
- string
- list of (action, observation) tuples
- AgentAction from CrewAI
"""
with open(self.filename, 'a', encoding='utf-8') as f:
f.write(f"\n## Step - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write("---\n")

# Handle AgentAction output
if isinstance(step_output, AgentAction):
# Write thought section
if hasattr(step_output, 'thought') and step_output.thought:
f.write("### Thought\n")
f.write(f"{step_output.thought}\n\n")

# Write tool section
if hasattr(step_output, 'tool'):
f.write("### Action\n")
f.write(f"**Tool:** {step_output.tool}\n")

# if hasattr(step_output, 'tool_input'):
# try:
# # Try to parse and pretty print JSON input
# tool_input = json.loads(step_output.tool_input)
# formatted_input = json.dumps(tool_input, indent=2)
# f.write(f"**Input:**\n```json\n{formatted_input}\n```\n")
# except (json.JSONDecodeError, TypeError):
# # Fallback to raw string if not JSON
# f.write(f"**Input:**\n```\n{step_output.tool_input}\n```\n")

# # Write result section
# if hasattr(step_output, 'result'):
# f.write("\n### Result\n")
# try:
# # Try to parse and pretty print JSON result
# result = json.loads(step_output.result)
# formatted_result = json.dumps(result, indent=2)
# f.write(f"```json\n{formatted_result}\n```\n")
# except (json.JSONDecodeError, TypeError):
# # Fallback to raw string if not JSON
# f.write(f"```\n{step_output.result}\n```\n")

f.write("\n")
return

# Handle single string output
if isinstance(step_output, str):
f.write(step_output + "\n")
return

for step in step_output:
if not isinstance(step, tuple):
f.write(str(step) + "\n")
continue

action, observation = step

# Handle action section
f.write("### Action\n")
if isinstance(action, dict):
if "tool" in action:
f.write(f"**Tool:** {action['tool']}\n")
if "tool_input" in action:
f.write(f"**Input:**\n```\n{action['tool_input']}\n```\n")
if "log" in action:
f.write(f"**Log:** {action['log']}\n")
if "Action" in action:
f.write(f"**Action Type:** {action['Action']}\n")
else:
f.write(f"{str(action)}\n")

# Handle observation section
f.write("\n### Observation\n")
if isinstance(observation, str):
# Handle special formatting for search-like results
lines = observation.split('\n')
for line in lines:
if line.startswith(('Title:', 'Link:', 'Snippet:')):
key, value = line.split(':', 1)
f.write(f"**{key.strip()}:**{value}\n")
elif line.startswith('-'):
f.write(line + "\n")
else:
f.write(line + "\n")
else:
f.write(str(observation) + "\n")

f.write("\n")
38 changes: 34 additions & 4 deletions app/modules/intelligence/agents/agents/rag_agent.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import asyncio
import os
from typing import Any, Dict, List
from typing import Any, AsyncGenerator, Dict, List

import aiofiles
from contextlib import redirect_stdout
import agentops
from crewai import Agent, Crew, Process, Task
from pydantic import BaseModel, Field

from app.modules.code_provider.code_provider_service import CodeProviderService
from app.modules.conversations.message.message_schema import NodeContext
from app.modules.intelligence.agents.agents.callback_handler import FileCallbackHandler
from app.modules.intelligence.provider.provider_service import (
AgentType,
ProviderService,
Expand Down Expand Up @@ -71,6 +75,7 @@ def __init__(self, sql_db, llm, mini_llm, user_id):
self.llm = llm
self.mini_llm = mini_llm
self.user_id = user_id
#self.callback_handler = FileCallbackHandler("rag_agent_execution.md")

async def create_agents(self):
query_agent = Agent(
Expand Down Expand Up @@ -101,6 +106,7 @@ async def create_agents(self):
verbose=True,
llm=self.llm,
max_iter=self.max_iter,
#step_callback=self.callback_handler,
)

return query_agent
Expand Down Expand Up @@ -267,15 +273,39 @@ async def kickoff_rag_agent(
llm,
mini_llm,
user_id: str,
) -> str:
) -> AsyncGenerator[str, None]:
provider_service = ProviderService(sql_db, user_id)
crew_ai_llm = provider_service.get_large_llm(agent_type=AgentType.CREWAI)
crew_ai_mini_llm = provider_service.get_small_llm(agent_type=AgentType.CREWAI)
rag_agent = RAGAgent(sql_db, crew_ai_llm, crew_ai_mini_llm, user_id)
file_structure = await CodeProviderService(sql_db).get_project_structure_async(
project_id
)
result = await rag_agent.run(


read_fd, write_fd = os.pipe()

async def kickoff():
with os.fdopen(write_fd, "w", buffering=1) as write_file:
with redirect_stdout(write_file):
await rag_agent.run(
query, project_id, chat_history, node_ids, file_structure
)
return result


asyncio.create_task(kickoff())

# Yield CrewAgent logs as they are written to the pipe
final_answer_streaming = False
async with aiofiles.open(read_fd, mode='r') as read_file:
async for line in read_file:
if not line:
break
else:
if final_answer_streaming:
if line.endswith('\\x1b[00m\\n'):
yield line[:-6]
else:
yield line
if "## Final Answer:" in line:
final_answer_streaming = True
123 changes: 76 additions & 47 deletions app/modules/intelligence/agents/chat_agents/qna_chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
import logging
import time
from functools import lru_cache
from typing import AsyncGenerator, Dict, List
from typing import Any, AsyncGenerator, Dict, List, Optional
from typing import Annotated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Remove unused imports to improve clarity.

The static analysis hints correctly highlight that 'typing.Any', 'typing.Optional', and 'typing.Annotated' are not used in this file. Consider removing them.

Here is a suggested diff:

-from typing import Any, AsyncGenerator, Dict, List, Optional
-from typing import Annotated
+from typing import AsyncGenerator, Dict, List
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from typing import Any, AsyncGenerator, Dict, List, Optional
from typing import Annotated
from typing import AsyncGenerator, Dict, List
🧰 Tools
🪛 Ruff (0.8.2)

5-5: typing.Any imported but unused

Remove unused import

(F401)


5-5: typing.Optional imported but unused

Remove unused import

(F401)


6-6: typing.Annotated imported but unused

Remove unused import: typing.Annotated

(F401)

from langgraph.types import StreamWriter

from typing_extensions import TypedDict

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain.schema import HumanMessage, SystemMessage
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.prompts import (
Expand Down Expand Up @@ -81,7 +87,57 @@ async def _classify_query(self, query: str, history: List[HumanMessage]):

return response.classification

async def run(


class State(TypedDict):
query: str
project_id: str
user_id: str
conversation_id: str
node_ids: List[NodeContext]



async def _stream_rag_agent(self, state: State, writer: StreamWriter):
async for chunk in self.execute(
state["query"],
state["project_id"],
state["user_id"],
state["conversation_id"],
state["node_ids"],
):

writer(chunk)

def _create_graph(self):
graph_builder = StateGraph(QNAChatAgent.State)

graph_builder.add_node(
"rag_agent",
self._stream_rag_agent,
)
graph_builder.add_edge(START, "rag_agent")
graph_builder.add_edge("rag_agent", END)
return graph_builder.compile()

async def run(self,
query: str,
project_id: str,
user_id: str,
conversation_id: str,
node_ids: List[NodeContext],):
state = {
"query": query,
"project_id": project_id,
"user_id": user_id,
"conversation_id": conversation_id,
"node_ids": node_ids,
}
graph = self._create_graph()
async for chunk in graph.astream(state,stream_mode="custom"):
yield chunk

async def execute(
self,
query: str,
project_id: str,
Expand Down Expand Up @@ -117,8 +173,7 @@ async def run(
tool_results = []
citations = []
if classification == ClassificationResult.AGENT_REQUIRED:
rag_start_time = time.time() # Start timer for RAG agent
rag_result = await kickoff_rag_agent(
async for chunk in kickoff_rag_agent(
query,
project_id,
[
Expand All @@ -131,54 +186,28 @@ async def run(
self.llm,
self.mini_llm,
user_id,
)
rag_duration = time.time() - rag_start_time # Calculate duration
logger.info(
f"Time elapsed since entering run: {time.time() - start_time:.2f}s, "
f"Duration of RAG agent: {rag_duration:.2f}s"
)
):
content = str(chunk)

if rag_result.pydantic:
citations = rag_result.pydantic.citations
response = rag_result.pydantic.response
result = [node for node in response]
else:
citations = []
result = rag_result.raw
tool_results = [SystemMessage(content=result)]
# Timing for adding message chunk
add_chunk_start_time = (
time.time()
) # Start timer for adding message chunk
self.history_manager.add_message_chunk(
conversation_id,
tool_results[0].content,
MessageType.AI_GENERATED,
citations=citations,
)
add_chunk_duration = (
time.time() - add_chunk_start_time
) # Calculate duration
logger.info(
f"Time elapsed since entering run: {time.time() - start_time:.2f}s, "
f"Duration of adding message chunk: {add_chunk_duration:.2f}s"
)
self.history_manager.add_message_chunk(
conversation_id,
content,
MessageType.AI_GENERATED,
citations=citations,
)

yield json.dumps(
{
"citations": citations,
"message": content,
}
)


# Timing for flushing message buffer
flush_buffer_start_time = (
time.time()
) # Start timer for flushing message buffer
self.history_manager.flush_message_buffer(
conversation_id, MessageType.AI_GENERATED
)
flush_buffer_duration = (
time.time() - flush_buffer_start_time
) # Calculate duration
logger.info(
f"Time elapsed since entering run: {time.time() - start_time:.2f}s, "
f"Duration of flushing message buffer: {flush_buffer_duration:.2f}s"
)
yield json.dumps({"citations": citations, "message": result})


if classification != ClassificationResult.AGENT_REQUIRED:
inputs = {
Expand Down
Loading