Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvement : Centralised prompt management #207

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ cli/dist
cli/.venv
venv/
.env
*.json
.momentum
db
cli/momentum_cli/.momentum
Expand All @@ -22,3 +21,4 @@ cli/momentum_cli/.momentum
projects/
# Ignore PyCharm config
.idea
service-account.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
Expand All @@ -19,7 +18,8 @@


def upgrade() -> None:
op.execute('ALTER TABLE projects ADD COLUMN repo_path TEXT DEFAULT NULL')
op.execute("ALTER TABLE projects ADD COLUMN repo_path TEXT DEFAULT NULL")


def downgrade() -> None:
op.drop_column("projects", "repo_path")
8 changes: 6 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
router as conversations_router,
)
from app.modules.intelligence.agents.agents_router import router as agent_router
from app.modules.intelligence.llm_provider.llm_provider_router import (
router as llm_provider_router,
)
from app.modules.intelligence.prompts.prompt_router import router as prompt_router
from app.modules.intelligence.prompts.system_prompt_setup import SystemPromptSetup
from app.modules.intelligence.provider.provider_router import router as provider_router
from app.modules.intelligence.tools.tool_router import router as tool_router
from app.modules.key_management.secret_manager import router as secret_manager_router
from app.modules.parsing.graph_construction.parsing_router import (
Expand Down Expand Up @@ -110,7 +112,9 @@ def include_routers(self):
self.app.include_router(search_router, prefix="/api/v1", tags=["Search"])
self.app.include_router(github_router, prefix="/api/v1", tags=["Github"])
self.app.include_router(agent_router, prefix="/api/v1", tags=["Agents"])
self.app.include_router(provider_router, prefix="/api/v1", tags=["Providers"])
self.app.include_router(
llm_provider_router, prefix="/api/v1", tags=["Providers"]
)
self.app.include_router(tool_router, prefix="/api/v1", tags=["Tools"])
if os.getenv("isDevelopmentMode") != "enabled":
self.app.include_router(
Expand Down
5 changes: 2 additions & 3 deletions app/modules/code_provider/github/github_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,18 @@ async def get_user_repos(
):
user_repo_list = await GithubController(db).get_user_repos(user=user)
user_repo_list["repositories"].extend(config_provider.get_demo_repo_list())

# Remove duplicates while preserving order
seen = set()
deduped_repos = []
for repo in reversed(user_repo_list["repositories"]):
# Create tuple of values to use as hash key
repo_key = repo["full_name"]


if repo_key not in seen:
seen.add(repo_key)
deduped_repos.append(repo)

user_repo_list["repositories"] = deduped_repos
return user_repo_list

Expand Down
24 changes: 12 additions & 12 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
from app.modules.intelligence.agents.custom_agents.custom_agents_service import (
CustomAgentsService,
)
from app.modules.intelligence.memory.chat_history_service import ChatHistoryService
from app.modules.intelligence.provider.provider_service import (
AgentType,
ProviderService,
from app.modules.intelligence.llm_provider.llm_provider_service import (
LLMProviderService,
)
from app.modules.intelligence.memory.chat_history_service import ChatHistoryService
from app.modules.intelligence.prompts_provider.agent_types import AgentLLMType
from app.modules.projects.projects_service import ProjectService
from app.modules.users.user_service import UserService
from app.modules.utils.posthog_helper import PostHogClient
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(
user_email: str,
project_service: ProjectService,
history_manager: ChatHistoryService,
provider_service: ProviderService,
llm_provider_service: LLMProviderService,
agent_injector_service: AgentInjectorService,
custom_agent_service: CustomAgentsService,
):
Expand All @@ -83,24 +83,24 @@ def __init__(
self.user_email = user_email
self.project_service = project_service
self.history_manager = history_manager
self.provider_service = provider_service
self.llm_provider_service = llm_provider_service
self.agent_injector_service = agent_injector_service
self.custom_agent_service = custom_agent_service

@classmethod
def create(cls, db: Session, user_id: str, user_email: str):
project_service = ProjectService(db)
history_manager = ChatHistoryService(db)
provider_service = ProviderService(db, user_id)
agent_injector_service = AgentInjectorService(db, provider_service, user_id)
llm_provider_service = LLMProviderService(db, user_id)
agent_injector_service = AgentInjectorService(db, llm_provider_service, user_id)
custom_agent_service = CustomAgentsService()
return cls(
db,
user_id,
user_email,
project_service,
history_manager,
provider_service,
llm_provider_service,
agent_injector_service,
custom_agent_service,
)
Expand Down Expand Up @@ -206,7 +206,7 @@ def _create_conversation_record(
logger.info(
f"Project id : {conversation.project_ids[0]} Created new conversation with ID: {conversation_id}, title: {title}, user_id: {user_id}, agent_id: {conversation.agent_ids[0]}"
)
provider_name = self.provider_service.get_llm_provider_name()
provider_name = self.llm_provider_service.get_llm_provider_name()
PostHogClient().send_event(
user_id,
"create Conversation Event",
Expand Down Expand Up @@ -261,7 +261,7 @@ async def store_message(
conversation_id, message_type, user_id
)
logger.info(f"Stored message in conversation {conversation_id}")
provider_name = self.provider_service.get_llm_provider_name()
provider_name = self.llm_provider_service.get_llm_provider_name()

PostHogClient().send_event(
user_id,
Expand Down Expand Up @@ -335,7 +335,7 @@ async def _generate_title(
) -> str:
agent_type = conversation.agent_ids[0]

llm = self.provider_service.get_small_llm(agent_type=AgentType.LANGCHAIN)
llm = self.llm_provider_service.get_small_llm(agent_type=AgentLLMType.LANGCHAIN)
prompt = ChatPromptTemplate.from_template(
"Given an agent type '{agent_type}' and an initial message '{message}', "
"generate a concise and relevant title for a conversation. "
Expand Down
16 changes: 9 additions & 7 deletions app/modules/intelligence/agents/agent_injector_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,28 @@
from app.modules.intelligence.agents.custom_agents.custom_agents_service import (
CustomAgentsService,
)
from app.modules.intelligence.provider.provider_service import (
AgentType,
ProviderService,
from app.modules.intelligence.llm_provider.llm_provider_service import (
LLMProviderService,
)
from app.modules.intelligence.prompts_provider.agent_types import AgentLLMType

logger = logging.getLogger(__name__)


class AgentInjectorService:
def __init__(self, db: Session, provider_service: ProviderService, user_id: str):
def __init__(self, db: Session, provider_service: LLMProviderService, user_id: str):
self.sql_db = db
self.provider_service = provider_service
self.custom_agent_service = CustomAgentsService()
self.agents = self._initialize_agents()
self.user_id = user_id

def _initialize_agents(self) -> Dict[str, Any]:
mini_llm = self.provider_service.get_small_llm(agent_type=AgentType.LANGCHAIN)
mini_llm = self.provider_service.get_small_llm(
agent_type=AgentLLMType.LANGCHAIN
)
reasoning_llm = self.provider_service.get_large_llm(
agent_type=AgentType.LANGCHAIN
agent_type=AgentLLMType.LANGCHAIN
)
return {
"debugging_agent": DebuggingChatAgent(mini_llm, reasoning_llm, self.sql_db),
Expand All @@ -66,7 +68,7 @@ def get_agent(self, agent_id: str) -> Any:
return self.agents[agent_id]
else:
reasoning_llm = self.provider_service.get_large_llm(
agent_type=AgentType.LANGCHAIN
agent_type=AgentLLMType.LANGCHAIN
)
return CustomAgent(
llm=reasoning_llm,
Expand Down
73 changes: 22 additions & 51 deletions app/modules/intelligence/agents/agents/blast_radius_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from pydantic import BaseModel, Field

from app.modules.conversations.message.message_schema import NodeContext
from app.modules.intelligence.provider.provider_service import (
AgentType,
ProviderService,
from app.modules.intelligence.llm_provider.llm_provider_service import (
LLMProviderService,
)
from app.modules.intelligence.prompts_provider.agent_prompts import AgentPromptsProvider
from app.modules.intelligence.prompts_provider.agent_types import AgentLLMType
from app.modules.intelligence.tools.change_detection.change_detection_tool import (
ChangeDetectionResponse,
get_change_detection_tool,
Expand All @@ -33,10 +34,13 @@ def __init__(self, sql_db, user_id, llm):
)

async def create_agents(self):
agent_prompt = AgentPromptsProvider.get_agent_prompt(
agent_id="blast_radius_agent", agent_type=AgentLLMType.CREWAI
)
blast_radius_agent = Agent(
role="Blast Radius Agent",
goal="Explain the blast radius of the changes made in the code.",
backstory="You are an expert in understanding the impact of code changes on the codebase.",
role=agent_prompt["role"],
goal=agent_prompt["goal"],
backstory=agent_prompt["backstory"],
allow_delegation=False,
verbose=True,
llm=self.llm,
Expand All @@ -60,50 +64,17 @@ async def create_tasks(
query: str,
blast_radius_agent,
):
task_prompt = AgentPromptsProvider.get_task_prompt(
task_id="analyze_changes_task",
agent_type=AgentLLMType.CREWAI,
project_id=project_id,
query=query,
ChangeDetectionResponse=ChangeDetectionResponse,
BlastRadiusAgentResponse=self.BlastRadiusAgentResponse,
)

analyze_changes_task = Task(
description=f"""Fetch the changes in the current branch for project {project_id} using the get code changes tool.
The response of the fetch changes tool is in the following format:
{ChangeDetectionResponse.model_json_schema()}
In the response, the patches contain the file patches for the changes.
The changes contain the list of changes with the updated and entry point code. Entry point corresponds to the API/Consumer upstream of the function that the change was made in.
The citations contain the list of file names referenced in the changed code and entry point code.

You also have access the the query knowledge graph tool to answer natural language questions about the codebase during the analysis.
Based on the response from the get code changes tool, formulate queries to ask details about specific changed code elements.
1. Frame your query for the knowledge graph tool:
- Identify key concepts, code elements, and implied relationships from the changed code.
- Consider the context from the users query: {query}.
- Determine the intent and key technical terms.
- Transform into keyword phrases that might match docstrings:
* Use concise, functionality-based phrases (e.g., "creates document MongoDB collection").
* Focus on verb-based keywords (e.g., "create", "define", "calculate").
* Include docstring-related keywords like "parameters", "returns", "raises" when relevant.
* Preserve key technical terms from the original query.
* Generate multiple keyword variations to increase matching chances.
* Be specific in keywords to improve match accuracy.
* Ensure the query includes relevant details and follows a similar structure to enhance similarity search results.

2. Execute your formulated query using the knowledge graph tool.

Analyze the changes fetched and explain their impact on the codebase. Consider the following:
1. Which functions or classes have been directly modified?
2. What are the potential side effects of these changes?
3. Are there any dependencies that might be affected?
4. How might these changes impact the overall system behavior?
5. Based on the entry point code, determine which APIs or consumers etc are impacted by the changes.

Refer to the {query} for any specific instructions and follow them.

Based on the analysis, provide a structured inference of the blast radius:
1. Summarize the direct changes
2. List potential indirect effects
3. Identify any critical areas that require careful testing
4. Suggest any necessary refactoring or additional changes to mitigate risks
6. If the changes are impacting multiple APIs/Consumers, then say so.


Ensure that your output ALWAYS follows the structure outlined in the following pydantic model:
{self.BlastRadiusAgentResponse.model_json_schema()}""",
description=task_prompt,
expected_output=f"Comprehensive impact analysis of the code changes on the codebase and answers to the users query about them. Ensure that your output ALWAYS follows the structure outlined in the following pydantic model : {self.BlastRadiusAgentResponse.model_json_schema()}",
agent=blast_radius_agent,
tools=[
Expand Down Expand Up @@ -140,8 +111,8 @@ async def run(
async def kickoff_blast_radius_agent(
query: str, project_id: str, node_ids: List[NodeContext], sql_db, user_id, llm
) -> Dict[str, str]:
provider_service = ProviderService(sql_db, user_id)
crew_ai_mini_llm = provider_service.get_small_llm(agent_type=AgentType.CREWAI)
provider_service = LLMProviderService(sql_db, user_id)
crew_ai_mini_llm = provider_service.get_small_llm(agent_type=AgentLLMType.CREWAI)
blast_radius_agent = BlastRadiusAgent(sql_db, user_id, crew_ai_mini_llm)
result = await blast_radius_agent.run(project_id, node_ids, query)
return result
Loading
Loading