Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Oracle23ai as a vector datasource #5342

Merged
merged 24 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/api-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ jobs:
docker/docker-compose.pgvecto-rs.yaml
docker/docker-compose.pgvector.yaml
docker/docker-compose.chroma.yaml
docker/docker-compose.oracle.yaml
services: |
weaviate
qdrant
Expand All @@ -77,6 +78,7 @@ jobs:
pgvecto-rs
pgvector
chroma
oracle

- name: Test Vector Stores
run: dev/pytest/pytest_vdb.sh
Expand Down Expand Up @@ -145,6 +147,7 @@ jobs:
docker/docker-compose.pgvecto-rs.yaml
docker/docker-compose.pgvector.yaml
docker/docker-compose.chroma.yaml
docker/docker-compose.oracle.yaml
services: |
weaviate
qdrant
Expand All @@ -154,6 +157,7 @@ jobs:
pgvecto-rs
pgvector
chroma
oracle

- name: Test Vector Stores
run: poetry run -C api bash dev/pytest/pytest_vdb.sh
2 changes: 2 additions & 0 deletions api/configs/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from configs.middleware.vdb.chroma_configs import ChromaConfigs
from configs.middleware.vdb.milvus_configs import MilvusConfigs
from configs.middleware.vdb.opensearch_configs import OpenSearchConfigs
from configs.middleware.vdb.oracle_configs import OracleConfigs
from configs.middleware.vdb.pgvector_configs import PGVectorConfigs
from configs.middleware.vdb.pgvectors_configs import PGVectoRSConfigs
from configs.middleware.vdb.qdrant_configs import QdrantConfigs
Expand Down Expand Up @@ -61,5 +62,6 @@ class MiddlewareConfigs(
TencentVectorDBConfigs,
TiDBVectorConfigs,
WeaviateConfigs,
OracleConfigs,
):
pass
34 changes: 34 additions & 0 deletions api/configs/middleware/vdb/oracle_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional

from pydantic import BaseModel, Field, PositiveInt


class OracleConfigs(BaseModel):
"""
ORACLE configs
"""

ORACLE_HOST: Optional[str] = Field(
description='ORACLE host',
default=None,
)

ORACLE_PORT: Optional[PositiveInt] = Field(
description='ORACLE port',
default=None,
)

ORACLE_USER: Optional[str] = Field(
description='ORACLE user',
default=None,
)

ORACLE_PASSWORD: Optional[str] = Field(
description='ORACLE password',
default=None,
)

ORACLE_DATABASE: Optional[str] = Field(
description='ORACLE database',
default=None,
)
4 changes: 2 additions & 2 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ class DatasetRetrievalSettingApi(Resource):
def get(self):
vector_type = current_app.config['VECTOR_STORE']
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH
Expand All @@ -522,7 +522,7 @@ class DatasetRetrievalSettingMockApi(Resource):
@account_initialization_required
def get(self, vector_type):
match vector_type:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCEN:
case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.ORACLE:
return {
'retrieval_method': [
RetrievalMethod.SEMANTIC_SEARCH
Expand Down
Empty file.
239 changes: 239 additions & 0 deletions api/core/rag/datasource/vdb/oracle/oraclevector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
import array
import json
import uuid
from contextlib import contextmanager
from typing import Any

import numpy
import oracledb
from flask import current_app
from pydantic import BaseModel, model_validator

from core.rag.datasource.entity.embedding import Embeddings
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset

oracledb.defaults.fetch_lobs = False


class OracleVectorConfig(BaseModel):
host: str
port: int
user: str
password: str
database: str

@model_validator(mode='before')
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config ORACLE_HOST is required")
if not values["port"]:
raise ValueError("config ORACLE_PORT is required")
if not values["user"]:
raise ValueError("config ORACLE_USER is required")
if not values["password"]:
raise ValueError("config ORACLE_PASSWORD is required")
if not values["database"]:
raise ValueError("config ORACLE_DB is required")
return values


SQL_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS {table_name} (
id varchar2(100)
,text CLOB NOT NULL
,meta JSON
,embedding vector NOT NULL
)
"""


class OracleVector(BaseVector):
def __init__(self, collection_name: str, config: OracleVectorConfig):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"

def get_type(self) -> str:
return VectorType.ORACLE

def numpy_converter_in(self, value):
if value.dtype == numpy.float64:
dtype = "d"
elif value.dtype == numpy.float32:
dtype = "f"
else:
dtype = "b"
return array.array(dtype, value)

def input_type_handler(self, cursor, value, arraysize):
if isinstance(value, numpy.ndarray):
return cursor.var(
oracledb.DB_TYPE_VECTOR,
arraysize=arraysize,
inconverter=self.numpy_converter_in,
)

def numpy_converter_out(self, value):
if value.typecode == "b":
dtype = numpy.int8
elif value.typecode == "f":
dtype = numpy.float32
else:
dtype = numpy.float64
return numpy.array(value, copy=False, dtype=dtype)

def output_type_handler(self, cursor, metadata):
if metadata.type_code is oracledb.DB_TYPE_VECTOR:
return cursor.var(
metadata.type_code,
arraysize=cursor.arraysize,
outconverter=self.numpy_converter_out,
)
def _create_connection_pool(self, config: OracleVectorConfig):
return oracledb.create_pool(user=config.user, password=config.password, dsn="{}:{}/{}".format(config.host, config.port, config.database), min=1, max=50, increment=1)


@contextmanager
def _get_cursor(self):
conn = self.pool.acquire()
conn.inputtypehandler = self.input_type_handler
conn.outputtypehandler = self.output_type_handler
cur = conn.cursor()
try:
yield cur
finally:
cur.close()
conn.commit()
conn.close()

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
return self.add_texts(texts, embeddings)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
values = []
pks = []
for i, doc in enumerate(documents):
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
pks.append(doc_id)
values.append(
(
doc_id,
doc.page_content,
json.dumps(doc.metadata),
#array.array("f", embeddings[i]),
numpy.array(embeddings[i]),
)
)
#print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
with self._get_cursor() as cur:
cur.executemany(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values)
return pks

def text_exists(self, id: str) -> bool:
with self._get_cursor() as cur:
cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
return cur.fetchone() is not None

def get_by_ids(self, ids: list[str]) -> list[Document]:
with self._get_cursor() as cur:
cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
docs = []
for record in cur:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs
#def get_ids_by_metadata_field(self, key: str, value: str):
# with self._get_cursor() as cur:
# cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
# idss = []
# for record in cur:
# idss.append(record[0])
# return idss

#def delete_by_document_id(self, document_id: str):
# ids = self.get_ids_by_metadata_field('doc_id', document_id)
# if len(ids)>0:
# with self._get_cursor() as cur:
# cur.execute(f"delete FROM {self.table_name} d WHERE d.meta.doc_id in '%s'" % ("','".join(ids),))


def delete_by_ids(self, ids: list[str]) -> None:
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))

def delete_by_metadata_field(self, key: str, value: str) -> None:
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.

:param query_vector: The input vector to search for similar items.
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 5)
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name} ORDER BY distance fetch first {top_k} rows only" ,[numpy.array(query_vector)]
)
docs = []
score_threshold = kwargs.get("score_threshold") if kwargs.get("score_threshold") else 0.0
for record in cur:
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# do not support bm25 search
return []

def delete(self) -> None:
with self._get_cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")

def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
lock_name = f"{cache_key}_lock"
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return

with self._get_cursor() as cur:
cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
redis_client.set(collection_exist_cache_key, 1, ex=3600)


class OracleVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> OracleVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(
self.gen_index_struct_dict(VectorType.ORACLE, collection_name))

config = current_app.config
return OracleVector(
collection_name=collection_name,
config=OracleVectorConfig(
host=config.get("ORACLE_HOST"),
port=config.get("ORACLE_PORT"),
user=config.get("ORACLE_USER"),
password=config.get("ORACLE_PASSWORD"),
database=config.get("ORACLE_DATABASE"),
),
)
3 changes: 3 additions & 0 deletions api/core/rag/datasource/vdb/vector_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]:
case VectorType.TENCENT:
from core.rag.datasource.vdb.tencent.tencent_vector import TencentVectorFactory
return TencentVectorFactory
case VectorType.ORACLE:
from core.rag.datasource.vdb.oracle.oraclevector import OracleVectorFactory
return OracleVectorFactory
case VectorType.OPENSEARCH:
from core.rag.datasource.vdb.opensearch.opensearch_vector import OpenSearchVectorFactory
return OpenSearchVectorFactory
Expand Down
1 change: 1 addition & 0 deletions api/core/rag/datasource/vdb/vector_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class VectorType(str, Enum):
WEAVIATE = 'weaviate'
OPENSEARCH = 'opensearch'
TENCENT = 'tencent'
ORACLE = 'oracle'
Loading