Skip to content

Commit

Permalink
Add support for KNN vector similarity search (#513)
Browse files Browse the repository at this point in the history
Co-authored-by: Chayim <chayim@users.noreply.github.com>
  • Loading branch information
Pwuts and chayim authored Jul 12, 2023
1 parent 70f6401 commit 89b6c84
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 9 deletions.
2 changes: 2 additions & 0 deletions aredis_om/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
FindQuery,
HashModel,
JsonModel,
VectorFieldOptions,
KNNExpression,
NotFoundError,
QueryNotSupportedError,
QuerySyntaxError,
Expand Down
2 changes: 2 additions & 0 deletions aredis_om/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Field,
HashModel,
JsonModel,
VectorFieldOptions,
KNNExpression,
NotFoundError,
RedisModel,
)
195 changes: 186 additions & 9 deletions aredis_om/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,24 @@ def tree(self):
return render_tree(self)


@dataclasses.dataclass
class KNNExpression:
k: int
vector_field: ModelField
reference_vector: bytes

def __str__(self):
return f"KNN $K @{self.vector_field.name} $knn_ref_vector"

@property
def query_params(self) -> Dict[str, Union[str, bytes]]:
return {"K": str(self.k), "knn_ref_vector": self.reference_vector}

@property
def score_field(self) -> str:
return f"__{self.vector_field.name}_score"


ExpressionOrNegated = Union[Expression, NegatedExpression]


Expand Down Expand Up @@ -349,8 +367,9 @@ def __init__(
self,
expressions: Sequence[ExpressionOrNegated],
model: Type["RedisModel"],
knn: Optional[KNNExpression] = None,
offset: int = 0,
limit: int = DEFAULT_PAGE_SIZE,
limit: Optional[int] = None,
page_size: int = DEFAULT_PAGE_SIZE,
sort_fields: Optional[List[str]] = None,
nocontent: bool = False,
Expand All @@ -364,13 +383,16 @@ def __init__(

self.expressions = expressions
self.model = model
self.knn = knn
self.offset = offset
self.limit = limit
self.limit = limit or (self.knn.k if self.knn else DEFAULT_PAGE_SIZE)
self.page_size = page_size
self.nocontent = nocontent

if sort_fields:
self.sort_fields = self.validate_sort_fields(sort_fields)
elif self.knn:
self.sort_fields = [self.knn.score_field]
else:
self.sort_fields = []

Expand Down Expand Up @@ -425,11 +447,26 @@ def query(self):
if self._query:
return self._query
self._query = self.resolve_redisearch_query(self.expression)
if self.knn:
self._query = (
self._query
if self._query.startswith("(") or self._query == "*"
else f"({self._query})"
) + f"=>[{self.knn}]"
return self._query

@property
def query_params(self):
params: List[Union[str, bytes]] = []
if self.knn:
params += [attr for kv in self.knn.query_params.items() for attr in kv]
return params

def validate_sort_fields(self, sort_fields: List[str]):
for sort_field in sort_fields:
field_name = sort_field.lstrip("-")
if self.knn and field_name == self.knn.score_field:
continue
if field_name not in self.model.__fields__:
raise QueryNotSupportedError(
f"You tried sort by {field_name}, but that field "
Expand Down Expand Up @@ -728,10 +765,27 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
return result

async def execute(self, exhaust_results=True, return_raw_result=False):
args = ["ft.search", self.model.Meta.index_name, self.query, *self.pagination]
args: List[Union[str, bytes]] = [
"FT.SEARCH",
self.model.Meta.index_name,
self.query,
*self.pagination,
]
if self.sort_fields:
args += self.resolve_redisearch_sort_fields()

if self.query_params:
args += ["PARAMS", str(len(self.query_params))] + self.query_params

if self.knn:
# Ensure DIALECT is at least 2
if "DIALECT" not in args:
args += ["DIALECT", "2"]
else:
i_dialect = args.index("DIALECT") + 1
if int(args[i_dialect]) < 2:
args[i_dialect] = "2"

if self.nocontent:
args.append("NOCONTENT")

Expand Down Expand Up @@ -917,11 +971,13 @@ def __init__(self, default: Any = Undefined, **kwargs: Any) -> None:
sortable = kwargs.pop("sortable", Undefined)
index = kwargs.pop("index", Undefined)
full_text_search = kwargs.pop("full_text_search", Undefined)
vector_options = kwargs.pop("vector_options", None)
super().__init__(default=default, **kwargs)
self.primary_key = primary_key
self.sortable = sortable
self.index = index
self.full_text_search = full_text_search
self.vector_options = vector_options


class RelationshipInfo(Representation):
Expand All @@ -935,6 +991,94 @@ def __init__(
self.link_model = link_model


@dataclasses.dataclass
class VectorFieldOptions:
class ALGORITHM(Enum):
FLAT = "FLAT"
HNSW = "HNSW"

class TYPE(Enum):
FLOAT32 = "FLOAT32"
FLOAT64 = "FLOAT64"

class DISTANCE_METRIC(Enum):
L2 = "L2"
IP = "IP"
COSINE = "COSINE"

algorithm: ALGORITHM
type: TYPE
dimension: int
distance_metric: DISTANCE_METRIC

# Common optional parameters
initial_cap: Optional[int] = None

# Optional parameters for FLAT
block_size: Optional[int] = None

# Optional parameters for HNSW
m: Optional[int] = None
ef_construction: Optional[int] = None
ef_runtime: Optional[int] = None
epsilon: Optional[float] = None

@staticmethod
def flat(
type: TYPE,
dimension: int,
distance_metric: DISTANCE_METRIC,
initial_cap: Optional[int] = None,
block_size: Optional[int] = None,
):
return VectorFieldOptions(
algorithm=VectorFieldOptions.ALGORITHM.FLAT,
type=type,
dimension=dimension,
distance_metric=distance_metric,
initial_cap=initial_cap,
block_size=block_size,
)

@staticmethod
def hnsw(
type: TYPE,
dimension: int,
distance_metric: DISTANCE_METRIC,
initial_cap: Optional[int] = None,
m: Optional[int] = None,
ef_construction: Optional[int] = None,
ef_runtime: Optional[int] = None,
epsilon: Optional[float] = None,
):
return VectorFieldOptions(
algorithm=VectorFieldOptions.ALGORITHM.HNSW,
type=type,
dimension=dimension,
distance_metric=distance_metric,
initial_cap=initial_cap,
m=m,
ef_construction=ef_construction,
ef_runtime=ef_runtime,
epsilon=epsilon,
)

@property
def schema(self):
attr = []
for k, v in vars(self).items():
if k == "algorithm" or v is None:
continue
attr.extend(
[
k.upper() if k != "dimension" else "DIM",
str(v) if not isinstance(v, Enum) else v.name,
]
)

return " ".join([f"VECTOR {self.algorithm.name} {len(attr)}"] + attr)


def Field(
default: Any = Undefined,
*,
Expand Down Expand Up @@ -964,6 +1108,7 @@ def Field(
sortable: Union[bool, UndefinedType] = Undefined,
index: Union[bool, UndefinedType] = Undefined,
full_text_search: Union[bool, UndefinedType] = Undefined,
vector_options: Optional[VectorFieldOptions] = None,
schema_extra: Optional[Dict[str, Any]] = None,
) -> Any:
current_schema_extra = schema_extra or {}
Expand Down Expand Up @@ -991,6 +1136,7 @@ def Field(
sortable=sortable,
index=index,
full_text_search=full_text_search,
vector_options=vector_options,
**current_schema_extra,
)
field_info._validate()
Expand Down Expand Up @@ -1083,6 +1229,10 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
new_class._meta.primary_key = PrimaryKey(
name=field_name, field=field
)
if field.field_info.vector_options:
score_attr = f"_{field_name}_score"
setattr(new_class, score_attr, None)
new_class.__annotations__[score_attr] = Union[float, None]

if not getattr(new_class._meta, "global_key_prefix", None):
new_class._meta.global_key_prefix = getattr(
Expand Down Expand Up @@ -1216,8 +1366,12 @@ def db(cls):
return cls._meta.database

@classmethod
def find(cls, *expressions: Union[Any, Expression]) -> FindQuery:
return FindQuery(expressions=expressions, model=cls)
def find(
cls,
*expressions: Union[Any, Expression],
knn: Optional[KNNExpression] = None,
) -> FindQuery:
return FindQuery(expressions=expressions, knn=knn, model=cls)

@classmethod
def from_redis(cls, res: Any):
Expand All @@ -1237,7 +1391,7 @@ def to_string(s):
for i in range(1, len(res), step):
if res[i + offset] is None:
continue
fields = dict(
fields: Dict[str, str] = dict(
zip(
map(to_string, res[i + offset][::2]),
map(to_string, res[i + offset][1::2]),
Expand All @@ -1247,6 +1401,9 @@ def to_string(s):
if fields.get("$"):
json_fields = json.loads(fields.pop("$"))
doc = cls(**json_fields)
for k, v in fields.items():
if k.startswith("__") and k.endswith("_score"):
setattr(doc, k[1:], float(v))
else:
doc = cls(**fields)

Expand Down Expand Up @@ -1474,7 +1631,13 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
embedded_cls = embedded_cls[0]
schema = cls.schema_for_type(name, embedded_cls, field_info)
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
schema = f"{name} NUMERIC"
vector_options: Optional[VectorFieldOptions] = getattr(
field_info, "vector_options", None
)
if vector_options:
schema = f"{name} {vector_options.schema}"
else:
schema = f"{name} NUMERIC"
elif issubclass(typ, str):
if getattr(field_info, "full_text_search", False) is True:
schema = (
Expand Down Expand Up @@ -1623,10 +1786,22 @@ def schema_for_type(
# Not a class, probably a type annotation
field_is_model = False

vector_options: Optional[VectorFieldOptions] = getattr(
field_info, "vector_options", None
)
try:
is_vector = vector_options and any(
issubclass(get_args(typ)[0], t) for t in NUMERIC_TYPES
)
except IndexError:
raise RedisModelError(
f"Vector field '{name}' must be annotated as a container type"
)

# When we encounter a list or model field, we need to descend
# into the values of the list or the fields of the model to
# find any values marked as indexed.
if is_container_type:
if is_container_type and not is_vector:
field_type = get_origin(typ)
embedded_cls = get_args(typ)
if not embedded_cls:
Expand Down Expand Up @@ -1689,7 +1864,9 @@ def schema_for_type(
)

# TODO: GEO field
if parent_is_container_type or parent_is_model_in_container:
if is_vector and vector_options:
schema = f"{path} AS {index_field_name} {vector_options.schema}"
elif parent_is_container_type or parent_is_model_in_container:
if typ is not str:
raise RedisModelError(
"In this Preview release, list and tuple fields can only "
Expand Down

0 comments on commit 89b6c84

Please sign in to comment.