From 6ecad88141d2e2e9928bc6041a68a875f9296795 Mon Sep 17 00:00:00 2001 From: Vishwanath Martur <64204611+vishwamartur@users.noreply.github.com> Date: Sun, 5 Jan 2025 16:32:01 +0530 Subject: [PATCH] Add support for Ollama Related to #188 Add support for Ollama to enable users to run open source models locally. * **Provider Service Integration** - Add Ollama API integration in `app/modules/intelligence/provider/provider_service.py` - Implement method to get Ollama LLM - Update `list_available_llms` method to include Ollama * **Configuration Options** - Add configuration options for Ollama endpoint and model selection in `app/core/config_provider.py` - Update `ConfigProvider` class to include Ollama settings * **Agent Factory and Injector Service** - Add support for Ollama models in `app/modules/intelligence/agents/agent_factory.py` - Implement method to create Ollama agent - Add support for Ollama models in `app/modules/intelligence/agents/agent_injector_service.py` - Implement method to get Ollama agent * **Tool Service** - Add tools for Ollama model support in `app/modules/intelligence/tools/tool_service.py` - Implement methods to interact with Ollama models --- app/core/config_provider.py | 8 ++++++++ .../intelligence/agents/agent_factory.py | 5 +++++ .../agents/agent_injector_service.py | 5 +++++ .../intelligence/provider/provider_service.py | 20 +++++++++++++++++++ .../intelligence/tools/tool_service.py | 11 ++++++++++ 5 files changed, 49 insertions(+) diff --git a/app/core/config_provider.py b/app/core/config_provider.py index 52d1b255..0d1a08c2 100644 --- a/app/core/config_provider.py +++ b/app/core/config_provider.py @@ -13,6 +13,8 @@ def __init__(self): "password": os.getenv("NEO4J_PASSWORD"), } self.github_key = os.getenv("GITHUB_PRIVATE_KEY") + self.ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") + self.ollama_model = os.getenv("OLLAMA_MODEL", "llama2") def get_neo4j_config(self): return self.neo4j_config @@ -20,6 +22,12 @@ def get_neo4j_config(self): def get_github_key(self): return self.github_key + def get_ollama_config(self): + return { + "endpoint": self.ollama_endpoint, + "model": self.ollama_model, + } + def get_demo_repo_list(self): return [ { diff --git a/app/modules/intelligence/agents/agent_factory.py b/app/modules/intelligence/agents/agent_factory.py index a6f9a606..c32aeff9 100644 --- a/app/modules/intelligence/agents/agent_factory.py +++ b/app/modules/intelligence/agents/agent_factory.py @@ -24,6 +24,7 @@ AgentType, ProviderService, ) +from langchain_ollama import Ollama class AgentFactory: @@ -70,6 +71,10 @@ def _create_agent( "code_generation_agent": lambda: CodeGenerationChatAgent( mini_llm, reasoning_llm, self.db ), + "ollama_agent": lambda: Ollama( + base_url=self.provider_service.get_ollama_endpoint(), + model=self.provider_service.get_ollama_model(), + ), } if agent_id in agent_map: diff --git a/app/modules/intelligence/agents/agent_injector_service.py b/app/modules/intelligence/agents/agent_injector_service.py index ca951e87..76476b6e 100644 --- a/app/modules/intelligence/agents/agent_injector_service.py +++ b/app/modules/intelligence/agents/agent_injector_service.py @@ -28,6 +28,7 @@ AgentType, ProviderService, ) +from langchain_ollama import Ollama logger = logging.getLogger(__name__) @@ -59,6 +60,10 @@ def _initialize_agents(self) -> Dict[str, Any]: "code_generation_agent": CodeGenerationChatAgent( mini_llm, reasoning_llm, self.sql_db ), + "ollama_agent": Ollama( + base_url=self.provider_service.get_ollama_endpoint(), + model=self.provider_service.get_ollama_model(), + ), } def get_agent(self, agent_id: str) -> Any: diff --git a/app/modules/intelligence/provider/provider_service.py b/app/modules/intelligence/provider/provider_service.py index 7d216b0f..3e3e3a3e 100644 --- a/app/modules/intelligence/provider/provider_service.py +++ b/app/modules/intelligence/provider/provider_service.py @@ -7,6 +7,7 @@ from langchain_anthropic import ChatAnthropic from langchain_openai.chat_models import ChatOpenAI from portkey_ai import PORTKEY_GATEWAY_URL, createHeaders +from langchain_ollama import Ollama from app.modules.key_management.secret_manager import SecretManager from app.modules.users.user_preferences_model import UserPreferences @@ -44,6 +45,11 @@ async def list_available_llms(self) -> List[ProviderInfo]: name="Anthropic", description="An AI safety-focused company known for models like Claude.", ), + ProviderInfo( + id="ollama", + name="Ollama", + description="A provider for running open source models locally.", + ), ] async def set_global_ai_provider(self, user_id: str, provider: str): @@ -195,6 +201,12 @@ def get_large_llm(self, agent_type: AgentType): default_headers=portkey_headers, ) + elif preferred_provider == "ollama": + logging.info("Initializing Ollama LLM") + ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") + ollama_model = os.getenv("OLLAMA_MODEL", "llama2") + self.llm = Ollama(base_url=ollama_endpoint, model=ollama_model) + else: raise ValueError("Invalid LLM provider selected.") @@ -323,6 +335,12 @@ def get_small_llm(self, agent_type: AgentType): default_headers=portkey_headers, ) + elif preferred_provider == "ollama": + logging.info("Initializing Ollama LLM") + ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434") + ollama_model = os.getenv("OLLAMA_MODEL", "llama2") + self.llm = Ollama(base_url=ollama_endpoint, model=ollama_model) + else: raise ValueError("Invalid LLM provider selected.") @@ -337,6 +355,8 @@ def get_llm_provider_name(self) -> str: return "OpenAI" elif isinstance(llm, ChatAnthropic): return "Anthropic" + elif isinstance(llm, Ollama): + return "Ollama" elif isinstance(llm, LLM): return "OpenAI" if llm.model.split("/")[0] == "openai" else "Anthropic" else: diff --git a/app/modules/intelligence/tools/tool_service.py b/app/modules/intelligence/tools/tool_service.py index 6804efb6..9dbc2c5d 100644 --- a/app/modules/intelligence/tools/tool_service.py +++ b/app/modules/intelligence/tools/tool_service.py @@ -36,6 +36,7 @@ GetNodesFromTags, ) from app.modules.intelligence.tools.tool_schema import ToolInfo +from langchain_ollama import Ollama class ToolService: @@ -65,8 +66,18 @@ def _initialize_tools(self) -> Dict[str, Any]: "get_node_neighbours_from_node_id": GetNodeNeighboursFromNodeIdTool( self.db ), + "ollama_tool": Ollama( + base_url=self._get_ollama_endpoint(), + model=self._get_ollama_model(), + ), } + def _get_ollama_endpoint(self) -> str: + return self.db.query(ConfigProvider).first().get_ollama_config()["endpoint"] + + def _get_ollama_model(self) -> str: + return self.db.query(ConfigProvider).first().get_ollama_config()["model"] + async def run_tool(self, tool_id: str, params: Dict[str, Any]) -> Dict[str, Any]: tool = self.tools.get(tool_id) if not tool: