From 962a34684a160dd1f36d33891d6400a919554564 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luis=20Coss=C3=ADo?= Date: Wed, 5 Mar 2025 18:15:18 -0300 Subject: [PATCH] evaluation fixes --- qdrant_client/hybrid/formula.py | 58 ++++++++++++++++++++++----------- 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/qdrant_client/hybrid/formula.py b/qdrant_client/hybrid/formula.py index d44251bea..52159a462 100644 --- a/qdrant_client/hybrid/formula.py +++ b/qdrant_client/hybrid/formula.py @@ -2,7 +2,6 @@ from qdrant_client._pydantic_compat import construct from qdrant_client.http import models from typing import Union, Any -import math import numpy as np from qdrant_client.local.geo import geo_distance @@ -11,8 +10,6 @@ DEFAULT_SCORE = np.float32(0.0) -DEFAULT_BY_ZERO = np.float32(1.0) - def evaluate_expression( expression: models.Expression, @@ -34,13 +31,17 @@ def evaluate_expression( return np.float32(0.0) elif isinstance(expression, models.MultExpression): - result = np.prod( - [ - evaluate_expression(expr, point_id, scores, payload, has_vector, defaults) - for expr in expression.mult - ], - dtype=np.float32, - ) + factors: list[np.float32] = [] + + for expr in expression.mult: + factor = evaluate_expression(expr, point_id, scores, payload, has_vector, defaults) + # Return early if any factor is zero + if factor == np.float32(0.0): + return factor + + factors.append(factor) + + result = np.prod(factors, dtype=np.float32) return np.float32(result) elif isinstance(expression, models.SumExpression): @@ -75,13 +76,17 @@ def evaluate_expression( expression.div.right, point_id, scores, payload, has_vector, defaults ) - if right == 0.0: + if right == np.float32(0.0): if expression.div.by_zero_default is not None: return np.float32(expression.div.by_zero_default) raise_non_finite_error(f"{left}/{right}") with np.errstate(invalid="ignore"): - return left / right + result = left / right + if np.isfinite(result): + return np.float32(result) + + raise_non_finite_error(f"{left}/{right}") elif isinstance(expression, models.SqrtExpression): value = evaluate_expression( @@ -115,7 +120,13 @@ def evaluate_expression( value = evaluate_expression( expression.exp, point_id, scores, payload, has_vector, defaults ) - return np.exp(value, dtype=np.float32) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + exp_value = np.exp(value, dtype=np.float32) + if np.isfinite(exp_value): + return exp_value + + raise_non_finite_error(f"exp({value})") elif isinstance(expression, models.Log10Expression): value = evaluate_expression( @@ -174,8 +185,10 @@ def evaluate_variable( # Get value from payload value = value_by_key(payload, var) - if value is not None and len(value) > 0: + if isinstance(value, list) and len(value) > 0: value = value[0] + + if is_number(value): try: return np.float32(value) except (TypeError, ValueError): @@ -184,11 +197,14 @@ def evaluate_variable( defined_default = defaults.get(var, None) - print(f"defined_default: {defined_default}") - try: - return np.float32(defined_default) - except ValueError: - return DEFAULT_SCORE + if is_number(defined_default): + try: + return np.float32(defined_default) + except (TypeError, ValueError): + # try to get from defaults + pass + + return DEFAULT_SCORE elif isinstance(var, int): # Get score from scores @@ -243,6 +259,10 @@ def raise_non_finite_error(expression: str): raise ValueError(f"The expression {expression} produced a non-finite number") +def is_number(value: Any) -> bool: + return isinstance(value, (int, float)) and not isinstance(value, bool) + + def test_parsing_variable(): assert parse_variable("$score") == 0 assert parse_variable("$score[0]") == 0