-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Part 1 - Add RAG Embedding Interface (#628)
**Reason for Change**: This series of PR will integrate llamaindex RAG service for Kaito. This first PR introduces the Embedding Model Interface and two classes that implement it `huggingface_local` and `huggingface_remote`. These allows users to specify both local and remote HuggingFace models for text embeddings.
- Loading branch information
1 parent
152e683
commit 1d99028
Showing
6 changed files
with
59 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class BaseEmbeddingModel(ABC): | ||
@abstractmethod | ||
def get_text_embedding(self, text: str): | ||
"""Returns the text embedding for a given input string.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_embedding_dimension(self) -> int: | ||
"""Returns the embedding dimension for the model.""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from llama_index.embeddings.huggingface import HuggingFaceEmbedding | ||
|
||
from .base import BaseEmbeddingModel | ||
|
||
|
||
class LocalHuggingFaceEmbedding(BaseEmbeddingModel): | ||
def __init__(self, model_name: str): | ||
self.model = HuggingFaceEmbedding(model_name=model_name) # TODO: Ensure/test loads on GPU (when available) | ||
|
||
def get_text_embedding(self, text: str): | ||
"""Returns the text embedding for a given input string.""" | ||
return self.model.get_text_embedding(text) | ||
|
||
def get_embedding_dimension(self) -> int: | ||
"""Infers the embedding dimension by making a local call to get the embedding of a dummy text.""" | ||
dummy_input = "This is a dummy sentence." | ||
embedding = self.get_text_embedding(dummy_input) | ||
|
||
return len(embedding) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from llama_index.embeddings.huggingface_api import \ | ||
HuggingFaceInferenceAPIEmbedding | ||
|
||
from .base import BaseEmbeddingModel | ||
|
||
|
||
class RemoteHuggingFaceEmbedding(BaseEmbeddingModel): | ||
def __init__(self, model_name: str, api_key: str): | ||
self.model = HuggingFaceInferenceAPIEmbedding(model_name=model_name, token=api_key) | ||
|
||
def get_text_embedding(self, text: str): | ||
"""Returns the text embedding for a given input string.""" | ||
return self.model.get_text_embedding(text) | ||
|
||
def get_embedding_dimension(self) -> int: | ||
"""Infers the embedding dimension by making a remote call to get the embedding of a dummy text.""" | ||
dummy_input = "This is a dummy sentence." | ||
embedding = self.get_text_embedding(dummy_input) | ||
|
||
return len(embedding) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# RAG Library Requirements | ||
llama-index | ||
llama-index-embeddings-huggingface | ||
fastapi | ||
faiss-cpu | ||
llama-index-vector-stores-faiss | ||
uvicorn |