Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

community: add include_labels option to ConfluenceLoader #28259

Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions libs/community/langchain_community/document_loaders/confluence.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def __init__(
include_archived_content: bool = False,
include_attachments: bool = False,
include_comments: bool = False,
include_labels: bool = False,
content_format: ContentFormat = ContentFormat.STORAGE,
limit: Optional[int] = 50,
max_pages: Optional[int] = 1000,
Expand All @@ -181,6 +182,7 @@ def __init__(
self.include_archived_content = include_archived_content
self.include_attachments = include_attachments
self.include_comments = include_comments
self.include_labels = include_labels
self.content_format = content_format
self.limit = limit
self.max_pages = max_pages
Expand Down Expand Up @@ -327,12 +329,20 @@ def _lazy_load(self, **kwargs: Any) -> Iterator[Document]:
)
include_attachments = self._resolve_param("include_attachments", kwargs)
include_comments = self._resolve_param("include_comments", kwargs)
include_labels = self._resolve_param("include_labels", kwargs)
content_format = self._resolve_param("content_format", kwargs)
limit = self._resolve_param("limit", kwargs)
max_pages = self._resolve_param("max_pages", kwargs)
ocr_languages = self._resolve_param("ocr_languages", kwargs)
keep_markdown_format = self._resolve_param("keep_markdown_format", kwargs)
keep_newlines = self._resolve_param("keep_newlines", kwargs)
expand = ",".join(
[
content_format.value,
"version",
*(["metadata.labels"] if include_labels else []),
]
)
Comment on lines +339 to +345
Copy link
Contributor Author

@nakamasato nakamasato Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

expand is a comma-separated query parameter. originally hardcoded as f"{content_format.value},version".

I made a variable so we can add more option if necessary in the future as expand parameter supports a lot more values.

Screenshot 2024-11-21 at 22 27 32
(ref)


if not space_key and not page_ids and not label and not cql:
raise ValueError(
Expand All @@ -347,13 +357,14 @@ def _lazy_load(self, **kwargs: Any) -> Iterator[Document]:
limit=limit,
max_pages=max_pages,
status="any" if include_archived_content else "current",
expand=f"{content_format.value},version",
expand=expand,
)
yield from self.process_pages(
pages,
include_restricted_content,
include_attachments,
include_comments,
include_labels,
content_format,
ocr_languages=ocr_languages,
keep_markdown_format=keep_markdown_format,
Expand All @@ -380,13 +391,14 @@ def _lazy_load(self, **kwargs: Any) -> Iterator[Document]:
limit=limit,
max_pages=max_pages,
include_archived_spaces=include_archived_content,
expand=f"{content_format.value},version",
expand=expand,
)
yield from self.process_pages(
pages,
include_restricted_content,
include_attachments,
include_comments,
False, # labels are not included in the search results
content_format,
ocr_languages,
keep_markdown_format,
Expand All @@ -408,14 +420,16 @@ def _lazy_load(self, **kwargs: Any) -> Iterator[Document]:
before_sleep=before_sleep_log(logger, logging.WARNING),
)(self.confluence.get_page_by_id)
page = get_page(
page_id=page_id, expand=f"{content_format.value},version"
page_id=page_id,
expand=expand,
)
if not include_restricted_content and not self.is_public_page(page):
continue
yield self.process_page(
page,
include_attachments,
include_comments,
include_labels,
content_format,
ocr_languages,
keep_markdown_format,
Expand Down Expand Up @@ -498,6 +512,7 @@ def process_pages(
include_restricted_content: bool,
include_attachments: bool,
include_comments: bool,
include_labels: bool,
content_format: ContentFormat,
ocr_languages: Optional[str] = None,
keep_markdown_format: Optional[bool] = False,
Expand All @@ -511,6 +526,7 @@ def process_pages(
page,
include_attachments,
include_comments,
include_labels,
content_format,
ocr_languages=ocr_languages,
keep_markdown_format=keep_markdown_format,
Expand All @@ -522,6 +538,7 @@ def process_page(
page: dict,
include_attachments: bool,
include_comments: bool,
include_labels: bool,
content_format: ContentFormat,
ocr_languages: Optional[str] = None,
keep_markdown_format: Optional[bool] = False,
Expand Down Expand Up @@ -575,10 +592,19 @@ def process_page(
]
text = text + "".join(comment_texts)

if include_labels:
labels = [
label["name"]
for label in page.get("metadata", {})
.get("labels", {})
.get("results", [])
]
Comment on lines +596 to +601
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label would be sth like this: {'prefix': 'global', 'name': 'database', 'id': '111111111'}

Screenshot 2024-11-21 at 22 23 03

ref


metadata = {
"title": page["title"],
"id": page["id"],
"source": self.base_url.strip("/") + page["_links"]["webui"],
**({"labels": labels} if include_labels else {}),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set labels key only when include_labels is set to true

}

if "version" in page and "when" in page["version"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,36 @@ def test_confluence_loader_when_content_format_and_keep_markdown_format_enabled(
assert mock_confluence.cql.call_count == 0
assert mock_confluence.get_page_child_by_type.call_count == 0

@pytest.mark.requires("markdownify")
def test_confluence_loader_when_include_lables_set_to_true(
self, mock_confluence: MagicMock
) -> None:
# one response with two pages
mock_confluence.get_all_pages_from_space.return_value = [
self._get_mock_page("123", include_labels=True),
self._get_mock_page("456", include_labels=False),
]
mock_confluence.get_all_restrictions_for_content.side_effect = [
self._get_mock_page_restrictions("123"),
self._get_mock_page_restrictions("456"),
]

conflence_loader = self._get_mock_confluence_loader(
mock_confluence,
space_key=self.MOCK_SPACE_KEY,
include_labels=True,
max_pages=2,
)

documents = conflence_loader.load()

assert mock_confluence.get_all_pages_from_space.call_count == 1

assert len(documents) == 2
assert all(isinstance(doc, Document) for doc in documents)
assert documents[0].metadata["labels"] == ["l1", "l2"]
assert documents[1].metadata["labels"] == []

def _get_mock_confluence_loader(
self, mock_confluence: MagicMock, **kwargs: Any
) -> ConfluenceLoader:
Expand All @@ -208,14 +238,31 @@ def _get_mock_confluence_loader(
return confluence_loader

def _get_mock_page(
self, page_id: str, content_format: ContentFormat = ContentFormat.STORAGE
self,
page_id: str,
content_format: ContentFormat = ContentFormat.STORAGE,
include_labels: bool = False,
) -> Dict:
return {
"id": f"{page_id}",
"title": f"Page {page_id}",
"body": {
f"{content_format.name.lower()}": {"value": f"<p>Content {page_id}</p>"}
},
**(
{
"metadata": {
"labels": {
"results": [
{"prefix": "global", "name": "l1", "id": "111"},
{"prefix": "global", "name": "l2", "id": "222"},
]
}
}
if include_labels
else {},
}
),
"status": "current",
"type": "page",
"_links": {
Expand Down
Loading