diff --git a/ragengine/config.py b/ragengine/config.py new file mode 100644 index 000000000..bda4b46d3 --- /dev/null +++ b/ragengine/config.py @@ -0,0 +1,20 @@ +# config.py + +# Variables are set via environment variables from the RAGEngine CR +# and exposed to the pod. For example, InferenceURL is specified in the CR and +# passed to the pod via env variables. + +import os + +EMBEDDING_TYPE = os.getenv("EMBEDDING_TYPE", "local") +EMBEDDING_URL = os.getenv("EMBEDDING_URL") + +INFERENCE_URL = os.getenv("INFERENCE_URL", "http://localhost:5000/chat") +INFERENCE_ACCESS_SECRET = os.getenv("AccessSecret", "default-inference-secret") +# RESPONSE_FIELD = os.getenv("RESPONSE_FIELD", "result") + +MODEL_ID = os.getenv("MODEL_ID", "BAAI/bge-small-en-v1.5") +VECTOR_DB_TYPE = os.getenv("VECTOR_DB_TYPE", "faiss") +INDEX_SERVICE_NAME = os.getenv("INDEX_SERVICE_NAME", "default-index-service") +ACCESS_SECRET = os.getenv("ACCESS_SECRET", "default-access-secret") +PERSIST_DIR = "storage" \ No newline at end of file diff --git a/ragengine/inference/__init__.py b/ragengine/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ragengine/inference/inference.py b/ragengine/inference/inference.py new file mode 100644 index 000000000..85d1155ad --- /dev/null +++ b/ragengine/inference/inference.py @@ -0,0 +1,53 @@ +from typing import Any +from llama_index.core.llms import CustomLLM, CompletionResponse, LLMMetadata, CompletionResponseGen +from llama_index.llms.openai import OpenAI +from llama_index.core.llms.callbacks import llm_completion_callback +import requests +from config import INFERENCE_URL, INFERENCE_ACCESS_SECRET #, RESPONSE_FIELD + +class Inference(CustomLLM): + params: dict = {} + + def set_params(self, params: dict) -> None: + self.params = params + + def get_param(self, key, default=None): + return self.params.get(key, default) + + @llm_completion_callback() + def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: + pass + + @llm_completion_callback() + def complete(self, prompt: str, **kwargs) -> CompletionResponse: + try: + if "openai" in INFERENCE_URL: + return self._openai_complete(prompt, **kwargs, **self.params) + else: + return self._custom_api_complete(prompt, **kwargs, **self.params) + finally: + # Clear params after the completion is done + self.params = {} + + def _openai_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + llm = OpenAI( + api_key=INFERENCE_ACCESS_SECRET, + **kwargs # Pass all kwargs directly; kwargs may include model, temperature, max_tokens, etc. + ) + return llm.complete(prompt) + + def _custom_api_complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: + headers = {"Authorization": f"Bearer {INFERENCE_ACCESS_SECRET}"} + data = {"prompt": prompt, **kwargs} + + response = requests.post(INFERENCE_URL, json=data, headers=headers) + response_data = response.json() + + # Dynamically extract the field from the response based on the specified response_field + # completion_text = response_data.get(RESPONSE_FIELD, "No response field found") # not necessary for now + return CompletionResponse(text=str(response_data)) + + @property + def metadata(self) -> LLMMetadata: + """Get LLM metadata.""" + return LLMMetadata()