Skip to content

Commit

Permalink
evaluation fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
coszio committed Mar 5, 2025
1 parent edf2c91 commit 962a346
Showing 1 changed file with 39 additions and 19 deletions.
58 changes: 39 additions & 19 deletions qdrant_client/hybrid/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from qdrant_client._pydantic_compat import construct
from qdrant_client.http import models
from typing import Union, Any
import math
import numpy as np

from qdrant_client.local.geo import geo_distance
Expand All @@ -11,8 +10,6 @@

DEFAULT_SCORE = np.float32(0.0)

DEFAULT_BY_ZERO = np.float32(1.0)


def evaluate_expression(
expression: models.Expression,
Expand All @@ -34,13 +31,17 @@ def evaluate_expression(
return np.float32(0.0)

elif isinstance(expression, models.MultExpression):
result = np.prod(
[
evaluate_expression(expr, point_id, scores, payload, has_vector, defaults)
for expr in expression.mult
],
dtype=np.float32,
)
factors: list[np.float32] = []

for expr in expression.mult:
factor = evaluate_expression(expr, point_id, scores, payload, has_vector, defaults)
# Return early if any factor is zero
if factor == np.float32(0.0):
return factor

factors.append(factor)

result = np.prod(factors, dtype=np.float32)
return np.float32(result)

elif isinstance(expression, models.SumExpression):
Expand Down Expand Up @@ -75,13 +76,17 @@ def evaluate_expression(
expression.div.right, point_id, scores, payload, has_vector, defaults
)

if right == 0.0:
if right == np.float32(0.0):
if expression.div.by_zero_default is not None:
return np.float32(expression.div.by_zero_default)
raise_non_finite_error(f"{left}/{right}")

with np.errstate(invalid="ignore"):
return left / right
result = left / right
if np.isfinite(result):
return np.float32(result)

raise_non_finite_error(f"{left}/{right}")

elif isinstance(expression, models.SqrtExpression):
value = evaluate_expression(
Expand Down Expand Up @@ -115,7 +120,13 @@ def evaluate_expression(
value = evaluate_expression(
expression.exp, point_id, scores, payload, has_vector, defaults
)
return np.exp(value, dtype=np.float32)
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
exp_value = np.exp(value, dtype=np.float32)
if np.isfinite(exp_value):
return exp_value

raise_non_finite_error(f"exp({value})")

elif isinstance(expression, models.Log10Expression):
value = evaluate_expression(
Expand Down Expand Up @@ -174,8 +185,10 @@ def evaluate_variable(
# Get value from payload
value = value_by_key(payload, var)

if value is not None and len(value) > 0:
if isinstance(value, list) and len(value) > 0:
value = value[0]

if is_number(value):
try:
return np.float32(value)
except (TypeError, ValueError):
Expand All @@ -184,11 +197,14 @@ def evaluate_variable(

defined_default = defaults.get(var, None)

print(f"defined_default: {defined_default}")
try:
return np.float32(defined_default)
except ValueError:
return DEFAULT_SCORE
if is_number(defined_default):
try:
return np.float32(defined_default)
except (TypeError, ValueError):
# try to get from defaults
pass

return DEFAULT_SCORE

elif isinstance(var, int):
# Get score from scores
Expand Down Expand Up @@ -243,6 +259,10 @@ def raise_non_finite_error(expression: str):
raise ValueError(f"The expression {expression} produced a non-finite number")


def is_number(value: Any) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)


def test_parsing_variable():
assert parse_variable("$score") == 0
assert parse_variable("$score[0]") == 0
Expand Down

0 comments on commit 962a346

Please sign in to comment.