Skip to content

Commit

Permalink
Revert to commit d472467: remove langchain (potpie-ai#176)
Browse files Browse the repository at this point in the history
  • Loading branch information
dhirenmathur authored Nov 13, 2024
1 parent c8751f7 commit ba80b71
Show file tree
Hide file tree
Showing 47 changed files with 289 additions and 1,074 deletions.
6 changes: 2 additions & 4 deletions .env.template
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ENV= development
OPENAI_API_KEY=
OPENAI_MODEL_REASONING=
# POSTGRES_SERVER=postgresql://postgres:mysecretpassword@host.docker.internal:5432/momentum #for use with wsgl
# POSTGRES_SERVER=postgresql://postgres:mysecretpassword@host.docker.internal:5432/momentum Use this when using WSL
POSTGRES_SERVER=postgresql://postgres:mysecretpassword@localhost:5432/momentum
MONGO_URI= mongodb://127.0.0.1:27017
MONGODB_DB_NAME= momentum
Expand All @@ -26,6 +26,4 @@ EMAIL_FROM_ADDRESS=
RESEND_API_KEY=
ANTHROPIC_API_KEY=
POSTHOG_API_KEY=
POSTHOG_HOST=
POTPIE_PLUS_BASE_URL=http://localhost:8080
POTPIE_PLUS_HMAC_KEY=123
POSTHOG_HOST=
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("conversations", "visibility")
op.execute("DROP TYPE visibility")
# ### end Alembic commands ###
# ### end Alembic commands ###
2 changes: 0 additions & 2 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from app.modules.intelligence.prompts.prompt_router import router as prompt_router
from app.modules.intelligence.prompts.system_prompt_setup import SystemPromptSetup
from app.modules.intelligence.provider.provider_router import router as provider_router
from app.modules.intelligence.tools.tool_router import router as tool_router
from app.modules.key_management.secret_manager import router as secret_manager_router
from app.modules.parsing.graph_construction.parsing_router import (
router as parsing_router,
Expand Down Expand Up @@ -97,7 +96,6 @@ def include_routers(self):
self.app.include_router(agent_router, prefix="/api/v1", tags=["Agents"])

self.app.include_router(provider_router, prefix="/api/v1", tags=["Providers"])
self.app.include_router(tool_router, prefix="/api/v1", tags=["Tools"])

def add_health_check(self):
@self.app.get("/health", tags=["Health"])
Expand Down
42 changes: 1 addition & 41 deletions app/modules/auth/auth_service.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import hashlib
import hmac
import json
import logging
import os
from typing import Union

from dotenv import load_dotenv
import requests
from fastapi import Depends, HTTPException, Request, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from firebase_admin import auth

load_dotenv(override=True)

class AuthService:
def login(self, email, password):
log_prefix = "AuthService::login:"
Expand Down Expand Up @@ -69,41 +64,6 @@ async def check_auth(
)
res.headers["WWW-Authenticate"] = 'Bearer realm="auth_required"'
return decoded_token

@staticmethod
def generate_hmac_signature(message: str) -> str:
"""Generate HMAC signature for a message string"""
hmac_key = AuthService.get_hmac_secret_key()
if not hmac_key:
raise ValueError("HMAC secret key not configured")
hmac_obj = hmac.new(
key=hmac_key,
msg=message.encode("utf-8"),
digestmod=hashlib.sha256
)
return hmac_obj.hexdigest()

@staticmethod
def verify_hmac_signature(payload_body: Union[str, dict], hmac_signature: str) -> bool:
"""Verify HMAC signature matches the payload"""
hmac_key = AuthService.get_hmac_secret_key()
if not hmac_key:
raise ValueError("HMAC secret key not configured")
payload_str = payload_body if isinstance(payload_body, str) else json.dumps(payload_body, sort_keys=True)
expected_signature = hmac.new(
key=hmac_key,
msg=payload_str.encode("utf-8"),
digestmod=hashlib.sha256
).hexdigest()
return hmac.compare_digest(hmac_signature, expected_signature)

@staticmethod
def get_hmac_secret_key() -> bytes:
"""Get HMAC secret key from environment"""
key = os.getenv("POTPIE_PLUS_HMAC_KEY", "")
if not key:
return b""
return key.encode("utf-8")


auth_handler = AuthService()
58 changes: 25 additions & 33 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@
)
from app.modules.github.github_service import GithubService
from app.modules.intelligence.agents.agent_injector_service import AgentInjectorService
from app.modules.intelligence.agents.custom_agents.custom_agents_service import (
CustomAgentsService,
)
from app.modules.intelligence.memory.chat_history_service import ChatHistoryService
from app.modules.intelligence.provider.provider_service import ProviderService
from app.modules.projects.projects_service import ProjectService
Expand Down Expand Up @@ -73,7 +70,6 @@ def __init__(
history_manager: ChatHistoryService,
provider_service: ProviderService,
agent_injector_service: AgentInjectorService,
custom_agent_service: CustomAgentsService,
):
self.sql_db = db
self.user_id = user_id
Expand All @@ -82,15 +78,13 @@ def __init__(
self.history_manager = history_manager
self.provider_service = provider_service
self.agent_injector_service = agent_injector_service
self.custom_agent_service = custom_agent_service

@classmethod
def create(cls, db: Session, user_id: str, user_email: str):
project_service = ProjectService(db)
history_manager = ChatHistoryService(db)
provider_service = ProviderService(db, user_id)
agent_injector_service = AgentInjectorService(db, provider_service, user_id)
custom_agent_service = CustomAgentsService()
agent_injector_service = AgentInjectorService(db, provider_service)
return cls(
db,
user_id,
Expand All @@ -99,7 +93,6 @@ def create(cls, db: Session, user_id: str, user_email: str):
history_manager,
provider_service,
agent_injector_service,
custom_agent_service,
)

async def check_conversation_access(
Expand Down Expand Up @@ -142,7 +135,6 @@ async def create_conversation(
) -> tuple[str, str]:
try:
if not self.agent_injector_service.validate_agent_id(
user_id,
conversation.agent_ids[0]
):
raise ConversationServiceError(
Expand Down Expand Up @@ -282,16 +274,25 @@ async def store_message(
)
await self._update_conversation_title(conversation_id, new_title)

project_id = (
repo_id = (
conversation.project_ids[0] if conversation.project_ids else None
)
if not project_id:
if not repo_id:
raise ConversationServiceError(
"No project associated with this conversation"
)

async for chunk in self._generate_and_stream_ai_response(
message.content, conversation_id, user_id, message.node_ids
agent = self.agent_injector_service.get_agent(conversation.agent_ids[0])
if not agent:
raise ConversationServiceError(
f"Invalid agent_id: {conversation.agent_ids[0]}"
)

logger.info(
f"Running agent for repo_id: {repo_id} conversation_id: {conversation_id}"
)
async for chunk in agent.run(
message.content, repo_id, user_id, conversation.id, message.node_ids
):
yield chunk

Expand Down Expand Up @@ -445,32 +446,23 @@ async def _generate_and_stream_ai_response(
raise ConversationNotFoundError(
f"Conversation with id {conversation_id} not found"
)

agent_id = conversation.agent_ids[0]
project_id = conversation.project_ids[0] if conversation.project_ids else None
agent = self.agent_injector_service.get_agent(conversation.agent_ids[0])
if not agent:
raise ConversationServiceError(
f"Invalid agent_id: {conversation.agent_ids[0]}"
)

try:
agent = self.agent_injector_service.get_agent(agent_id)

logger.info(
f"conversation_id: {conversation_id} Running agent {agent_id} with query: {query}"
f"conversation_id: {conversation_id}Running agent {conversation.agent_ids[0]} with query: {query} "
)

if isinstance(agent, CustomAgentsService):
# Custom agent doesn't support streaming, so we'll yield the entire response at once
response = await agent.run(
agent_id, query, project_id, user_id, conversation.id, node_ids
)
yield response
else:
# For other agents that support streaming
async for chunk in agent.run(
query, project_id, user_id, conversation.id, node_ids
):
async for chunk in agent.run(
query, conversation.project_ids[0], user_id, conversation.id, node_ids
):
if chunk:
yield chunk

logger.info(
f"Generated and streamed AI response for conversation {conversation.id} for user {user_id} using agent {agent_id}"
f"Generated and streamed AI response for conversation {conversation.id} for user {user_id} using agent {conversation.agent_ids[0]}"
)
except Exception as e:
logger.error(
Expand Down
58 changes: 23 additions & 35 deletions app/modules/intelligence/agents/agent_injector_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,62 +3,50 @@

from sqlalchemy.orm import Session

from app.modules.intelligence.agents.chat_agents.code_changes_chat_agent import (
CodeChangesChatAgent,
from app.modules.intelligence.agents.chat_agents.code_changes_agent import (
CodeChangesAgent,
)
from app.modules.intelligence.agents.chat_agents.debugging_chat_agent import DebuggingChatAgent
from app.modules.intelligence.agents.chat_agents.integration_test_chat_agent import (
IntegrationTestChatAgent,
)
from app.modules.intelligence.agents.chat_agents.lld_chat_agent import LLDChatAgent
from app.modules.intelligence.agents.chat_agents.qna_chat_agent import QNAChatAgent
from app.modules.intelligence.agents.chat_agents.unit_test_chat_agent import UnitTestAgent
from app.modules.intelligence.agents.custom_agents.custom_agent import CustomAgent
from app.modules.intelligence.agents.custom_agents.custom_agents_service import (
CustomAgentsService,
from app.modules.intelligence.agents.chat_agents.debugging_agent import DebuggingAgent
from app.modules.intelligence.agents.chat_agents.integration_test_agent import (
IntegrationTestAgent,
)
from app.modules.intelligence.agents.chat_agents.lld_agent import LLDAgent
from app.modules.intelligence.agents.chat_agents.qna_agent import QNAAgent
from app.modules.intelligence.agents.chat_agents.unit_test_agent import UnitTestAgent
from app.modules.intelligence.provider.provider_service import ProviderService

logger = logging.getLogger(__name__)


class AgentInjectorService:
def __init__(self, db: Session, provider_service: ProviderService, user_id: str):
def __init__(self, db: Session, provider_service: ProviderService):
self.sql_db = db
self.provider_service = provider_service
self.custom_agent_service = CustomAgentsService()
self.agents = self._initialize_agents()
self.user_id = user_id

def _initialize_agents(self) -> Dict[str, Any]:
mini_llm = self.provider_service.get_small_llm()
reasoning_llm = self.provider_service.get_large_llm()
return {
"debugging_agent": DebuggingChatAgent(mini_llm, reasoning_llm, self.sql_db),
"codebase_qna_agent": QNAChatAgent(mini_llm, reasoning_llm, self.sql_db),
"debugging_agent": DebuggingAgent(mini_llm, reasoning_llm, self.sql_db),
"codebase_qna_agent": QNAAgent(mini_llm, reasoning_llm, self.sql_db),
"unit_test_agent": UnitTestAgent(mini_llm, reasoning_llm, self.sql_db),
"integration_test_agent": IntegrationTestChatAgent(
"integration_test_agent": IntegrationTestAgent(
mini_llm, reasoning_llm, self.sql_db
),
"code_changes_agent": CodeChangesChatAgent(
"code_changes_agent": CodeChangesAgent(
mini_llm, reasoning_llm, self.sql_db
),
"LLD_agent": LLDChatAgent(mini_llm, reasoning_llm, self.sql_db),
"LLD_agent": LLDAgent(mini_llm, reasoning_llm, self.sql_db),
}

def get_agent(self, agent_id: str) -> Any:
if agent_id in self.agents:
return self.agents[agent_id]
else:
reasoning_llm = self.provider_service.get_large_llm()
return CustomAgent(
llm=reasoning_llm,
db=self.sql_db,
agent_id=agent_id,
user_id=self.user_id,
)

def validate_agent_id(self, user_id: str, agent_id: str) -> bool:
return agent_id in self.agents or self.custom_agent_service.validate_agent(
self.sql_db, user_id, agent_id
)
agent = self.agents.get(agent_id)
if not agent:
logger.error(f"Invalid agent_id: {agent_id}")
raise ValueError(f"Invalid agent_id: {agent_id}")
return agent

def validate_agent_id(self, agent_id: str) -> bool:
logger.info(f"Validating agent_id: {agent_id}")
return agent_id in self.agents
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from pydantic import BaseModel, Field

from app.modules.conversations.message.message_schema import NodeContext
from app.modules.intelligence.tools.change_detection.change_detection_tool import (
from app.modules.intelligence.tools.change_detection.change_detection import (
ChangeDetectionResponse,
get_change_detection_tool,
get_blast_radius_tool,
)
from app.modules.intelligence.tools.kg_based_tools.ask_knowledge_graph_queries_tool import (
get_ask_knowledge_graph_queries_tool,
Expand Down Expand Up @@ -103,12 +103,11 @@ async def create_tasks(
expected_output=f"Comprehensive impact analysis of the code changes on the codebase and answers to the users query about them. Ensure that your output ALWAYS follows the structure outlined in the following pydantic model : {self.BlastRadiusAgentResponse.model_json_schema()}",
agent=blast_radius_agent,
tools=[
get_change_detection_tool(self.user_id),
get_blast_radius_tool(self.user_id),
self.get_nodes_from_tags,
self.ask_knowledge_graph_queries,
],
output_pydantic=self.BlastRadiusAgentResponse,
async_execution=True,
)

return analyze_changes_task
Expand All @@ -135,7 +134,7 @@ async def run(
return result


async def kickoff_blast_radius_agent(
async def kickoff_blast_radius_crew(
query: str, project_id: str, node_ids: List[NodeContext], sql_db, user_id, llm
) -> Dict[str, str]:
blast_radius_agent = BlastRadiusAgent(sql_db, user_id, llm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class RAGResponse(BaseModel):
response: List[NodeResponse]


class DebugRAGAgent:
class DebugAgent:
def __init__(self, sql_db, llm, mini_llm, user_id):
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.max_iter = os.getenv("MAX_ITER", 5)
Expand Down Expand Up @@ -225,7 +225,7 @@ async def run(
return result


async def kickoff_debug_rag_agent(
async def kickoff_debug_crew(
query: str,
project_id: str,
chat_history: List,
Expand All @@ -235,7 +235,7 @@ async def kickoff_debug_rag_agent(
mini_llm,
user_id: str,
) -> str:
debug_agent = DebugRAGAgent(sql_db, llm, mini_llm, user_id)
debug_agent = DebugAgent(sql_db, llm, mini_llm, user_id)
file_structure = await GithubService(sql_db).get_project_structure_async(project_id)
result = await debug_agent.run(
query, project_id, chat_history, node_ids, file_structure
Expand Down
Loading

0 comments on commit ba80b71

Please sign in to comment.