Skip to content

Commit

Permalink
Improve metrics for find/chat/ask (#2216)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Jun 4, 2024
1 parent b43122a commit 2ef7862
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 100 deletions.
106 changes: 60 additions & 46 deletions nucliadb/nucliadb/search/search/chat/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
IncompleteFindResultsError,
InvalidQueryError,
)
from nucliadb.search.search.metrics import RAGMetrics
from nucliadb.search.search.query import QueryParser
from nucliadb.search.utilities import get_predict
from nucliadb_models.search import (
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
prompt_context: PromptContext,
prompt_context_order: PromptContextOrder,
auditor: ChatAuditor,
metrics: RAGMetrics,
):
# Initial attributes
self.kbid = kbid
Expand All @@ -99,7 +101,8 @@ def __init__(
self.predict_answer_stream = predict_answer_stream
self.prompt_context = prompt_context
self.prompt_context_order = prompt_context_order
self.auditor = auditor
self.auditor: ChatAuditor = auditor
self.metrics: RAGMetrics = metrics

# Computed from the predict chat answer stream
self._answer_text = ""
Expand Down Expand Up @@ -264,11 +267,12 @@ async def json(self) -> str:

async def get_relations_results(self) -> Relations:
if self._relations is None:
self._relations = await get_relations_results(
kbid=self.kbid,
text_answer=self._answer_text,
target_shard_replicas=self.ask_request.shards,
)
with self.metrics.time("relations"):
self._relations = await get_relations_results(
kbid=self.kbid,
text_answer=self._answer_text,
target_shard_replicas=self.ask_request.shards,
)
return self._relations

async def _stream_predict_answer_text(self) -> AsyncGenerator[str, None]:
Expand All @@ -279,10 +283,14 @@ async def _stream_predict_answer_text(self) -> AsyncGenerator[str, None]:
This method does not assume any order in the stream of items, but it assumes that at least
the answer text is streamed in order.
"""
first_answer_chunk_yielded = False
async for generative_chunk in self.predict_answer_stream:
item = generative_chunk.chunk
if isinstance(item, TextGenerativeResponse):
self._answer_text += item.text
if not first_answer_chunk_yielded:
self.metrics.record_first_chunk_yielded()
first_answer_chunk_yielded = True
yield item.text
elif isinstance(item, StatusGenerativeResponse):
self._status = item
Expand Down Expand Up @@ -323,7 +331,7 @@ async def json(self) -> str:
answer=NOT_ENOUGH_CONTEXT_ANSWER,
retrieval_results=self.find_results,
status=AnswerStatusCode.NO_CONTEXT,
).json(exclude_unset=True)
).model_dump_json(exclude_unset=True)


async def ask(
Expand All @@ -336,21 +344,23 @@ async def ask(
resource: Optional[str] = None,
) -> AskResult:
start_time = time()
metrics = RAGMetrics()
chat_history = ask_request.context or []
user_context = ask_request.extra_context or []
user_query = ask_request.query

# Maybe rephrase the query
rephrased_query = None
if len(chat_history) > 0 or len(user_context) > 0:
rephrased_query = await rephrase_query(
kbid,
chat_history=chat_history,
query=user_query,
user_id=user_id,
user_context=user_context,
generative_model=ask_request.generative_model,
)
with metrics.time("rephrase"):
rephrased_query = await rephrase_query(
kbid,
chat_history=chat_history,
query=user_query,
user_id=user_id,
user_context=user_context,
generative_model=ask_request.generative_model,
)

# Retrieval is not needed if we are chatting on a specific
# resource and the full_resource strategy is enabled
Expand All @@ -364,15 +374,17 @@ async def ask(

# Maybe do a retrieval query
if needs_retrieval:
find_results, query_parser = await get_find_results(
kbid=kbid,
# Prefer the rephrased query if available
query=rephrased_query or user_query,
chat_request=ask_request,
ndb_client=client_type,
user=user_id,
origin=origin,
)
with metrics.time("retrieval"):
find_results, query_parser = await get_find_results(
kbid=kbid,
# Prefer the rephrased query if available
query=rephrased_query or user_query,
chat_request=ask_request,
ndb_client=client_type,
user=user_id,
origin=origin,
metrics=metrics,
)
if len(find_results.resources) == 0:
return NotEnoughContextAskResult(find_results=find_results)

Expand All @@ -389,23 +401,24 @@ async def ask(
)

# Now we build the prompt context
query_parser.max_tokens = ask_request.max_tokens # type: ignore
max_tokens_context = await query_parser.get_max_tokens_context()
prompt_context_builder = PromptContextBuilder(
kbid=kbid,
find_results=find_results,
resource=resource,
user_context=user_context,
strategies=ask_request.rag_strategies,
image_strategies=ask_request.rag_images_strategies,
max_context_characters=tokens_to_chars(max_tokens_context),
visual_llm=await query_parser.get_visual_llm_enabled(),
)
(
prompt_context,
prompt_context_order,
prompt_context_images,
) = await prompt_context_builder.build()
with metrics.time("context_building"):
query_parser.max_tokens = ask_request.max_tokens # type: ignore
max_tokens_context = await query_parser.get_max_tokens_context()
prompt_context_builder = PromptContextBuilder(
kbid=kbid,
find_results=find_results,
resource=resource,
user_context=user_context,
strategies=ask_request.rag_strategies,
image_strategies=ask_request.rag_images_strategies,
max_context_characters=tokens_to_chars(max_tokens_context),
visual_llm=await query_parser.get_visual_llm_enabled(),
)
(
prompt_context,
prompt_context_order,
prompt_context_images,
) = await prompt_context_builder.build()

# Parse the user prompt (if any)
user_prompt = None
Expand All @@ -426,10 +439,11 @@ async def ask(
max_tokens=query_parser.get_max_tokens_answer(),
query_context_images=prompt_context_images,
)
predict = get_predict()
nuclia_learning_id, predict_answer_stream = await predict.chat_query_ndjson(
kbid, chat_model
)
with metrics.time("stream_start"):
predict = get_predict()
nuclia_learning_id, predict_answer_stream = await predict.chat_query_ndjson(
kbid, chat_model
)

auditor = ChatAuditor(
kbid=kbid,
Expand All @@ -444,7 +458,6 @@ async def ask(
query_context=prompt_context,
query_context_order=prompt_context_order,
)

return AskResult(
kbid=kbid,
ask_request=ask_request,
Expand All @@ -454,6 +467,7 @@ async def ask(
prompt_context=prompt_context,
prompt_context_order=prompt_context_order,
auditor=auditor,
metrics=metrics,
)


Expand Down
77 changes: 44 additions & 33 deletions nucliadb/nucliadb/search/search/chat/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from nucliadb.search.search.exceptions import IncompleteFindResultsError
from nucliadb.search.search.find import find
from nucliadb.search.search.merge import merge_relations_results
from nucliadb.search.search.metrics import RAGMetrics
from nucliadb.search.search.query import QueryParser
from nucliadb.search.utilities import get_predict
from nucliadb_models.search import (
Expand Down Expand Up @@ -138,6 +139,7 @@ async def get_find_results(
ndb_client: NucliaDBClientType,
user: str,
origin: str,
metrics: RAGMetrics = RAGMetrics(),
) -> tuple[KnowledgeboxFindResults, QueryParser]:
find_request = FindRequest()
find_request.resource_filters = chat_request.resource_filters
Expand Down Expand Up @@ -174,6 +176,7 @@ async def get_find_results(
user,
origin,
generative_model=chat_request.generative_model,
metrics=metrics,
)
if incomplete:
raise IncompleteFindResultsError()
Expand Down Expand Up @@ -224,6 +227,7 @@ async def chat(
origin: str,
resource: Optional[str] = None,
) -> ChatResult:
metrics = RAGMetrics()
start_time = time()
nuclia_learning_id: Optional[str] = None
chat_history = chat_request.context or []
Expand All @@ -234,14 +238,15 @@ async def chat(
prompt_context_order: PromptContextOrder = {}

if len(chat_history) > 0 or len(user_context) > 0:
rephrased_query = await rephrase_query(
kbid,
chat_history=chat_history,
query=user_query,
user_id=user_id,
user_context=user_context,
generative_model=chat_request.generative_model,
)
with metrics.time("rephrase"):
rephrased_query = await rephrase_query(
kbid,
chat_history=chat_history,
query=user_query,
user_id=user_id,
user_context=user_context,
generative_model=chat_request.generative_model,
)

# Retrieval is not needed if we are chatting on a specific
# resource and the full_resource strategy is enabled
Expand All @@ -254,14 +259,16 @@ async def chat(
needs_retrieval = False

if needs_retrieval:
find_results, query_parser = await get_find_results(
kbid=kbid,
query=rephrased_query or user_query,
chat_request=chat_request,
ndb_client=client_type,
user=user_id,
origin=origin,
)
with metrics.time("retrieval"):
find_results, query_parser = await get_find_results(
kbid=kbid,
query=rephrased_query or user_query,
chat_request=chat_request,
ndb_client=client_type,
user=user_id,
origin=origin,
metrics=metrics,
)
status_code = FoundStatusCode()
if len(find_results.resources) == 0:
# If no resources were found on the retrieval, we return
Expand Down Expand Up @@ -290,23 +297,24 @@ async def chat(
min_score=MinScore(),
)

query_parser.max_tokens = chat_request.max_tokens # type: ignore
max_tokens_context = await query_parser.get_max_tokens_context()
prompt_context_builder = PromptContextBuilder(
kbid=kbid,
find_results=find_results,
resource=resource,
user_context=user_context,
strategies=chat_request.rag_strategies,
image_strategies=chat_request.rag_images_strategies,
max_context_characters=tokens_to_chars(max_tokens_context),
visual_llm=await query_parser.get_visual_llm_enabled(),
)
(
prompt_context,
prompt_context_order,
prompt_context_images,
) = await prompt_context_builder.build()
with metrics.time("context_building"):
query_parser.max_tokens = chat_request.max_tokens # type: ignore
max_tokens_context = await query_parser.get_max_tokens_context()
prompt_context_builder = PromptContextBuilder(
kbid=kbid,
find_results=find_results,
resource=resource,
user_context=user_context,
strategies=chat_request.rag_strategies,
image_strategies=chat_request.rag_images_strategies,
max_context_characters=tokens_to_chars(max_tokens_context),
visual_llm=await query_parser.get_visual_llm_enabled(),
)
(
prompt_context,
prompt_context_order,
prompt_context_images,
) = await prompt_context_builder.build()
user_prompt = None
if chat_request.prompt is not None:
user_prompt = UserPrompt(prompt=chat_request.prompt)
Expand All @@ -331,6 +339,9 @@ async def _wrapped_stream():
# so we can audit after streamed out answer
text_answer = b""
async for chunk in format_generated_answer(predict_generator, status_code):
if text_answer == b"":
# first chunk
metrics.record_first_chunk_yielded()
text_answer += chunk
yield chunk

Expand Down
Loading

3 comments on commit 2ef7862

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 2ef7862 Previous: 08db1e8 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 3125.7514785879184 iter/sec (stddev: 0.0000016026942151304698) 3041.132211072051 iter/sec (stddev: 0.0000011628635974660105) 0.97

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 2ef7862 Previous: 08db1e8 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 2954.56124000865 iter/sec (stddev: 0.000007149226315013667) 3041.132211072051 iter/sec (stddev: 0.0000011628635974660105) 1.03

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 2ef7862 Previous: 08db1e8 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 2999.5501715998803 iter/sec (stddev: 0.000001557408575987662) 3041.132211072051 iter/sec (stddev: 0.0000011628635974660105) 1.01

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.