diff --git a/nucliadb/nucliadb/search/search/chat/ask.py b/nucliadb/nucliadb/search/search/chat/ask.py index 33afe264bf..d823c6e272 100644 --- a/nucliadb/nucliadb/search/search/chat/ask.py +++ b/nucliadb/nucliadb/search/search/chat/ask.py @@ -46,6 +46,7 @@ IncompleteFindResultsError, InvalidQueryError, ) +from nucliadb.search.search.metrics import RAGMetrics from nucliadb.search.search.query import QueryParser from nucliadb.search.utilities import get_predict from nucliadb_models.search import ( @@ -90,6 +91,7 @@ def __init__( prompt_context: PromptContext, prompt_context_order: PromptContextOrder, auditor: ChatAuditor, + metrics: RAGMetrics, ): # Initial attributes self.kbid = kbid @@ -99,7 +101,8 @@ def __init__( self.predict_answer_stream = predict_answer_stream self.prompt_context = prompt_context self.prompt_context_order = prompt_context_order - self.auditor = auditor + self.auditor: ChatAuditor = auditor + self.metrics: RAGMetrics = metrics # Computed from the predict chat answer stream self._answer_text = "" @@ -264,11 +267,12 @@ async def json(self) -> str: async def get_relations_results(self) -> Relations: if self._relations is None: - self._relations = await get_relations_results( - kbid=self.kbid, - text_answer=self._answer_text, - target_shard_replicas=self.ask_request.shards, - ) + with self.metrics.time("relations"): + self._relations = await get_relations_results( + kbid=self.kbid, + text_answer=self._answer_text, + target_shard_replicas=self.ask_request.shards, + ) return self._relations async def _stream_predict_answer_text(self) -> AsyncGenerator[str, None]: @@ -279,10 +283,14 @@ async def _stream_predict_answer_text(self) -> AsyncGenerator[str, None]: This method does not assume any order in the stream of items, but it assumes that at least the answer text is streamed in order. """ + first_answer_chunk_yielded = False async for generative_chunk in self.predict_answer_stream: item = generative_chunk.chunk if isinstance(item, TextGenerativeResponse): self._answer_text += item.text + if not first_answer_chunk_yielded: + self.metrics.record_first_chunk_yielded() + first_answer_chunk_yielded = True yield item.text elif isinstance(item, StatusGenerativeResponse): self._status = item @@ -323,7 +331,7 @@ async def json(self) -> str: answer=NOT_ENOUGH_CONTEXT_ANSWER, retrieval_results=self.find_results, status=AnswerStatusCode.NO_CONTEXT, - ).json(exclude_unset=True) + ).model_dump_json(exclude_unset=True) async def ask( @@ -336,6 +344,7 @@ async def ask( resource: Optional[str] = None, ) -> AskResult: start_time = time() + metrics = RAGMetrics() chat_history = ask_request.context or [] user_context = ask_request.extra_context or [] user_query = ask_request.query @@ -343,14 +352,15 @@ async def ask( # Maybe rephrase the query rephrased_query = None if len(chat_history) > 0 or len(user_context) > 0: - rephrased_query = await rephrase_query( - kbid, - chat_history=chat_history, - query=user_query, - user_id=user_id, - user_context=user_context, - generative_model=ask_request.generative_model, - ) + with metrics.time("rephrase"): + rephrased_query = await rephrase_query( + kbid, + chat_history=chat_history, + query=user_query, + user_id=user_id, + user_context=user_context, + generative_model=ask_request.generative_model, + ) # Retrieval is not needed if we are chatting on a specific # resource and the full_resource strategy is enabled @@ -364,15 +374,17 @@ async def ask( # Maybe do a retrieval query if needs_retrieval: - find_results, query_parser = await get_find_results( - kbid=kbid, - # Prefer the rephrased query if available - query=rephrased_query or user_query, - chat_request=ask_request, - ndb_client=client_type, - user=user_id, - origin=origin, - ) + with metrics.time("retrieval"): + find_results, query_parser = await get_find_results( + kbid=kbid, + # Prefer the rephrased query if available + query=rephrased_query or user_query, + chat_request=ask_request, + ndb_client=client_type, + user=user_id, + origin=origin, + metrics=metrics, + ) if len(find_results.resources) == 0: return NotEnoughContextAskResult(find_results=find_results) @@ -389,23 +401,24 @@ async def ask( ) # Now we build the prompt context - query_parser.max_tokens = ask_request.max_tokens # type: ignore - max_tokens_context = await query_parser.get_max_tokens_context() - prompt_context_builder = PromptContextBuilder( - kbid=kbid, - find_results=find_results, - resource=resource, - user_context=user_context, - strategies=ask_request.rag_strategies, - image_strategies=ask_request.rag_images_strategies, - max_context_characters=tokens_to_chars(max_tokens_context), - visual_llm=await query_parser.get_visual_llm_enabled(), - ) - ( - prompt_context, - prompt_context_order, - prompt_context_images, - ) = await prompt_context_builder.build() + with metrics.time("context_building"): + query_parser.max_tokens = ask_request.max_tokens # type: ignore + max_tokens_context = await query_parser.get_max_tokens_context() + prompt_context_builder = PromptContextBuilder( + kbid=kbid, + find_results=find_results, + resource=resource, + user_context=user_context, + strategies=ask_request.rag_strategies, + image_strategies=ask_request.rag_images_strategies, + max_context_characters=tokens_to_chars(max_tokens_context), + visual_llm=await query_parser.get_visual_llm_enabled(), + ) + ( + prompt_context, + prompt_context_order, + prompt_context_images, + ) = await prompt_context_builder.build() # Parse the user prompt (if any) user_prompt = None @@ -426,10 +439,11 @@ async def ask( max_tokens=query_parser.get_max_tokens_answer(), query_context_images=prompt_context_images, ) - predict = get_predict() - nuclia_learning_id, predict_answer_stream = await predict.chat_query_ndjson( - kbid, chat_model - ) + with metrics.time("stream_start"): + predict = get_predict() + nuclia_learning_id, predict_answer_stream = await predict.chat_query_ndjson( + kbid, chat_model + ) auditor = ChatAuditor( kbid=kbid, @@ -444,7 +458,6 @@ async def ask( query_context=prompt_context, query_context_order=prompt_context_order, ) - return AskResult( kbid=kbid, ask_request=ask_request, @@ -454,6 +467,7 @@ async def ask( prompt_context=prompt_context, prompt_context_order=prompt_context_order, auditor=auditor, + metrics=metrics, ) diff --git a/nucliadb/nucliadb/search/search/chat/query.py b/nucliadb/nucliadb/search/search/chat/query.py index 0d214e51f8..81afa3aa47 100644 --- a/nucliadb/nucliadb/search/search/chat/query.py +++ b/nucliadb/nucliadb/search/search/chat/query.py @@ -31,6 +31,7 @@ from nucliadb.search.search.exceptions import IncompleteFindResultsError from nucliadb.search.search.find import find from nucliadb.search.search.merge import merge_relations_results +from nucliadb.search.search.metrics import RAGMetrics from nucliadb.search.search.query import QueryParser from nucliadb.search.utilities import get_predict from nucliadb_models.search import ( @@ -138,6 +139,7 @@ async def get_find_results( ndb_client: NucliaDBClientType, user: str, origin: str, + metrics: RAGMetrics = RAGMetrics(), ) -> tuple[KnowledgeboxFindResults, QueryParser]: find_request = FindRequest() find_request.resource_filters = chat_request.resource_filters @@ -174,6 +176,7 @@ async def get_find_results( user, origin, generative_model=chat_request.generative_model, + metrics=metrics, ) if incomplete: raise IncompleteFindResultsError() @@ -224,6 +227,7 @@ async def chat( origin: str, resource: Optional[str] = None, ) -> ChatResult: + metrics = RAGMetrics() start_time = time() nuclia_learning_id: Optional[str] = None chat_history = chat_request.context or [] @@ -234,14 +238,15 @@ async def chat( prompt_context_order: PromptContextOrder = {} if len(chat_history) > 0 or len(user_context) > 0: - rephrased_query = await rephrase_query( - kbid, - chat_history=chat_history, - query=user_query, - user_id=user_id, - user_context=user_context, - generative_model=chat_request.generative_model, - ) + with metrics.time("rephrase"): + rephrased_query = await rephrase_query( + kbid, + chat_history=chat_history, + query=user_query, + user_id=user_id, + user_context=user_context, + generative_model=chat_request.generative_model, + ) # Retrieval is not needed if we are chatting on a specific # resource and the full_resource strategy is enabled @@ -254,14 +259,16 @@ async def chat( needs_retrieval = False if needs_retrieval: - find_results, query_parser = await get_find_results( - kbid=kbid, - query=rephrased_query or user_query, - chat_request=chat_request, - ndb_client=client_type, - user=user_id, - origin=origin, - ) + with metrics.time("retrieval"): + find_results, query_parser = await get_find_results( + kbid=kbid, + query=rephrased_query or user_query, + chat_request=chat_request, + ndb_client=client_type, + user=user_id, + origin=origin, + metrics=metrics, + ) status_code = FoundStatusCode() if len(find_results.resources) == 0: # If no resources were found on the retrieval, we return @@ -290,23 +297,24 @@ async def chat( min_score=MinScore(), ) - query_parser.max_tokens = chat_request.max_tokens # type: ignore - max_tokens_context = await query_parser.get_max_tokens_context() - prompt_context_builder = PromptContextBuilder( - kbid=kbid, - find_results=find_results, - resource=resource, - user_context=user_context, - strategies=chat_request.rag_strategies, - image_strategies=chat_request.rag_images_strategies, - max_context_characters=tokens_to_chars(max_tokens_context), - visual_llm=await query_parser.get_visual_llm_enabled(), - ) - ( - prompt_context, - prompt_context_order, - prompt_context_images, - ) = await prompt_context_builder.build() + with metrics.time("context_building"): + query_parser.max_tokens = chat_request.max_tokens # type: ignore + max_tokens_context = await query_parser.get_max_tokens_context() + prompt_context_builder = PromptContextBuilder( + kbid=kbid, + find_results=find_results, + resource=resource, + user_context=user_context, + strategies=chat_request.rag_strategies, + image_strategies=chat_request.rag_images_strategies, + max_context_characters=tokens_to_chars(max_tokens_context), + visual_llm=await query_parser.get_visual_llm_enabled(), + ) + ( + prompt_context, + prompt_context_order, + prompt_context_images, + ) = await prompt_context_builder.build() user_prompt = None if chat_request.prompt is not None: user_prompt = UserPrompt(prompt=chat_request.prompt) @@ -331,6 +339,9 @@ async def _wrapped_stream(): # so we can audit after streamed out answer text_answer = b"" async for chunk in format_generated_answer(predict_generator, status_code): + if text_answer == b"": + # first chunk + metrics.record_first_chunk_yielded() text_answer += chunk yield chunk diff --git a/nucliadb/nucliadb/search/search/find.py b/nucliadb/nucliadb/search/search/find.py index 053b23c60e..ae2cfbd70e 100644 --- a/nucliadb/nucliadb/search/search/find.py +++ b/nucliadb/nucliadb/search/search/find.py @@ -23,6 +23,7 @@ from nucliadb.search.requesters.utils import Method, debug_nodes_info, node_query from nucliadb.search.search.find_merge import find_merge_results +from nucliadb.search.search.metrics import RAGMetrics from nucliadb.search.search.query import QueryParser from nucliadb.search.search.utils import ( min_score_from_payload, @@ -47,6 +48,7 @@ async def find( x_nucliadb_user: str, x_forwarded_for: str, generative_model: Optional[str] = None, + metrics: RAGMetrics = RAGMetrics(), ) -> tuple[KnowledgeboxFindResults, bool, QueryParser]: audit = get_audit() start_time = time() @@ -82,26 +84,30 @@ async def find( generative_model=generative_model, rephrase=item.rephrase, ) - pb_query, incomplete_results, autofilters = await query_parser.parse() - results, query_incomplete_results, queried_nodes = await node_query( - kbid, Method.SEARCH, pb_query, target_shard_replicas=item.shards - ) + with metrics.time("query_parse"): + pb_query, incomplete_results, autofilters = await query_parser.parse() + + with metrics.time("node_query"): + results, query_incomplete_results, queried_nodes = await node_query( + kbid, Method.SEARCH, pb_query, target_shard_replicas=item.shards + ) incomplete_results = incomplete_results or query_incomplete_results # We need to merge - search_results = await find_merge_results( - results, - count=item.page_size, - page=item.page_number, - kbid=kbid, - show=item.show, - field_type_filter=item.field_type_filter, - extracted=item.extracted, - requested_relations=pb_query.relation_subgraph, - min_score_bm25=query_parser.min_score.bm25, - min_score_semantic=query_parser.min_score.semantic, - highlight=item.highlight, - ) + with metrics.time("results_merge"): + search_results = await find_merge_results( + results, + count=item.page_size, + page=item.page_number, + kbid=kbid, + show=item.show, + field_type_filter=item.field_type_filter, + extracted=item.extracted, + requested_relations=pb_query.relation_subgraph, + min_score_bm25=query_parser.min_score.bm25, + min_score_semantic=query_parser.min_score.semantic, + highlight=item.highlight, + ) search_time = time() - start_time if audit is not None: @@ -121,16 +127,30 @@ async def find( search_results.shards = queried_shards search_results.autofilters = autofilters - if search_time > settings.slow_find_log_threshold: + if metrics.elapsed("node_query") > settings.slow_node_query_log_threshold: + logger.warning( + "Slow node query", + extra={ + "kbid": kbid, + "user": x_nucliadb_user, + "client": x_ndb_client, + "query": item.model_dump_json(), + "time": search_time, + "nodes": debug_nodes_info(queried_nodes), + }, + ) + elif search_time > settings.slow_find_log_threshold: logger.warning( - "Slow query", + "Slow find query", extra={ "kbid": kbid, "user": x_nucliadb_user, "client": x_ndb_client, - "query": item.json(), + "query": item.model_dump_json(), "time": search_time, "nodes": debug_nodes_info(queried_nodes), + # Include step times in the log + **{step: metrics.elapsed(step) for step in metrics.steps()}, }, ) diff --git a/nucliadb/nucliadb/search/search/metrics.py b/nucliadb/nucliadb/search/search/metrics.py index 3bd02ce878..9a4ade5c57 100644 --- a/nucliadb/nucliadb/search/search/metrics.py +++ b/nucliadb/nucliadb/search/search/metrics.py @@ -17,6 +17,10 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # +import contextlib +import time +from typing import Optional + from nucliadb_telemetry import metrics merge_observer = metrics.Observer("merge_results", labels={"type": ""}) @@ -24,3 +28,69 @@ query_parse_dependency_observer = metrics.Observer( "query_parse_dependency", labels={"type": ""} ) + +buckets = [ + 0.005, + 0.01, + 0.025, + 0.05, + 0.075, + 0.1, + 0.25, + 0.5, + 0.75, + 1.0, + 2.5, + 5.0, + 7.5, + 10.0, + 30.0, + 60.0, + metrics.INF, +] + +generative_first_chunk_histogram = metrics.Histogram( + name="generative_first_chunk", + buckets=buckets, +) +rag_histogram = metrics.Histogram( + name="rag", + labels={"step": ""}, + buckets=buckets, +) + + +class RAGMetrics: + def __init__(self): + self.global_start = time.monotonic() + self._start_times: dict[str, float] = {} + self._end_times: dict[str, float] = {} + self.first_chunk_yielded_at: Optional[float] = None + + @contextlib.contextmanager + def time(self, step: str): + self._start(step) + try: + yield + finally: + self._end(step) + + def steps(self): + return list(self._start_times.keys()) + + def elapsed(self, step: str) -> float: + return self._end_times[step] - self._start_times[step] + + def record_first_chunk_yielded(self): + self.first_chunk_yielded_at = time.monotonic() + generative_first_chunk_histogram.observe( + self.first_chunk_yielded_at - self.global_start + ) + + def _start(self, step: str): + self._start_times[step] = time.monotonic() + + def _end(self, step: str): + self._end_times[step] = time.monotonic() + elapsed = self.elapsed(step) + rag_histogram.observe(elapsed, labels={"step": step}) diff --git a/nucliadb/nucliadb/search/settings.py b/nucliadb/nucliadb/search/settings.py index ae46cdff30..4285414d9f 100644 --- a/nucliadb/nucliadb/search/settings.py +++ b/nucliadb/nucliadb/search/settings.py @@ -28,7 +28,13 @@ class Settings(DriverSettings): slow_find_log_threshold: float = Field( default=3.0, title="Slow query log threshold", - description="The threshold in seconds for logging slow queries", + description="The threshold in seconds for logging slow find queries", + ) + + slow_node_query_log_threshold: float = Field( + default=2.0, + title="Slow node query log threshold", + description="The threshold in seconds for logging slow node queries", )