Skip to content

Commit

Permalink
Fix hierarchy rag strategy and add test (#2273)
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran authored Jun 25, 2024
1 parent 37a3537 commit 5db3bc9
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 85 deletions.
183 changes: 103 additions & 80 deletions nucliadb/src/nucliadb/search/search/chat/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
import copy
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple, cast

Expand All @@ -38,11 +39,13 @@
ImageRagStrategy,
ImageRagStrategyName,
KnowledgeboxFindResults,
PageImageStrategy,
PromptContext,
PromptContextImages,
PromptContextOrder,
RagStrategy,
RagStrategyName,
TableImageStrategy,
)
from nucliadb_protos import resources_pb2
from nucliadb_utils.asyncio_utils import ConcurrentRunner, run_concurrently
Expand Down Expand Up @@ -340,6 +343,96 @@ async def composed_prompt_context(
context[paragraph.id] = _clean_paragraph_text(paragraph)


async def hierarchy_prompt_context(
context: CappedPromptContext,
kbid: str,
ordered_paragraphs: list[FindParagraph],
paragraphs_extra_characters: int = 0,
) -> None:
"""
This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
craft a context with all paragraphs of the same resource grouped together. Moreover, on each group of paragraphs,
it includes the resource title and summary so that the LLM can have a better understanding of the context.
"""
paragraphs_extra_characters = max(paragraphs_extra_characters, 0)
# Make a copy of the ordered paragraphs to avoid modifying the original list, which is returned
# in the response to the user
ordered_paragraphs_copy = copy.deepcopy(ordered_paragraphs)
etcache = paragraphs.ExtractedTextCache()
resources: Dict[str, ExtraCharsParagraph] = {}

# Iterate paragraphs to get extended text
for paragraph in ordered_paragraphs_copy:
rid, field_type, field = paragraph.id.split("/")[:3]
field_path = "/".join([rid, field_type, field])
position = paragraph.id.split("/")[-1]
start, end = position.split("-")
int_start = int(start)
int_end = int(end) + paragraphs_extra_characters
extended_paragraph_text = paragraph.text
if paragraphs_extra_characters > 0:
extended_paragraph_text = await paragraphs.get_paragraph_text(
kbid=kbid,
rid=rid,
field=field_path,
start=int_start,
end=int_end,
extracted_text_cache=etcache,
)
if rid not in resources:
# Get the title and the summary of the resource
title_text = await paragraphs.get_paragraph_text(
kbid=kbid,
rid=rid,
field="/a/title",
start=0,
end=500,
extracted_text_cache=etcache,
)
summary_text = await paragraphs.get_paragraph_text(
kbid=kbid,
rid=rid,
field="/a/summary",
start=0,
end=1000,
extracted_text_cache=etcache,
)
resources[rid] = ExtraCharsParagraph(
title=title_text,
summary=summary_text,
paragraphs=[(paragraph, extended_paragraph_text)],
)
else:
resources[rid].paragraphs.append((paragraph, extended_paragraph_text))

# Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
# extended paragraph text of all the paragraphs in the resource.
for values in resources.values():
title_text = values.title
summary_text = values.summary
first_paragraph = None
text_with_hierarchy = ""
for paragraph, extended_paragraph_text in values.paragraphs:
if first_paragraph is None:
first_paragraph = paragraph
text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
# All paragraphs of the resource are cleared except the first one, which will be the
# one containing the whole hierarchy information
paragraph.text = ""

if first_paragraph is not None:
# The first paragraph is the only one holding the hierarchy information
first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"

# Now that the paragraphs have been modified, we can add them to the context
for paragraph in ordered_paragraphs_copy:
if paragraph.text == "":
# Skip paragraphs that were cleared in the hierarchy expansion
continue
context[paragraph.id] = _clean_paragraph_text(paragraph)
return


class PromptContextBuilder:
"""
Builds the context for the LLM prompt.
Expand Down Expand Up @@ -386,18 +479,18 @@ async def build(
return context, context_order, context_images

async def _build_context_images(self, context: CappedPromptContext) -> None:
flatten_strategies = []
page_count = 5
gather_pages = False
gather_tables = False
if self.image_strategies is not None:
for strategy in self.image_strategies:
flatten_strategies.append(strategy.name)
if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
strategy = cast(PageImageStrategy, strategy)
gather_pages = True
if strategy.count is not None: # type: ignore
page_count = strategy.count # type: ignore
if strategy.name == ImageRagStrategyName.TABLES:
if strategy.count is not None and strategy.count > 0:
page_count = strategy.count
elif strategy.name == ImageRagStrategyName.TABLES:
strategy = cast(TableImageStrategy, strategy)
gather_tables = True

for paragraph in self.ordered_paragraphs:
Expand Down Expand Up @@ -452,10 +545,12 @@ async def _build_context(self, context: CappedPromptContext) -> None:
return

if hierarchy_strategy:
await inject_hierarchy_in_paragraphs(
self.kbid, self.ordered_paragraphs, hierarchy_paragraphs_extended_characters
await hierarchy_prompt_context(
context,
self.kbid,
self.ordered_paragraphs,
hierarchy_paragraphs_extended_characters,
)
await default_prompt_context(context, self.kbid, self.ordered_paragraphs)
return

await composed_prompt_context(
Expand All @@ -474,78 +569,6 @@ class ExtraCharsParagraph:
paragraphs: List[Tuple[FindParagraph, str]]


async def inject_hierarchy_in_paragraphs(
kbid: str, ordered_paragraphs: list[FindParagraph], extra_characters: int
):
"""
This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
modifies the first paragraph of each resource to include the title and summary of the resource, as well as the
extended paragraph text of all the paragraphs in the resource.
NOTE: this is kind of ugly and should be refactored so that, instead of modifying the paragraphs in place,
we simply output a context to be sent to the llm with the desired structure.
"""
etcache = paragraphs.ExtractedTextCache()
resources: Dict[str, ExtraCharsParagraph] = {}
for paragraph in ordered_paragraphs:
rid, field_type, field = paragraph.id.split("/")[:3]
field_path = "/".join([rid, field_type, field])
position = paragraph.id.split("/")[-1]
start, end = position.split("-")
int_start = int(start)
int_end = int(end) + extra_characters

extended_paragraph_text = await paragraphs.get_paragraph_text(
kbid=kbid,
rid=rid,
field=field_path,
start=int_start,
end=int_end,
extracted_text_cache=etcache,
)
if rid not in resources:
title_text = await paragraphs.get_paragraph_text(
kbid=kbid,
rid=rid,
field="/a/title",
start=0,
end=500,
extracted_text_cache=etcache,
)
summary_text = await paragraphs.get_paragraph_text(
kbid=kbid,
rid=rid,
field="/a/summary",
start=0,
end=1000,
extracted_text_cache=etcache,
)
resources[rid] = ExtraCharsParagraph(
title=title_text,
summary=summary_text,
paragraphs=[(paragraph, extended_paragraph_text)],
)
else:
resources[rid].paragraphs.append((paragraph, extended_paragraph_text))

for values in resources.values():
title_text = values.title
summary_text = values.summary
first_paragraph = None
text_with_hierarchy = ""
for paragraph, extended_paragraph_text in values.paragraphs:
if first_paragraph is None:
first_paragraph = paragraph
text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
# All paragraphs of the resource are cleared except the first one, which will be the
# one containing the whole hierarchy information
paragraph.text = ""

if first_paragraph is not None:
# The first paragraph is the only one holding the hierarchy information
first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"


def _clean_paragraph_text(paragraph: FindParagraph) -> str:
text = paragraph.text.strip()
# Do not send highlight marks on prompt context
Expand Down
50 changes: 50 additions & 0 deletions nucliadb/tests/search/unit/search/test_chat_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,53 @@ def test_capped_prompt_context():
context["key1"] = "foo" * int(1e6)

assert context.output == {"key1": "foo" * int(1e6)}


@pytest.mark.asyncio
async def test_hierarchy_promp_context(kb):
with mock.patch(
"nucliadb.search.search.chat.prompt.paragraphs.get_paragraph_text",
side_effect=["Title text", "Summary text"],
):
context = chat_prompt.CappedPromptContext(max_size=int(1e6))
find_results = KnowledgeboxFindResults(
resources={
"r1": FindResource(
id="r1",
fields={
"f/f1": FindField(
paragraphs={
"r1/f/f1/0-10": FindParagraph(
id="r1/f/f1/0-10",
score=10,
score_type=SCORE_TYPE.BM25,
order=0,
text="First Paragraph text",
),
"r1/f/f1/10-20": FindParagraph(
id="r1/f/f1/10-20",
score=8,
score_type=SCORE_TYPE.BM25,
order=1,
text="Second paragraph text",
),
}
)
},
)
},
)
ordered_paragraphs = chat_prompt.get_ordered_paragraphs(find_results)
await chat_prompt.hierarchy_prompt_context(
context,
"kbid",
ordered_paragraphs,
paragraphs_extra_characters=0,
)
assert (
context.output["r1/f/f1/0-10"]
== "DOCUMENT: Title text \n SUMMARY: Summary text \n RESOURCE CONTENT: \n EXTRACTED BLOCK: \n First Paragraph text \n\n \n EXTRACTED BLOCK: \n Second paragraph text" # noqa
)
# Chec that the original text of the paragraphs is preserved
assert ordered_paragraphs[0].text == "First Paragraph text"
assert ordered_paragraphs[1].text == "Second paragraph text"
10 changes: 5 additions & 5 deletions nucliadb_models/src/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,16 +938,16 @@ class FullResourceStrategy(RagStrategy):
name: Literal["full_resource"]
count: Optional[int] = Field(
default=None,
title="Resources",
description="Maximum number of full documents to retrieve",
title="Count",
description="Maximum number of full documents to retrieve. If not specified, all matching documents are retrieved.",
)


class HierarchyResourceStrategy(RagStrategy):
name: Literal["hierarchy"]
count: Optional[int] = Field(
title="Resources",
default=None,
title="Count",
description="Number of extra characters that are added to each matching paragraph when adding to the context.",
)

Expand All @@ -959,9 +959,9 @@ class TableImageStrategy(ImageRagStrategy):
class PageImageStrategy(ImageRagStrategy):
name: Literal["page_image"]
count: Optional[int] = Field(
title="Images",
default=None,
description="How many images to retrieve",
title="Count",
description="Maximum number of images to retrieve from the page. By default, at most 5 images are retrieved.",
)


Expand Down

2 comments on commit 5db3bc9

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 5db3bc9 Previous: 0d03d9f Ratio
tests/search/unit/search/test_fetch.py::test_highligh_error 2223.4664627064126 iter/sec (stddev: 0.000004857220048919636) 2841.0684406726436 iter/sec (stddev: 0.000004954958228416619) 1.28

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 5db3bc9 Previous: 0d03d9f Ratio
tests/search/unit/search/test_fetch.py::test_highligh_error 2079.439129550575 iter/sec (stddev: 0.000002835180817652127) 2841.0684406726436 iter/sec (stddev: 0.000004954958228416619) 1.37

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.