Skip to content

Commit

Permalink
Merge pull request #236 from potpie-ai/deepseek
Browse files Browse the repository at this point in the history
Deepseek R1 integration for Agent Interactions
  • Loading branch information
dhirenmathur authored Jan 30, 2025
2 parents c4e5a83 + 24d5f44 commit 27c6b5b
Show file tree
Hide file tree
Showing 22 changed files with 395 additions and 375 deletions.
2 changes: 1 addition & 1 deletion GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Potpie uses Google Secret Manager to securely manage API keys. If you created a
Follow these steps to set up the Secret Manager and Application Default Credentials (ADC) for Potpie:
1. Install gcloud CLI. Follow the official installation guide:
https://cloud.google.com/sdk/docs/install

After installation, initialize gcloud CLI:
```bash
gcloud init
Expand Down
45 changes: 25 additions & 20 deletions app/api/router.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,36 @@
from fastapi import Depends, HTTPException, Header
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from typing import Optional, List
from datetime import datetime
from typing import List, Optional

from fastapi import Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session

from app.core.database import get_db
from app.modules.auth.api_key_service import APIKeyService
from app.modules.conversations.conversation.conversation_controller import ConversationController
from app.modules.conversations.message.message_schema import MessageRequest
from app.modules.parsing.graph_construction.parsing_controller import ParsingController
from app.modules.parsing.graph_construction.parsing_schema import ParsingRequest
from app.modules.utils.APIRouter import APIRouter
from app.modules.conversations.conversation.conversation_controller import (
ConversationController,
)
from app.modules.conversations.conversation.conversation_schema import (
ConversationStatus,
CreateConversationRequest,
CreateConversationResponse,
ConversationStatus,
)
from app.modules.conversations.message.message_schema import MessageRequest
from app.modules.parsing.graph_construction.parsing_controller import ParsingController
from app.modules.parsing.graph_construction.parsing_schema import ParsingRequest
from app.modules.utils.APIRouter import APIRouter

router = APIRouter()


class SimpleConversationRequest(BaseModel):
project_ids: List[str]
agent_ids: List[str]


async def get_api_key_user(
x_api_key: Optional[str] = Header(None),
db: Session = Depends(get_db)
x_api_key: Optional[str] = Header(None), db: Session = Depends(get_db)
) -> dict:
"""Dependency to validate API key and get user info."""
if not x_api_key:
Expand All @@ -35,17 +39,18 @@ async def get_api_key_user(
detail="API key is required",
headers={"WWW-Authenticate": "ApiKey"},
)

user = await APIKeyService.validate_api_key(x_api_key, db)
if not user:
raise HTTPException(
status_code=401,
detail="Invalid API key",
headers={"WWW-Authenticate": "ApiKey"},
)

return user


@router.post("/conversations/", response_model=CreateConversationResponse)
async def create_conversation(
conversation: SimpleConversationRequest,
Expand All @@ -59,12 +64,13 @@ async def create_conversation(
title=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
status=ConversationStatus.ARCHIVED,
project_ids=conversation.project_ids,
agent_ids=conversation.agent_ids
agent_ids=conversation.agent_ids,
)

controller = ConversationController(db, user_id, None)
return await controller.create_conversation(full_request)


@router.post("/parse")
async def parse_directory(
repo_details: ParsingRequest,
Expand All @@ -73,6 +79,7 @@ async def parse_directory(
):
return await ParsingController.parse_directory(repo_details, db, user)


@router.get("/parsing-status/{project_id}")
async def get_parsing_status(
project_id: str,
Expand All @@ -81,6 +88,7 @@ async def get_parsing_status(
):
return await ParsingController.fetch_parsing_status(project_id, db, user)


@router.post("/conversations/{conversation_id}/message/")
async def post_message(
conversation_id: str,
Expand All @@ -89,10 +97,7 @@ async def post_message(
user=Depends(get_api_key_user),
):
if message.content == "" or message.content is None or message.content.isspace():
raise HTTPException(
status_code=400,
detail="Message content cannot be empty"
)
raise HTTPException(status_code=400, detail="Message content cannot be empty")

user_id = user["user_id"]
# Note: email is no longer available with API key auth
Expand Down
2 changes: 1 addition & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from app.api.router import router as potpie_api_router
from app.core.base_model import Base
from app.core.database import SessionLocal, engine
from app.core.models import * # noqa #necessary for models to not give import errors
Expand All @@ -21,7 +22,6 @@
from app.modules.intelligence.provider.provider_router import router as provider_router
from app.modules.intelligence.tools.tool_router import router as tool_router
from app.modules.key_management.secret_manager import router as secret_manager_router
from app.api.router import router as potpie_api_router
from app.modules.parsing.graph_construction.parsing_router import (
router as parsing_router,
)
Expand Down
53 changes: 31 additions & 22 deletions app/modules/auth/api_key_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import hashlib
import os
import secrets
import hashlib
from typing import Optional

from fastapi import HTTPException
from sqlalchemy.orm import Session
from google.cloud import secretmanager
from app.modules.users.user_preferences_model import UserPreferences
from sqlalchemy import text
from sqlalchemy.orm import Session

from app.modules.users.user_preferences_model import UserPreferences


class APIKeyService:
SECRET_PREFIX = "sk-"
Expand All @@ -22,8 +25,7 @@ def get_client_and_project():
project_id = os.environ.get("GCP_PROJECT")
if not project_id:
raise HTTPException(
status_code=500,
detail="GCP_PROJECT environment variable is not set"
status_code=500, detail="GCP_PROJECT environment variable is not set"
)

try:
Expand All @@ -32,7 +34,7 @@ def get_client_and_project():
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to initialize Secret Manager client: {str(e)}"
detail=f"Failed to initialize Secret Manager client: {str(e)}",
)

@staticmethod
Expand All @@ -53,12 +55,14 @@ async def create_api_key(user_id: str, db: Session) -> str:
hashed_key = APIKeyService.hash_api_key(api_key)

# Store hashed key in user preferences
user_pref = db.query(UserPreferences).filter(UserPreferences.user_id == user_id).first()
if not user_pref :
user_pref = (
db.query(UserPreferences).filter(UserPreferences.user_id == user_id).first()
)
if not user_pref:
user_pref = UserPreferences(user_id=user_id, preferences={})
db.add(user_pref)
if "api_key_hash" not in user_pref.preferences:
pref = user_pref.preferences.copy()
pref = user_pref.preferences.copy()
pref["api_key_hash"] = hashed_key
user_pref.preferences = pref
db.commit()
Expand Down Expand Up @@ -86,7 +90,9 @@ async def create_api_key(user_id: str, db: Session) -> str:
if "api_key_hash" in user_pref.preferences:
del user_pref.preferences["api_key_hash"]
db.commit()
raise HTTPException(status_code=500, detail=f"Failed to store API key: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to store API key: {str(e)}"
)

return api_key

Expand All @@ -97,24 +103,26 @@ async def validate_api_key(api_key: str, db: Session) -> Optional[dict]:
return None

hashed_key = APIKeyService.hash_api_key(api_key)

# Find user with matching hashed key using PostgreSQL JSONB operator
user_pref = db.query(UserPreferences).filter(
text("preferences->>'api_key_hash' = :hashed_key")
).params(hashed_key=hashed_key).first()
user_pref = (
db.query(UserPreferences)
.filter(text("preferences->>'api_key_hash' = :hashed_key"))
.params(hashed_key=hashed_key)
.first()
)

if not user_pref:
return None

return {
"user_id": user_pref.user_id,
"auth_type": "api_key"
}
return {"user_id": user_pref.user_id, "auth_type": "api_key"}

@staticmethod
async def revoke_api_key(user_id: str, db: Session) -> bool:
"""Revoke a user's API key."""
user_pref = db.query(UserPreferences).filter(UserPreferences.user_id == user_id).first()
user_pref = (
db.query(UserPreferences).filter(UserPreferences.user_id == user_id).first()
)
if not user_pref:
return False

Expand All @@ -141,7 +149,9 @@ async def revoke_api_key(user_id: str, db: Session) -> bool:
@staticmethod
async def get_api_key(user_id: str, db: Session) -> Optional[str]:
"""Retrieve the existing API key for a user."""
user_pref = db.query(UserPreferences).filter(UserPreferences.user_id == user_id).first()
user_pref = (
db.query(UserPreferences).filter(UserPreferences.user_id == user_id).first()
)
if not user_pref or "api_key_hash" not in user_pref.preferences:
return None

Expand All @@ -157,6 +167,5 @@ async def get_api_key(user_id: str, db: Session) -> Optional[str]:
return response.payload.data.decode("UTF-8")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Failed to retrieve API key: {str(e)}"
status_code=500, detail=f"Failed to retrieve API key: {str(e)}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ async def post_message(
raise HTTPException(status_code=500, detail=str(e))

async def regenerate_last_message(
self, conversation_id: str, node_ids: List[NodeContext] = [], stream: bool = True
self,
conversation_id: str,
node_ids: List[NodeContext] = [],
stream: bool = True,
) -> AsyncGenerator[str, None]:
try:
async for chunk in self.service.regenerate_last_message(
Expand Down
27 changes: 18 additions & 9 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,24 @@ async def classifier_node(self, state: State) -> Command:
"""Classifies the query and routes to appropriate agent"""
if not state.get("query"):
return Command(update={"response": "No query provided"}, goto=END)
agent_list = {agent.id:agent.status for agent in self.available_agents}
agent_list = {agent.id: agent.status for agent in self.available_agents}

# Do not route for custom agents
if (
state["agent_id"] in agent_list
and agent_list[state["agent_id"]] != "SYSTEM"
):
return Command(
update={"agent_id": state["agent_id"]}, goto=state["agent_id"]
)

#Do not route for custom agents
if state["agent_id"] in agent_list and agent_list[state["agent_id"]] != "SYSTEM":
return Command(update={"agent_id": state["agent_id"]}, goto=state["agent_id"])

# Classification using LLM with enhanced prompt
prompt = self.classifier_prompt.format(
query=state["query"],
agent_id=state["agent_id"],
agent_descriptions=self.agent_descriptions,
)

response = await self.llm.ainvoke(prompt)
response = response.content.strip("`")
try:
Expand All @@ -181,8 +186,8 @@ async def classifier_node(self, state: State) -> Command:
update={"agent_id": state["agent_id"]}, goto=state["agent_id"]
)
logger.info(
f"Streaming AI response for conversation {state['conversation_id']} for user {state['user_id']} using agent {agent_id}"
)
f"Streaming AI response for conversation {state['conversation_id']} for user {state['user_id']} using agent {agent_id}"
)
return Command(update={"agent_id": agent_id}, goto=agent_id)

async def agent_node(self, state: State, writer: StreamWriter):
Expand Down Expand Up @@ -550,7 +555,11 @@ async def _update_conversation_title(self, conversation_id: str, new_title: str)
self.sql_db.commit()

async def regenerate_last_message(
self, conversation_id: str, user_id: str, node_ids: List[NodeContext] = [], stream: bool = True
self,
conversation_id: str,
user_id: str,
node_ids: List[NodeContext] = [],
stream: bool = True,
) -> AsyncGenerator[str, None]:
try:
access_level = await self.check_conversation_access(
Expand Down
4 changes: 3 additions & 1 deletion app/modules/conversations/conversations_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ async def regenerate_last_message(
user_id = user["user_id"]
user_email = user["email"]
controller = ConversationController(db, user_id, user_email)
message_stream = controller.regenerate_last_message(conversation_id, request.node_ids, stream)
message_stream = controller.regenerate_last_message(
conversation_id, request.node_ids, stream
)
if stream:
return StreamingResponse(message_stream, media_type="text/event-stream")
else:
Expand Down
5 changes: 1 addition & 4 deletions app/modules/intelligence/agents/agents/unit_test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ async def create_agents(self):
role="Test Plan and Unit Test Expert",
goal="Create test plans and write unit tests based on user requirements",
backstory="You are a seasoned AI test engineer specializing in creating robust test plans and unit tests. You aim to assist users effectively in generating and refining test plans and unit tests, ensuring they are comprehensive and tailored to the user's project requirements.",
tools=[
self.get_code_from_probable_node_name,
self.get_code_from_node_id
],
tools=[self.get_code_from_probable_node_name, self.get_code_from_node_id],
allow_delegation=False,
verbose=True,
llm=self.llm,
Expand Down
9 changes: 9 additions & 0 deletions app/modules/intelligence/provider/provider_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ async def set_global_ai_provider(
status_code=500, detail=f"Error setting AI provider: {str(e)}"
)

async def get_global_ai_provider(self, user_id: str):
try:
provider = await self.service.get_global_ai_provider(user_id)
return {"provider": provider}
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Error getting AI provider: {str(e)}"
)

async def get_preferred_llm(self, user_id: str) -> GetProviderResponse:
try:
preferred_llm, model_type = await self.service.get_preferred_llm(user_id)
Expand Down
10 changes: 10 additions & 0 deletions app/modules/intelligence/provider/provider_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ async def set_global_ai_provider(
user["user_id"], provider_request
)

@staticmethod
@router.get("/get-global-ai-provider/")
async def get_global_ai_provider(
db: Session = Depends(get_db),
user=Depends(AuthService.check_auth),
):
user_id = user["user_id"]
controller = ProviderController(db, user_id)
return await controller.get_global_ai_provider(user_id)

@staticmethod
@router.get("/get-preferred-llm/", response_model=GetProviderResponse)
async def get_preferred_llm(
Expand Down
Loading

0 comments on commit 27c6b5b

Please sign in to comment.