diff --git a/libs/community/langchain_community/document_loaders/confluence.py b/libs/community/langchain_community/document_loaders/confluence.py index 70c86e7dce962..263c0c8d31fe2 100644 --- a/libs/community/langchain_community/document_loaders/confluence.py +++ b/libs/community/langchain_community/document_loaders/confluence.py @@ -166,6 +166,7 @@ def __init__( include_archived_content: bool = False, include_attachments: bool = False, include_comments: bool = False, + include_labels: bool = False, content_format: ContentFormat = ContentFormat.STORAGE, limit: Optional[int] = 50, max_pages: Optional[int] = 1000, @@ -181,6 +182,7 @@ def __init__( self.include_archived_content = include_archived_content self.include_attachments = include_attachments self.include_comments = include_comments + self.include_labels = include_labels self.content_format = content_format self.limit = limit self.max_pages = max_pages @@ -327,12 +329,20 @@ def _lazy_load(self, **kwargs: Any) -> Iterator[Document]: ) include_attachments = self._resolve_param("include_attachments", kwargs) include_comments = self._resolve_param("include_comments", kwargs) + include_labels = self._resolve_param("include_labels", kwargs) content_format = self._resolve_param("content_format", kwargs) limit = self._resolve_param("limit", kwargs) max_pages = self._resolve_param("max_pages", kwargs) ocr_languages = self._resolve_param("ocr_languages", kwargs) keep_markdown_format = self._resolve_param("keep_markdown_format", kwargs) keep_newlines = self._resolve_param("keep_newlines", kwargs) + expand = ",".join( + [ + content_format.value, + "version", + *(["metadata.labels"] if include_labels else []), + ] + ) if not space_key and not page_ids and not label and not cql: raise ValueError( @@ -347,13 +357,14 @@ def _lazy_load(self, **kwargs: Any) -> Iterator[Document]: limit=limit, max_pages=max_pages, status="any" if include_archived_content else "current", - expand=f"{content_format.value},version", + expand=expand, ) yield from self.process_pages( pages, include_restricted_content, include_attachments, include_comments, + include_labels, content_format, ocr_languages=ocr_languages, keep_markdown_format=keep_markdown_format, @@ -380,13 +391,14 @@ def _lazy_load(self, **kwargs: Any) -> Iterator[Document]: limit=limit, max_pages=max_pages, include_archived_spaces=include_archived_content, - expand=f"{content_format.value},version", + expand=expand, ) yield from self.process_pages( pages, include_restricted_content, include_attachments, include_comments, + False, # labels are not included in the search results content_format, ocr_languages, keep_markdown_format, @@ -408,7 +420,8 @@ def _lazy_load(self, **kwargs: Any) -> Iterator[Document]: before_sleep=before_sleep_log(logger, logging.WARNING), )(self.confluence.get_page_by_id) page = get_page( - page_id=page_id, expand=f"{content_format.value},version" + page_id=page_id, + expand=expand, ) if not include_restricted_content and not self.is_public_page(page): continue @@ -416,6 +429,7 @@ def _lazy_load(self, **kwargs: Any) -> Iterator[Document]: page, include_attachments, include_comments, + include_labels, content_format, ocr_languages, keep_markdown_format, @@ -498,6 +512,7 @@ def process_pages( include_restricted_content: bool, include_attachments: bool, include_comments: bool, + include_labels: bool, content_format: ContentFormat, ocr_languages: Optional[str] = None, keep_markdown_format: Optional[bool] = False, @@ -511,6 +526,7 @@ def process_pages( page, include_attachments, include_comments, + include_labels, content_format, ocr_languages=ocr_languages, keep_markdown_format=keep_markdown_format, @@ -522,6 +538,7 @@ def process_page( page: dict, include_attachments: bool, include_comments: bool, + include_labels: bool, content_format: ContentFormat, ocr_languages: Optional[str] = None, keep_markdown_format: Optional[bool] = False, @@ -575,10 +592,19 @@ def process_page( ] text = text + "".join(comment_texts) + if include_labels: + labels = [ + label["name"] + for label in page.get("metadata", {}) + .get("labels", {}) + .get("results", []) + ] + metadata = { "title": page["title"], "id": page["id"], "source": self.base_url.strip("/") + page["_links"]["webui"], + **({"labels": labels} if include_labels else {}), } if "version" in page and "when" in page["version"]: diff --git a/libs/community/tests/unit_tests/document_loaders/test_confluence.py b/libs/community/tests/unit_tests/document_loaders/test_confluence.py index feecb1588b571..abb47326beef7 100644 --- a/libs/community/tests/unit_tests/document_loaders/test_confluence.py +++ b/libs/community/tests/unit_tests/document_loaders/test_confluence.py @@ -195,6 +195,36 @@ def test_confluence_loader_when_content_format_and_keep_markdown_format_enabled( assert mock_confluence.cql.call_count == 0 assert mock_confluence.get_page_child_by_type.call_count == 0 + @pytest.mark.requires("markdownify") + def test_confluence_loader_when_include_lables_set_to_true( + self, mock_confluence: MagicMock + ) -> None: + # one response with two pages + mock_confluence.get_all_pages_from_space.return_value = [ + self._get_mock_page("123", include_labels=True), + self._get_mock_page("456", include_labels=False), + ] + mock_confluence.get_all_restrictions_for_content.side_effect = [ + self._get_mock_page_restrictions("123"), + self._get_mock_page_restrictions("456"), + ] + + conflence_loader = self._get_mock_confluence_loader( + mock_confluence, + space_key=self.MOCK_SPACE_KEY, + include_labels=True, + max_pages=2, + ) + + documents = conflence_loader.load() + + assert mock_confluence.get_all_pages_from_space.call_count == 1 + + assert len(documents) == 2 + assert all(isinstance(doc, Document) for doc in documents) + assert documents[0].metadata["labels"] == ["l1", "l2"] + assert documents[1].metadata["labels"] == [] + def _get_mock_confluence_loader( self, mock_confluence: MagicMock, **kwargs: Any ) -> ConfluenceLoader: @@ -208,7 +238,10 @@ def _get_mock_confluence_loader( return confluence_loader def _get_mock_page( - self, page_id: str, content_format: ContentFormat = ContentFormat.STORAGE + self, + page_id: str, + content_format: ContentFormat = ContentFormat.STORAGE, + include_labels: bool = False, ) -> Dict: return { "id": f"{page_id}", @@ -216,6 +249,20 @@ def _get_mock_page( "body": { f"{content_format.name.lower()}": {"value": f"

Content {page_id}

"} }, + **( + { + "metadata": { + "labels": { + "results": [ + {"prefix": "global", "name": "l1", "id": "111"}, + {"prefix": "global", "name": "l2", "id": "222"}, + ] + } + } + if include_labels + else {}, + } + ), "status": "current", "type": "page", "_links": {