Skip to content

Commit

Permalink
Add support for Ollama
Browse files Browse the repository at this point in the history
Related to potpie-ai#188

Add support for Ollama to enable users to run open source models locally.

* **Provider Service Integration**
  - Add Ollama API integration in `app/modules/intelligence/provider/provider_service.py`
  - Implement method to get Ollama LLM
  - Update `list_available_llms` method to include Ollama

* **Configuration Options**
  - Add configuration options for Ollama endpoint and model selection in `app/core/config_provider.py`
  - Update `ConfigProvider` class to include Ollama settings

* **Agent Factory and Injector Service**
  - Add support for Ollama models in `app/modules/intelligence/agents/agent_factory.py`
  - Implement method to create Ollama agent
  - Add support for Ollama models in `app/modules/intelligence/agents/agent_injector_service.py`
  - Implement method to get Ollama agent

* **Tool Service**
  - Add tools for Ollama model support in `app/modules/intelligence/tools/tool_service.py`
  - Implement methods to interact with Ollama models
  • Loading branch information
vishwamartur committed Jan 5, 2025
1 parent 55eb585 commit 6ecad88
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 0 deletions.
8 changes: 8 additions & 0 deletions app/core/config_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,21 @@ def __init__(self):
"password": os.getenv("NEO4J_PASSWORD"),
}
self.github_key = os.getenv("GITHUB_PRIVATE_KEY")
self.ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
self.ollama_model = os.getenv("OLLAMA_MODEL", "llama2")

def get_neo4j_config(self):
return self.neo4j_config

def get_github_key(self):
return self.github_key

def get_ollama_config(self):
return {
"endpoint": self.ollama_endpoint,
"model": self.ollama_model,
}

def get_demo_repo_list(self):
return [
{
Expand Down
5 changes: 5 additions & 0 deletions app/modules/intelligence/agents/agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
AgentType,
ProviderService,
)
from langchain_ollama import Ollama


class AgentFactory:
Expand Down Expand Up @@ -70,6 +71,10 @@ def _create_agent(
"code_generation_agent": lambda: CodeGenerationChatAgent(
mini_llm, reasoning_llm, self.db
),
"ollama_agent": lambda: Ollama(
base_url=self.provider_service.get_ollama_endpoint(),
model=self.provider_service.get_ollama_model(),
),
}

if agent_id in agent_map:
Expand Down
5 changes: 5 additions & 0 deletions app/modules/intelligence/agents/agent_injector_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
AgentType,
ProviderService,
)
from langchain_ollama import Ollama

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -59,6 +60,10 @@ def _initialize_agents(self) -> Dict[str, Any]:
"code_generation_agent": CodeGenerationChatAgent(
mini_llm, reasoning_llm, self.sql_db
),
"ollama_agent": Ollama(
base_url=self.provider_service.get_ollama_endpoint(),
model=self.provider_service.get_ollama_model(),
),
}

def get_agent(self, agent_id: str) -> Any:
Expand Down
20 changes: 20 additions & 0 deletions app/modules/intelligence/provider/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain_anthropic import ChatAnthropic
from langchain_openai.chat_models import ChatOpenAI
from portkey_ai import PORTKEY_GATEWAY_URL, createHeaders
from langchain_ollama import Ollama

from app.modules.key_management.secret_manager import SecretManager
from app.modules.users.user_preferences_model import UserPreferences
Expand Down Expand Up @@ -44,6 +45,11 @@ async def list_available_llms(self) -> List[ProviderInfo]:
name="Anthropic",
description="An AI safety-focused company known for models like Claude.",
),
ProviderInfo(
id="ollama",
name="Ollama",
description="A provider for running open source models locally.",
),
]

async def set_global_ai_provider(self, user_id: str, provider: str):
Expand Down Expand Up @@ -195,6 +201,12 @@ def get_large_llm(self, agent_type: AgentType):
default_headers=portkey_headers,
)

elif preferred_provider == "ollama":
logging.info("Initializing Ollama LLM")
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
ollama_model = os.getenv("OLLAMA_MODEL", "llama2")
self.llm = Ollama(base_url=ollama_endpoint, model=ollama_model)

else:
raise ValueError("Invalid LLM provider selected.")

Expand Down Expand Up @@ -323,6 +335,12 @@ def get_small_llm(self, agent_type: AgentType):
default_headers=portkey_headers,
)

elif preferred_provider == "ollama":
logging.info("Initializing Ollama LLM")
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
ollama_model = os.getenv("OLLAMA_MODEL", "llama2")
self.llm = Ollama(base_url=ollama_endpoint, model=ollama_model)

else:
raise ValueError("Invalid LLM provider selected.")

Expand All @@ -337,6 +355,8 @@ def get_llm_provider_name(self) -> str:
return "OpenAI"
elif isinstance(llm, ChatAnthropic):
return "Anthropic"
elif isinstance(llm, Ollama):
return "Ollama"
elif isinstance(llm, LLM):
return "OpenAI" if llm.model.split("/")[0] == "openai" else "Anthropic"
else:
Expand Down
11 changes: 11 additions & 0 deletions app/modules/intelligence/tools/tool_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
GetNodesFromTags,
)
from app.modules.intelligence.tools.tool_schema import ToolInfo
from langchain_ollama import Ollama


class ToolService:
Expand Down Expand Up @@ -65,8 +66,18 @@ def _initialize_tools(self) -> Dict[str, Any]:
"get_node_neighbours_from_node_id": GetNodeNeighboursFromNodeIdTool(
self.db
),
"ollama_tool": Ollama(
base_url=self._get_ollama_endpoint(),
model=self._get_ollama_model(),
),
}

def _get_ollama_endpoint(self) -> str:
return self.db.query(ConfigProvider).first().get_ollama_config()["endpoint"]

def _get_ollama_model(self) -> str:
return self.db.query(ConfigProvider).first().get_ollama_config()["model"]

async def run_tool(self, tool_id: str, params: Dict[str, Any]) -> Dict[str, Any]:
tool = self.tools.get(tool_id)
if not tool:
Expand Down

0 comments on commit 6ecad88

Please sign in to comment.