Skip to content

Commit

Permalink
make deduplicate public
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Mar 21, 2024
1 parent 83b6d76 commit 0a099cd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 36 deletions.
48 changes: 24 additions & 24 deletions backend/server/extraction_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,6 @@ class ExtractResponse(TypedDict):
data: List[Any]


def _deduplicate(
extract_responses: Sequence[ExtractResponse],
) -> ExtractResponse:
"""Deduplicate the results.
The deduplication is done by comparing the serialized JSON of each of the results
and only keeping the unique ones.
"""
unique_extracted = []
seen = set()
for response in extract_responses:
for data_item in response["data"]:
# Serialize the data item for comparison purposes
serialized = json.dumps(data_item, sort_keys=True)
if serialized not in seen:
seen.add(serialized)
unique_extracted.append(data_item)

return {
"data": unique_extracted,
}


def _cast_example_to_dict(example: Example) -> Dict[str, Any]:
"""Cast example record to dictionary."""
return {
Expand Down Expand Up @@ -147,6 +124,29 @@ def _make_prompt_template(
# PUBLIC API


def deduplicate(
extract_responses: Sequence[ExtractResponse],
) -> ExtractResponse:
"""Deduplicate the results.
The deduplication is done by comparing the serialized JSON of each of the results
and only keeping the unique ones.
"""
unique_extracted = []
seen = set()
for response in extract_responses:
for data_item in response["data"]:
# Serialize the data item for comparison purposes
serialized = json.dumps(data_item, sort_keys=True)
if serialized not in seen:
seen.add(serialized)
unique_extracted.append(data_item)

return {
"data": unique_extracted,
}


def get_examples_from_extractor(extractor: Extractor) -> List[Dict[str, Any]]:
"""Get examples from an extractor."""
return [_cast_example_to_dict(example) for example in extractor.examples]
Expand Down Expand Up @@ -206,4 +206,4 @@ async def extract_entire_document(
extraction_requests, {"max_concurrency": 1}
)
# Deduplicate the results
return _deduplicate(extract_responses)
return deduplicate(extract_responses)
10 changes: 3 additions & 7 deletions backend/server/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,22 @@

from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_core.runnables import RunnableLambda
from langchain_openai import OpenAIEmbeddings

from db.models import Extractor
from server.extraction_runnable import (
_deduplicate,
ExtractRequest,
ExtractResponse,
deduplicate,
extraction_runnable,
get_examples_from_extractor,
)


def _make_extract_requests(input_dict: Dict[str, Any]) -> List[ExtractRequest]:
docs = input_dict.pop("text")
return [
ExtractRequest(text=doc.page_content, **input_dict)
for doc in docs
]
return [ExtractRequest(text=doc.page_content, **input_dict) for doc in docs]


async def extract_from_content(
Expand Down Expand Up @@ -69,4 +65,4 @@ async def extract_from_content(
"model_name": model_name,
}
)
return _deduplicate(result)
return deduplicate(result)
10 changes: 5 additions & 5 deletions backend/tests/unit_tests/test_deduplication.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from server.extraction_runnable import ExtractResponse, _deduplicate
from server.extraction_runnable import ExtractResponse, deduplicate


async def test_deduplication_different_resutls() -> None:
"""Test deduplication of extraction results."""
result = _deduplicate(
result = deduplicate(
[
{"data": [{"name": "Chester", "age": 42}]},
{"data": [{"name": "Jane", "age": 42}]},
Expand All @@ -17,7 +17,7 @@ async def test_deduplication_different_resutls() -> None:
)
assert expected == result

result = _deduplicate(
result = deduplicate(
[
{
"data": [
Expand All @@ -44,11 +44,11 @@ async def test_deduplication_different_resutls() -> None:
assert expected == result

# Test with data being a list of strings
result = _deduplicate([{"data": ["1", "2"]}, {"data": ["1", "3"]}])
result = deduplicate([{"data": ["1", "2"]}, {"data": ["1", "3"]}])
expected = ExtractResponse(data=["1", "2", "3"])
assert expected == result

# Test with data being a mix of integer and string
result = _deduplicate([{"data": [1, "2"]}, {"data": ["1", "3"]}])
result = deduplicate([{"data": [1, "2"]}, {"data": ["1", "3"]}])
expected = ExtractResponse(data=[1, "2", "1", "3"])
assert expected == result

0 comments on commit 0a099cd

Please sign in to comment.