diff --git a/tests/unit/vertex_rag/conftest.py b/tests/unit/vertex_rag/conftest.py index 6510ac830e..f3047da17f 100644 --- a/tests/unit/vertex_rag/conftest.py +++ b/tests/unit/vertex_rag/conftest.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - # Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -24,11 +22,13 @@ VertexRagDataServiceAsyncClient, VertexRagDataServiceClient, ) -import test_rag_constants as tc +import test_rag_constants_preview import mock import pytest +# -*- coding: utf-8 -*- + _TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) @@ -37,7 +37,7 @@ def google_auth_mock(): with mock.patch.object(auth, "default") as auth_mock: auth_mock.return_value = ( auth_credentials.AnonymousCredentials(), - tc.TEST_PROJECT, + test_rag_constants_preview.TEST_PROJECT, ) yield auth_mock @@ -59,13 +59,17 @@ def rag_data_client_mock(): api_client_mock = mock.Mock(spec=VertexRagDataServiceClient) # get_rag_corpus - api_client_mock.get_rag_corpus.return_value = tc.TEST_GAPIC_RAG_CORPUS + api_client_mock.get_rag_corpus.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS + ) # delete_rag_corpus delete_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) delete_rag_corpus_lro_mock.result.return_value = DeleteRagCorpusRequest() api_client_mock.delete_rag_corpus.return_value = delete_rag_corpus_lro_mock # get_rag_file - api_client_mock.get_rag_file.return_value = tc.TEST_GAPIC_RAG_FILE + api_client_mock.get_rag_file.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_FILE + ) rag_data_client_mock.return_value = api_client_mock yield rag_data_client_mock diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py new file mode 100644 index 0000000000..7bf576c21d --- /dev/null +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -0,0 +1,531 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from google.cloud import aiplatform + +from vertexai.preview.rag import ( + EmbeddingModelConfig, + Pinecone, + RagCorpus, + RagFile, + RagResource, + SharePointSource, + SharePointSources, + SlackChannelsSource, + SlackChannel, + JiraSource, + JiraQuery, + Weaviate, + VertexVectorSearch, + VertexFeatureStore, +) +from google.cloud.aiplatform_v1beta1 import ( + GoogleDriveSource, + RagFileChunkingConfig, + RagFileParsingConfig, + ImportRagFilesConfig, + ImportRagFilesRequest, + ImportRagFilesResponse, + JiraSource as GapicJiraSource, + RagCorpus as GapicRagCorpus, + RagFile as GapicRagFile, + SharePointSources as GapicSharePointSources, + SlackSource as GapicSlackSource, + RagContexts, + RetrieveContextsResponse, + RagVectorDbConfig, +) +from google.cloud.aiplatform_v1beta1.types import api_auth +from google.protobuf import timestamp_pb2 + + +TEST_PROJECT = "test-project" +TEST_PROJECT_NUMBER = "12345678" +TEST_REGION = "us-central1" +TEST_CORPUS_DISPLAY_NAME = "my-corpus-1" +TEST_CORPUS_DISCRIPTION = "My first corpus." +TEST_RAG_CORPUS_ID = "generate-123" +TEST_API_ENDPOINT = "us-central1-" + aiplatform.constants.base.API_BASE_PATH +TEST_RAG_CORPUS_RESOURCE_NAME = f"projects/{TEST_PROJECT_NUMBER}/locations/{TEST_REGION}/ragCorpora/{TEST_RAG_CORPUS_ID}" + +# RagCorpus +TEST_WEAVIATE_HTTP_ENDPOINT = "test.weaviate.com" +TEST_WEAVIATE_COLLECTION_NAME = "test-collection" +TEST_WEAVIATE_API_KEY_SECRET_VERSION = ( + "projects/test-project/secrets/test-secret/versions/1" +) +TEST_WEAVIATE_CONFIG = Weaviate( + weaviate_http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT, + collection_name=TEST_WEAVIATE_COLLECTION_NAME, + api_key=TEST_WEAVIATE_API_KEY_SECRET_VERSION, +) +TEST_PINECONE_INDEX_NAME = "test-pinecone-index" +TEST_PINECONE_API_KEY_SECRET_VERSION = ( + "projects/test-project/secrets/test-secret/versions/1" +) +TEST_PINECONE_CONFIG = Pinecone( + index_name=TEST_PINECONE_INDEX_NAME, + api_key=TEST_PINECONE_API_KEY_SECRET_VERSION, +) +TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT = "test-vector-search-index-endpoint" +TEST_VERTEX_VECTOR_SEARCH_INDEX = "test-vector-search-index" +TEST_VERTEX_VECTOR_SEARCH_CONFIG = VertexVectorSearch( + index_endpoint=TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT, + index=TEST_VERTEX_VECTOR_SEARCH_INDEX, +) +TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME = "test-feature-view-resource-name" +TEST_GAPIC_RAG_CORPUS = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, +) +TEST_GAPIC_RAG_CORPUS.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = ( + "projects/{}/locations/{}/publishers/google/models/textembedding-gecko".format( + TEST_PROJECT, TEST_REGION + ) +) +TEST_GAPIC_RAG_CORPUS_WEAVIATE = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + rag_vector_db_config=RagVectorDbConfig( + weaviate=RagVectorDbConfig.Weaviate( + http_endpoint=TEST_WEAVIATE_HTTP_ENDPOINT, + collection_name=TEST_WEAVIATE_COLLECTION_NAME, + ), + api_auth=api_auth.ApiAuth( + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=TEST_WEAVIATE_API_KEY_SECRET_VERSION + ), + ), + ), +) +TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + rag_vector_db_config=RagVectorDbConfig( + vertex_feature_store=RagVectorDbConfig.VertexFeatureStore( + feature_view_resource_name=TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME + ), + ), +) +TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + rag_vector_db_config=RagVectorDbConfig( + vertex_vector_search=RagVectorDbConfig.VertexVectorSearch( + index_endpoint=TEST_VERTEX_VECTOR_SEARCH_INDEX_ENDPOINT, + index=TEST_VERTEX_VECTOR_SEARCH_INDEX, + ), + ), +) +TEST_GAPIC_RAG_CORPUS_PINECONE = GapicRagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + rag_vector_db_config=RagVectorDbConfig( + pinecone=RagVectorDbConfig.Pinecone(index_name=TEST_PINECONE_INDEX_NAME), + api_auth=api_auth.ApiAuth( + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=TEST_PINECONE_API_KEY_SECRET_VERSION + ), + ), + ), +) +TEST_EMBEDDING_MODEL_CONFIG = EmbeddingModelConfig( + publisher_model="publishers/google/models/textembedding-gecko", +) +TEST_VERTEX_FEATURE_STORE_CONFIG = VertexFeatureStore( + resource_name=TEST_VERTEX_FEATURE_STORE_RESOURCE_NAME, +) +TEST_RAG_CORPUS = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + embedding_model_config=TEST_EMBEDDING_MODEL_CONFIG, +) +TEST_RAG_CORPUS_WEAVIATE = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + vector_db=TEST_WEAVIATE_CONFIG, +) +TEST_RAG_CORPUS_VERTEX_FEATURE_STORE = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + vector_db=TEST_VERTEX_FEATURE_STORE_CONFIG, +) +TEST_RAG_CORPUS_PINECONE = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + vector_db=TEST_PINECONE_CONFIG, +) +TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH = RagCorpus( + name=TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=TEST_CORPUS_DISPLAY_NAME, + description=TEST_CORPUS_DISCRIPTION, + vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG, +) +TEST_PAGE_TOKEN = "test-page-token" + +# RagFiles +TEST_PATH = "usr/home/my_file.txt" +TEST_GCS_PATH = "gs://usr/home/data_dir/" +TEST_FILE_DISPLAY_NAME = "my-file.txt" +TEST_FILE_DESCRIPTION = "my file." +TEST_HEADERS = {"X-Goog-Upload-Protocol": "multipart"} +TEST_UPLOAD_REQUEST_URI = "https://{}/upload/v1beta1/projects/{}/locations/{}/ragCorpora/{}/ragFiles:upload".format( + TEST_API_ENDPOINT, TEST_PROJECT_NUMBER, TEST_REGION, TEST_RAG_CORPUS_ID +) +TEST_RAG_FILE_ID = "generate-456" +TEST_RAG_FILE_RESOURCE_NAME = ( + TEST_RAG_CORPUS_RESOURCE_NAME + f"/ragFiles/{TEST_RAG_FILE_ID}" +) +TEST_UPLOAD_RAG_FILE_RESPONSE_CONTENT = "" +TEST_RAG_FILE_JSON = { + "ragFile": { + "name": TEST_RAG_FILE_RESOURCE_NAME, + "displayName": TEST_FILE_DISPLAY_NAME, + } +} +TEST_RAG_FILE_JSON_ERROR = {"error": {"code": 13}} +TEST_CHUNK_SIZE = 512 +TEST_CHUNK_OVERLAP = 100 +# GCS +TEST_IMPORT_FILES_CONFIG_GCS = ImportRagFilesConfig() +TEST_IMPORT_FILES_CONFIG_GCS.gcs_source.uris = [TEST_GCS_PATH] +TEST_IMPORT_FILES_CONFIG_GCS.rag_file_parsing_config.use_advanced_pdf_parsing = False +TEST_IMPORT_REQUEST_GCS = ImportRagFilesRequest( + parent=TEST_RAG_CORPUS_RESOURCE_NAME, + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_GCS, +) +# Google Drive folders +TEST_DRIVE_FOLDER_ID = "123" +TEST_DRIVE_FOLDER = ( + f"https://drive.google.com/corp/drive/folders/{TEST_DRIVE_FOLDER_ID}" +) +TEST_DRIVE_FOLDER_2 = ( + f"https://drive.google.com/drive/folders/{TEST_DRIVE_FOLDER_ID}?resourcekey=0-eiOT3" +) +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER = ImportRagFilesConfig() +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.google_drive_source.resource_ids = [ + GoogleDriveSource.ResourceId( + resource_id=TEST_DRIVE_FOLDER_ID, + resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER, + ) +] +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER.rag_file_parsing_config.use_advanced_pdf_parsing = ( + False +) +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING = ImportRagFilesConfig() +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.google_drive_source.resource_ids = [ + GoogleDriveSource.ResourceId( + resource_id=TEST_DRIVE_FOLDER_ID, + resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER, + ) +] +TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING.rag_file_parsing_config.use_advanced_pdf_parsing = ( + True +) +TEST_IMPORT_REQUEST_DRIVE_FOLDER = ImportRagFilesRequest( + parent=TEST_RAG_CORPUS_RESOURCE_NAME, + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER, +) +TEST_IMPORT_REQUEST_DRIVE_FOLDER_PARSING = ImportRagFilesRequest( + parent=TEST_RAG_CORPUS_RESOURCE_NAME, + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FOLDER_PARSING, +) +# Google Drive files +TEST_DRIVE_FILE_ID = "456" +TEST_DRIVE_FILE = f"https://drive.google.com/file/d/{TEST_DRIVE_FILE_ID}" +TEST_IMPORT_FILES_CONFIG_DRIVE_FILE = ImportRagFilesConfig( + rag_file_chunking_config=RagFileChunkingConfig( + chunk_size=TEST_CHUNK_SIZE, + chunk_overlap=TEST_CHUNK_OVERLAP, + ), + rag_file_parsing_config=RagFileParsingConfig(use_advanced_pdf_parsing=False), +) +TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.max_embedding_requests_per_min = 800 + +TEST_IMPORT_FILES_CONFIG_DRIVE_FILE.google_drive_source.resource_ids = [ + GoogleDriveSource.ResourceId( + resource_id=TEST_DRIVE_FILE_ID, + resource_type=GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FILE, + ) +] +TEST_IMPORT_REQUEST_DRIVE_FILE = ImportRagFilesRequest( + parent=TEST_RAG_CORPUS_RESOURCE_NAME, + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_DRIVE_FILE, +) + +TEST_IMPORT_RESPONSE = ImportRagFilesResponse(imported_rag_files_count=2) + +TEST_GAPIC_RAG_FILE = GapicRagFile( + name=TEST_RAG_FILE_RESOURCE_NAME, + display_name=TEST_FILE_DISPLAY_NAME, + description=TEST_FILE_DESCRIPTION, +) +TEST_RAG_FILE = RagFile( + name=TEST_RAG_FILE_RESOURCE_NAME, + display_name=TEST_FILE_DISPLAY_NAME, + description=TEST_FILE_DESCRIPTION, +) +# Slack sources +TEST_SLACK_CHANNEL_ID = "123" +TEST_SLACK_CHANNEL_ID_2 = "456" +TEST_SLACK_START_TIME = timestamp_pb2.Timestamp() +TEST_SLACK_START_TIME.GetCurrentTime() +TEST_SLACK_END_TIME = timestamp_pb2.Timestamp() +TEST_SLACK_END_TIME.GetCurrentTime() +TEST_SLACK_API_KEY_SECRET_VERSION = ( + "projects/test-project/secrets/test-secret/versions/1" +) +TEST_SLACK_API_KEY_SECRET_VERSION_2 = ( + "projects/test-project/secrets/test-secret/versions/2" +) +TEST_SLACK_SOURCE = SlackChannelsSource( + channels=[ + SlackChannel( + channel_id=TEST_SLACK_CHANNEL_ID, + api_key=TEST_SLACK_API_KEY_SECRET_VERSION, + start_time=TEST_SLACK_START_TIME, + end_time=TEST_SLACK_END_TIME, + ), + SlackChannel( + channel_id=TEST_SLACK_CHANNEL_ID_2, + api_key=TEST_SLACK_API_KEY_SECRET_VERSION_2, + ), + ], +) +TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE = ImportRagFilesConfig( + rag_file_chunking_config=RagFileChunkingConfig( + chunk_size=TEST_CHUNK_SIZE, + chunk_overlap=TEST_CHUNK_OVERLAP, + ) +) +TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE.slack_source.channels = [ + GapicSlackSource.SlackChannels( + channels=[ + GapicSlackSource.SlackChannels.SlackChannel( + channel_id=TEST_SLACK_CHANNEL_ID, + start_time=TEST_SLACK_START_TIME, + end_time=TEST_SLACK_END_TIME, + ), + ], + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=TEST_SLACK_API_KEY_SECRET_VERSION + ), + ), + GapicSlackSource.SlackChannels( + channels=[ + GapicSlackSource.SlackChannels.SlackChannel( + channel_id=TEST_SLACK_CHANNEL_ID_2, + start_time=None, + end_time=None, + ), + ], + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=TEST_SLACK_API_KEY_SECRET_VERSION_2 + ), + ), +] +TEST_IMPORT_REQUEST_SLACK_SOURCE = ImportRagFilesRequest( + parent=TEST_RAG_CORPUS_RESOURCE_NAME, + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_SLACK_SOURCE, +) +# Jira sources +TEST_JIRA_EMAIL = "test@test.com" +TEST_JIRA_PROJECT = "test-project" +TEST_JIRA_CUSTOM_QUERY = "test-custom-query" +TEST_JIRA_SERVER_URI = "test.atlassian.net" +TEST_JIRA_API_KEY_SECRET_VERSION = ( + "projects/test-project/secrets/test-secret/versions/1" +) +TEST_JIRA_SOURCE = JiraSource( + queries=[ + JiraQuery( + email=TEST_JIRA_EMAIL, + jira_projects=[TEST_JIRA_PROJECT], + custom_queries=[TEST_JIRA_CUSTOM_QUERY], + api_key=TEST_JIRA_API_KEY_SECRET_VERSION, + server_uri=TEST_JIRA_SERVER_URI, + ) + ], +) +TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE = ImportRagFilesConfig( + rag_file_chunking_config=RagFileChunkingConfig( + chunk_size=TEST_CHUNK_SIZE, + chunk_overlap=TEST_CHUNK_OVERLAP, + ) +) +TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE.jira_source.jira_queries = [ + GapicJiraSource.JiraQueries( + custom_queries=[TEST_JIRA_CUSTOM_QUERY], + projects=[TEST_JIRA_PROJECT], + email=TEST_JIRA_EMAIL, + server_uri=TEST_JIRA_SERVER_URI, + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=TEST_JIRA_API_KEY_SECRET_VERSION + ), + ) +] +TEST_IMPORT_REQUEST_JIRA_SOURCE = ImportRagFilesRequest( + parent=TEST_RAG_CORPUS_RESOURCE_NAME, + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_JIRA_SOURCE, +) + +# SharePoint sources +TEST_SHARE_POINT_SOURCE = SharePointSources( + share_point_sources=[ + SharePointSource( + sharepoint_folder_path="test-sharepoint-folder-path", + drive_name="test-drive-name", + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + sharepoint_site_name="test-sharepoint-site-name", + ) + ], +) +TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE = ImportRagFilesConfig( + rag_file_chunking_config=RagFileChunkingConfig( + chunk_size=TEST_CHUNK_SIZE, + chunk_overlap=TEST_CHUNK_OVERLAP, + ), + share_point_sources=GapicSharePointSources( + share_point_sources=[ + GapicSharePointSources.SharePointSource( + sharepoint_folder_path="test-sharepoint-folder-path", + drive_name="test-drive-name", + client_id="test-client-id", + client_secret=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version="test-client-secret" + ), + tenant_id="test-tenant-id", + sharepoint_site_name="test-sharepoint-site-name", + ) + ] + ), +) + +TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE = ImportRagFilesRequest( + parent=TEST_RAG_CORPUS_RESOURCE_NAME, + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE, +) + +TEST_SHARE_POINT_SOURCE_2_DRIVES = SharePointSources( + share_point_sources=[ + SharePointSource( + sharepoint_folder_path="test-sharepoint-folder-path", + drive_name="test-drive-name", + drive_id="test-drive-id", + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + sharepoint_site_name="test-sharepoint-site-name", + ) + ], +) + +TEST_SHARE_POINT_SOURCE_NO_DRIVES = SharePointSources( + share_point_sources=[ + SharePointSource( + sharepoint_folder_path="test-sharepoint-folder-path", + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + sharepoint_site_name="test-sharepoint-site-name", + ) + ], +) + +TEST_SHARE_POINT_SOURCE_2_FOLDERS = SharePointSources( + share_point_sources=[ + SharePointSource( + sharepoint_folder_path="test-sharepoint-folder-path", + sharepoint_folder_id="test-sharepoint-folder-id", + drive_name="test-drive-name", + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + sharepoint_site_name="test-sharepoint-site-name", + ) + ], +) + +TEST_SHARE_POINT_SOURCE_NO_FOLDERS = SharePointSources( + share_point_sources=[ + SharePointSource( + drive_name="test-drive-name", + client_id="test-client-id", + client_secret="test-client-secret", + tenant_id="test-tenant-id", + sharepoint_site_name="test-sharepoint-site-name", + ) + ], +) + +TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesConfig( + rag_file_chunking_config=RagFileChunkingConfig( + chunk_size=TEST_CHUNK_SIZE, + chunk_overlap=TEST_CHUNK_OVERLAP, + ), + share_point_sources=GapicSharePointSources( + share_point_sources=[ + GapicSharePointSources.SharePointSource( + drive_name="test-drive-name", + client_id="test-client-id", + client_secret=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version="test-client-secret" + ), + tenant_id="test-tenant-id", + sharepoint_site_name="test-sharepoint-site-name", + ) + ] + ), +) + +TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE_NO_FOLDERS = ImportRagFilesRequest( + parent=TEST_RAG_CORPUS_RESOURCE_NAME, + import_rag_files_config=TEST_IMPORT_FILES_CONFIG_SHARE_POINT_SOURCE, +) + +# Retrieval +TEST_QUERY_TEXT = "What happen to the fox and the dog?" +TEST_CONTEXTS = RagContexts( + contexts=[ + RagContexts.Context( + source_uri="https://drive.google.com/file/d/123/view?usp=drivesdk", + text="The quick brown fox jumps over the lazy dog.", + ), + RagContexts.Context(text="The slow red fox jumps over the lazy dog."), + ] +) +TEST_RETRIEVAL_RESPONSE = RetrieveContextsResponse(contexts=TEST_CONTEXTS) +TEST_RAG_RESOURCE = RagResource( + rag_corpus=TEST_RAG_CORPUS_RESOURCE_NAME, + rag_file_ids=[TEST_RAG_FILE_ID], +) +TEST_RAG_RESOURCE_INVALID_NAME = RagResource( + rag_corpus="213lkj-1/23jkl/", + rag_file_ids=[TEST_RAG_FILE_ID], +) diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py new file mode 100644 index 0000000000..588c7eab96 --- /dev/null +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -0,0 +1,852 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +from google.api_core import operation as ga_operation +from vertexai.preview import rag +from vertexai.preview.rag.utils._gapic_utils import ( + prepare_import_files_request, + set_embedding_model_config, +) +from google.cloud.aiplatform_v1beta1 import ( + VertexRagDataServiceAsyncClient, + VertexRagDataServiceClient, + ListRagCorporaResponse, + ListRagFilesResponse, +) +from google.cloud import aiplatform +import mock +from unittest.mock import patch +import pytest +import test_rag_constants_preview + + +@pytest.fixture +def create_rag_corpus_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS + ) + create_rag_corpus_mock.return_value = create_rag_corpus_lro_mock + yield create_rag_corpus_mock + + +@pytest.fixture +def create_rag_corpus_mock_weaviate(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock_weaviate: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_WEAVIATE + ) + create_rag_corpus_mock_weaviate.return_value = create_rag_corpus_lro_mock + yield create_rag_corpus_mock_weaviate + + +@pytest.fixture +def create_rag_corpus_mock_vertex_feature_store(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock_vertex_feature_store: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE + ) + create_rag_corpus_mock_vertex_feature_store.return_value = ( + create_rag_corpus_lro_mock + ) + yield create_rag_corpus_mock_vertex_feature_store + + +@pytest.fixture +def create_rag_corpus_mock_vertex_vector_search(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock_vertex_vector_search: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH + ) + create_rag_corpus_mock_vertex_vector_search.return_value = ( + create_rag_corpus_lro_mock + ) + yield create_rag_corpus_mock_vertex_vector_search + + +@pytest.fixture +def create_rag_corpus_mock_pinecone(): + with mock.patch.object( + VertexRagDataServiceClient, + "create_rag_corpus", + ) as create_rag_corpus_mock_pinecone: + create_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_corpus_lro_mock.done.return_value = True + create_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_PINECONE + ) + create_rag_corpus_mock_pinecone.return_value = create_rag_corpus_lro_mock + yield create_rag_corpus_mock_pinecone + + +@pytest.fixture +def update_rag_corpus_mock_weaviate(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_corpus", + ) as update_rag_corpus_mock_weaviate: + update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_corpus_lro_mock.done.return_value = True + update_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_WEAVIATE + ) + update_rag_corpus_mock_weaviate.return_value = update_rag_corpus_lro_mock + yield update_rag_corpus_mock_weaviate + + +@pytest.fixture +def update_rag_corpus_mock_vertex_feature_store(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_corpus", + ) as update_rag_corpus_mock_vertex_feature_store: + update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_corpus_lro_mock.done.return_value = True + update_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_FEATURE_STORE + ) + update_rag_corpus_mock_vertex_feature_store.return_value = ( + update_rag_corpus_lro_mock + ) + yield update_rag_corpus_mock_vertex_feature_store + + +@pytest.fixture +def update_rag_corpus_mock_vertex_vector_search(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_corpus", + ) as update_rag_corpus_mock_vertex_vector_search: + update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_corpus_lro_mock.done.return_value = True + update_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_VERTEX_VECTOR_SEARCH + ) + update_rag_corpus_mock_vertex_vector_search.return_value = ( + update_rag_corpus_lro_mock + ) + yield update_rag_corpus_mock_vertex_vector_search + + +@pytest.fixture +def update_rag_corpus_mock_pinecone(): + with mock.patch.object( + VertexRagDataServiceClient, + "update_rag_corpus", + ) as update_rag_corpus_mock_pinecone: + update_rag_corpus_lro_mock = mock.Mock(ga_operation.Operation) + update_rag_corpus_lro_mock.done.return_value = True + update_rag_corpus_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS_PINECONE + ) + update_rag_corpus_mock_pinecone.return_value = update_rag_corpus_lro_mock + yield update_rag_corpus_mock_pinecone + + +@pytest.fixture +def list_rag_corpora_pager_mock(): + with mock.patch.object( + VertexRagDataServiceClient, + "list_rag_corpora", + ) as list_rag_corpora_pager_mock: + list_rag_corpora_pager_mock.return_value = [ + ListRagCorporaResponse( + rag_corpora=[ + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS, + ], + next_page_token=test_rag_constants_preview.TEST_PAGE_TOKEN, + ), + ] + yield list_rag_corpora_pager_mock + + +class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + +@pytest.fixture +def upload_file_mock(authorized_session_mock): + with patch.object(authorized_session_mock, "post") as mock_post: + mock_post.return_value = MockResponse( + test_rag_constants_preview.TEST_RAG_FILE_JSON, 200 + ) + yield mock_post + + +@pytest.fixture +def upload_file_not_found_mock(authorized_session_mock): + with patch.object(authorized_session_mock, "post") as mock_post: + mock_post.return_value = MockResponse(None, 404) + yield mock_post + + +@pytest.fixture +def upload_file_error_mock(authorized_session_mock): + with patch.object(authorized_session_mock, "post") as mock_post: + mock_post.return_value = MockResponse( + test_rag_constants_preview.TEST_RAG_FILE_JSON_ERROR, 200 + ) + yield mock_post + + +@pytest.fixture +def open_file_mock(): + with mock.patch("builtins.open") as open_file_mock: + yield open_file_mock + + +@pytest.fixture +def import_files_mock(): + with mock.patch.object( + VertexRagDataServiceClient, "import_rag_files" + ) as import_files_mock: + import_files_lro_mock = mock.Mock(ga_operation.Operation) + import_files_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_IMPORT_RESPONSE + ) + import_files_mock.return_value = import_files_lro_mock + yield import_files_mock + + +@pytest.fixture +def import_files_async_mock(): + with mock.patch.object( + VertexRagDataServiceAsyncClient, "import_rag_files" + ) as import_files_async_mock: + import_files_lro_mock = mock.Mock(ga_operation.Operation) + import_files_lro_mock.result.return_value = ( + test_rag_constants_preview.TEST_IMPORT_RESPONSE + ) + import_files_async_mock.return_value = import_files_lro_mock + yield import_files_async_mock + + +@pytest.fixture +def list_rag_files_pager_mock(): + with mock.patch.object( + VertexRagDataServiceClient, "list_rag_files" + ) as list_rag_files_pager_mock: + list_rag_files_pager_mock.return_value = [ + ListRagFilesResponse( + rag_files=[ + test_rag_constants_preview.TEST_GAPIC_RAG_FILE, + ], + next_page_token=test_rag_constants_preview.TEST_PAGE_TOKEN, + ), + ] + yield list_rag_files_pager_mock + + +def rag_corpus_eq(returned_corpus, expected_corpus): + assert returned_corpus.name == expected_corpus.name + assert returned_corpus.display_name == expected_corpus.display_name + assert returned_corpus.vector_db.__eq__(expected_corpus.vector_db) + + +def rag_file_eq(returned_file, expected_file): + assert returned_file.name == expected_file.name + assert returned_file.display_name == expected_file.display_name + + +def import_files_request_eq(returned_request, expected_request): + assert returned_request.parent == expected_request.parent + assert ( + returned_request.import_rag_files_config.gcs_source.uris + == expected_request.import_rag_files_config.gcs_source.uris + ) + assert ( + returned_request.import_rag_files_config.google_drive_source.resource_ids + == expected_request.import_rag_files_config.google_drive_source.resource_ids + ) + assert ( + returned_request.import_rag_files_config.slack_source.channels + == expected_request.import_rag_files_config.slack_source.channels + ) + assert ( + returned_request.import_rag_files_config.jira_source.jira_queries + == expected_request.import_rag_files_config.jira_source.jira_queries + ) + assert ( + returned_request.import_rag_files_config.rag_file_parsing_config + == expected_request.import_rag_files_config.rag_file_parsing_config + ) + + +@pytest.mark.usefixtures("google_auth_mock") +class TestRagDataManagement: + def setup_method(self): + importlib.reload(aiplatform.initializer) + importlib.reload(aiplatform) + aiplatform.init( + project=test_rag_constants_preview.TEST_PROJECT, + location=test_rag_constants_preview.TEST_REGION, + ) + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures("create_rag_corpus_mock") + def test_create_corpus_success(self): + rag_corpus = rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + embedding_model_config=test_rag_constants_preview.TEST_EMBEDDING_MODEL_CONFIG, + ) + + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS) + + @pytest.mark.usefixtures("create_rag_corpus_mock_weaviate") + def test_create_corpus_weaviate_success(self): + rag_corpus = rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, + ) + + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) + + @pytest.mark.usefixtures("create_rag_corpus_mock_vertex_feature_store") + def test_create_corpus_vertex_feature_store_success(self): + rag_corpus = rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vector_db=test_rag_constants_preview.TEST_VERTEX_FEATURE_STORE_CONFIG, + ) + + rag_corpus_eq( + rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_FEATURE_STORE + ) + + @pytest.mark.usefixtures("create_rag_corpus_mock_vertex_vector_search") + def test_create_corpus_vertex_vector_search_success(self): + rag_corpus = rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG, + ) + + rag_corpus_eq( + rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH + ) + + @pytest.mark.usefixtures("create_rag_corpus_mock_pinecone") + def test_create_corpus_pinecone_success(self): + rag_corpus = rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vector_db=test_rag_constants_preview.TEST_PINECONE_CONFIG, + ) + + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_PINECONE) + + @pytest.mark.usefixtures("rag_data_client_mock_exception") + def test_create_corpus_failure(self): + with pytest.raises(RuntimeError) as e: + rag.create_corpus( + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME + ) + e.match("Failed in RagCorpus creation due to") + + @pytest.mark.usefixtures("update_rag_corpus_mock_weaviate") + def test_update_corpus_weaviate_success(self): + rag_corpus = rag.update_corpus( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, + ) + + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) + + @pytest.mark.usefixtures("update_rag_corpus_mock_weaviate") + def test_update_corpus_weaviate_no_display_name_success(self): + rag_corpus = rag.update_corpus( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, + ) + + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) + + @pytest.mark.usefixtures("update_rag_corpus_mock_weaviate") + def test_update_corpus_weaviate_with_description_success(self): + rag_corpus = rag.update_corpus( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + description=test_rag_constants_preview.TEST_CORPUS_DISCRIPTION, + vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, + ) + + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) + + @pytest.mark.usefixtures("update_rag_corpus_mock_weaviate") + def test_update_corpus_weaviate_with_description_and_display_name_success(self): + rag_corpus = rag.update_corpus( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + description=test_rag_constants_preview.TEST_CORPUS_DISCRIPTION, + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, + ) + + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) + + @pytest.mark.usefixtures("update_rag_corpus_mock_vertex_feature_store") + def test_update_corpus_vertex_feature_store_success(self): + rag_corpus = rag.update_corpus( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vector_db=test_rag_constants_preview.TEST_VERTEX_FEATURE_STORE_CONFIG, + ) + + rag_corpus_eq( + rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_FEATURE_STORE + ) + + @pytest.mark.usefixtures("update_rag_corpus_mock_vertex_vector_search") + def test_update_corpus_vertex_vector_search_success(self): + rag_corpus = rag.update_corpus( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG, + ) + rag_corpus_eq( + rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH + ) + + @pytest.mark.usefixtures("update_rag_corpus_mock_pinecone") + def test_update_corpus_pinecone_success(self): + rag_corpus = rag.update_corpus( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + vector_db=test_rag_constants_preview.TEST_PINECONE_CONFIG, + ) + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_PINECONE) + + @pytest.mark.usefixtures("rag_data_client_mock_exception") + def test_update_corpus_failure(self): + with pytest.raises(RuntimeError) as e: + rag.update_corpus( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, + ) + e.match("Failed in RagCorpus update due to") + + @pytest.mark.usefixtures("rag_data_client_mock") + def test_get_corpus_success(self): + rag_corpus = rag.get_corpus( + test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME + ) + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS) + + @pytest.mark.usefixtures("rag_data_client_mock") + def test_get_corpus_id_success(self): + rag_corpus = rag.get_corpus(test_rag_constants_preview.TEST_RAG_CORPUS_ID) + rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS) + + @pytest.mark.usefixtures("rag_data_client_mock_exception") + def test_get_corpus_failure(self): + with pytest.raises(RuntimeError) as e: + rag.get_corpus(test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME) + e.match("Failed in getting the RagCorpus due to") + + def test_list_corpora_pager_success(self, list_rag_corpora_pager_mock): + aiplatform.init( + project=test_rag_constants_preview.TEST_PROJECT, + location=test_rag_constants_preview.TEST_REGION, + ) + pager = rag.list_corpora(page_size=1) + + list_rag_corpora_pager_mock.assert_called_once() + assert pager[0].next_page_token == test_rag_constants_preview.TEST_PAGE_TOKEN + + @pytest.mark.usefixtures("rag_data_client_mock_exception") + def test_list_corpora_failure(self): + with pytest.raises(RuntimeError) as e: + rag.list_corpora() + e.match("Failed in listing the RagCorpora due to") + + def test_delete_corpus_success(self, rag_data_client_mock): + rag.delete_corpus(test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME) + assert rag_data_client_mock.call_count == 2 + + def test_delete_corpus_id_success(self, rag_data_client_mock): + rag.delete_corpus(test_rag_constants_preview.TEST_RAG_CORPUS_ID) + assert rag_data_client_mock.call_count == 2 + + @pytest.mark.usefixtures("rag_data_client_mock_exception") + def test_delete_corpus_failure(self): + with pytest.raises(RuntimeError) as e: + rag.delete_corpus(test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME) + e.match("Failed in RagCorpus deletion due to") + + @pytest.mark.usefixtures("open_file_mock") + def test_upload_file_success( + self, + upload_file_mock, + ): + aiplatform.init( + project=test_rag_constants_preview.TEST_PROJECT, + location=test_rag_constants_preview.TEST_REGION, + ) + rag_file = rag.upload_file( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + path=test_rag_constants_preview.TEST_PATH, + display_name=test_rag_constants_preview.TEST_FILE_DISPLAY_NAME, + ) + + upload_file_mock.assert_called_once() + _, mock_kwargs = upload_file_mock.call_args + assert mock_kwargs["url"] == test_rag_constants_preview.TEST_UPLOAD_REQUEST_URI + assert mock_kwargs["headers"] == test_rag_constants_preview.TEST_HEADERS + + rag_file_eq(rag_file, test_rag_constants_preview.TEST_RAG_FILE) + + @pytest.mark.usefixtures("rag_data_client_mock_exception", "open_file_mock") + def test_upload_file_failure(self): + with pytest.raises(RuntimeError) as e: + rag.upload_file( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + path=test_rag_constants_preview.TEST_PATH, + display_name=test_rag_constants_preview.TEST_FILE_DISPLAY_NAME, + ) + e.match("Failed in uploading the RagFile due to") + + @pytest.mark.usefixtures("open_file_mock", "upload_file_not_found_mock") + def test_upload_file_not_found(self): + with pytest.raises(ValueError) as e: + rag.upload_file( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + path=test_rag_constants_preview.TEST_PATH, + display_name=test_rag_constants_preview.TEST_FILE_DISPLAY_NAME, + ) + e.match("is not found") + + @pytest.mark.usefixtures("open_file_mock", "upload_file_error_mock") + def test_upload_file_error(self): + with pytest.raises(RuntimeError) as e: + rag.upload_file( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + path=test_rag_constants_preview.TEST_PATH, + display_name=test_rag_constants_preview.TEST_FILE_DISPLAY_NAME, + ) + e.match("Failed in indexing the RagFile due to") + + def test_import_files(self, import_files_mock): + response = rag.import_files( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=[test_rag_constants_preview.TEST_GCS_PATH], + ) + import_files_mock.assert_called_once() + + assert response.imported_rag_files_count == 2 + + @pytest.mark.usefixtures("rag_data_client_mock_exception") + def test_import_files_failure(self): + with pytest.raises(RuntimeError) as e: + rag.import_files( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=[test_rag_constants_preview.TEST_GCS_PATH], + ) + e.match("Failed in importing the RagFiles due to") + + @pytest.mark.asyncio + async def test_import_files_async(self, import_files_async_mock): + response = await rag.import_files_async( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=[test_rag_constants_preview.TEST_GCS_PATH], + ) + import_files_async_mock.assert_called_once() + + assert response.result().imported_rag_files_count == 2 + + @pytest.mark.asyncio + @pytest.mark.usefixtures("rag_data_async_client_mock_exception") + async def test_import_files_async_failure(self): + with pytest.raises(RuntimeError) as e: + await rag.import_files_async( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=[test_rag_constants_preview.TEST_GCS_PATH], + ) + e.match("Failed in importing the RagFiles due to") + + @pytest.mark.usefixtures("rag_data_client_mock") + def test_get_file_success(self): + rag_file = rag.get_file(test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME) + rag_file_eq(rag_file, test_rag_constants_preview.TEST_RAG_FILE) + + @pytest.mark.usefixtures("rag_data_client_mock") + def test_get_file_id_success(self): + rag_file = rag.get_file( + name=test_rag_constants_preview.TEST_RAG_FILE_ID, + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_ID, + ) + rag_file_eq(rag_file, test_rag_constants_preview.TEST_RAG_FILE) + + @pytest.mark.usefixtures("rag_data_client_mock_exception") + def test_get_file_failure(self): + with pytest.raises(RuntimeError) as e: + rag.get_file(test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME) + e.match("Failed in getting the RagFile due to") + + def test_list_files_pager_success(self, list_rag_files_pager_mock): + files = rag.list_files( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + page_size=1, + ) + list_rag_files_pager_mock.assert_called_once() + assert files[0].next_page_token == test_rag_constants_preview.TEST_PAGE_TOKEN + + @pytest.mark.usefixtures("rag_data_client_mock_exception") + def test_list_files_failure(self): + with pytest.raises(RuntimeError) as e: + rag.list_files( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME + ) + e.match("Failed in listing the RagFiles due to") + + def test_delete_file_success(self, rag_data_client_mock): + rag.delete_file(test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME) + assert rag_data_client_mock.call_count == 2 + + def test_delete_file_id_success(self, rag_data_client_mock): + rag.delete_file( + name=test_rag_constants_preview.TEST_RAG_FILE_ID, + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_ID, + ) + # Passing corpus_name will result in 3 calls to rag_data_client + assert rag_data_client_mock.call_count == 3 + + @pytest.mark.usefixtures("rag_data_client_mock_exception") + def test_delete_file_failure(self): + with pytest.raises(RuntimeError) as e: + rag.delete_file(test_rag_constants_preview.TEST_RAG_FILE_RESOURCE_NAME) + e.match("Failed in RagFile deletion due to") + + def test_prepare_import_files_request_list_gcs_uris(self): + paths = [test_rag_constants_preview.TEST_GCS_PATH] + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=paths, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + import_files_request_eq( + request, test_rag_constants_preview.TEST_IMPORT_REQUEST_GCS + ) + + @pytest.mark.parametrize( + "path", + [ + test_rag_constants_preview.TEST_DRIVE_FOLDER, + test_rag_constants_preview.TEST_DRIVE_FOLDER_2, + ], + ) + def test_prepare_import_files_request_drive_folders(self, path): + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=[path], + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + import_files_request_eq( + request, test_rag_constants_preview.TEST_IMPORT_REQUEST_DRIVE_FOLDER + ) + + @pytest.mark.parametrize( + "path", + [ + test_rag_constants_preview.TEST_DRIVE_FOLDER, + test_rag_constants_preview.TEST_DRIVE_FOLDER_2, + ], + ) + def test_prepare_import_files_request_drive_folders_with_pdf_parsing(self, path): + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=[path], + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + use_advanced_pdf_parsing=True, + ) + import_files_request_eq( + request, test_rag_constants_preview.TEST_IMPORT_REQUEST_DRIVE_FOLDER_PARSING + ) + + def test_prepare_import_files_request_drive_files(self): + paths = [test_rag_constants_preview.TEST_DRIVE_FILE] + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=paths, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + max_embedding_requests_per_min=800, + ) + import_files_request_eq( + request, test_rag_constants_preview.TEST_IMPORT_REQUEST_DRIVE_FILE + ) + + def test_prepare_import_files_request_invalid_drive_path(self): + with pytest.raises(ValueError) as e: + paths = ["https://drive.google.com/bslalsdfk/whichever_file/456"] + prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=paths, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + e.match("is not a valid Google Drive url") + + def test_prepare_import_files_request_invalid_path(self): + with pytest.raises(ValueError) as e: + paths = ["https://whereever.com/whichever_file/456"] + prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + paths=paths, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + e.match("path must be a Google Cloud Storage uri or a Google Drive url") + + def test_prepare_import_files_request_slack_source(self): + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + source=test_rag_constants_preview.TEST_SLACK_SOURCE, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + import_files_request_eq( + request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SLACK_SOURCE + ) + + def test_prepare_import_files_request_jira_source(self): + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + source=test_rag_constants_preview.TEST_JIRA_SOURCE, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + import_files_request_eq( + request, test_rag_constants_preview.TEST_IMPORT_REQUEST_JIRA_SOURCE + ) + + def test_prepare_import_files_request_sharepoint_source(self): + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + import_files_request_eq( + request, test_rag_constants_preview.TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE + ) + + def test_prepare_import_files_request_sharepoint_source_2_drives(self): + with pytest.raises(ValueError) as e: + prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_DRIVES, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + e.match("drive_name and drive_id cannot both be set.") + + def test_prepare_import_files_request_sharepoint_source_2_folders(self): + with pytest.raises(ValueError) as e: + prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_2_FOLDERS, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + e.match("sharepoint_folder_path and sharepoint_folder_id cannot both be set.") + + def test_prepare_import_files_request_sharepoint_source_no_drives(self): + with pytest.raises(ValueError) as e: + prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_DRIVES, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + e.match("Either drive_name and drive_id must be set.") + + def test_prepare_import_files_request_sharepoint_source_no_folders(self): + request = prepare_import_files_request( + corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, + source=test_rag_constants_preview.TEST_SHARE_POINT_SOURCE_NO_FOLDERS, + chunk_size=test_rag_constants_preview.TEST_CHUNK_SIZE, + chunk_overlap=test_rag_constants_preview.TEST_CHUNK_OVERLAP, + ) + import_files_request_eq( + request, + test_rag_constants_preview.TEST_IMPORT_REQUEST_SHARE_POINT_SOURCE_NO_FOLDERS, + ) + + def test_set_embedding_model_config_set_both_error(self): + embedding_model_config = rag.EmbeddingModelConfig( + publisher_model="whatever", + endpoint="whatever", + ) + with pytest.raises(ValueError) as e: + set_embedding_model_config( + embedding_model_config, + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS, + ) + e.match("publisher_model and endpoint cannot be set at the same time") + + def test_set_embedding_model_config_not_set_error(self): + embedding_model_config = rag.EmbeddingModelConfig() + with pytest.raises(ValueError) as e: + set_embedding_model_config( + embedding_model_config, + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS, + ) + e.match("At least one of publisher_model and endpoint must be set") + + def test_set_embedding_model_config_wrong_publisher_model_format_error(self): + embedding_model_config = rag.EmbeddingModelConfig(publisher_model="whatever") + with pytest.raises(ValueError) as e: + set_embedding_model_config( + embedding_model_config, + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS, + ) + e.match("publisher_model must be of the format ") + + def test_set_embedding_model_config_wrong_endpoint_format_error(self): + embedding_model_config = rag.EmbeddingModelConfig(endpoint="whatever") + with pytest.raises(ValueError) as e: + set_embedding_model_config( + embedding_model_config, + test_rag_constants_preview.TEST_GAPIC_RAG_CORPUS, + ) + e.match("endpoint must be of the format ") diff --git a/tests/unit/vertex_rag/test_rag_retrieval_preview.py b/tests/unit/vertex_rag/test_rag_retrieval_preview.py new file mode 100644 index 0000000000..21bdc7b4cd --- /dev/null +++ b/tests/unit/vertex_rag/test_rag_retrieval_preview.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +from google.cloud import aiplatform +from vertexai.preview import rag +from google.cloud.aiplatform_v1beta1 import ( + VertexRagServiceClient, +) +import mock +import pytest +import test_rag_constants_preview + + +@pytest.fixture +def retrieve_contexts_mock(): + with mock.patch.object( + VertexRagServiceClient, + "retrieve_contexts", + ) as retrieve_contexts_mock: + retrieve_contexts_mock.return_value = ( + test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE + ) + yield retrieve_contexts_mock + + +@pytest.fixture +def rag_client_mock_exception(): + with mock.patch.object( + rag.utils._gapic_utils, "create_rag_service_client" + ) as rag_client_mock_exception: + api_client_mock = mock.Mock(spec=VertexRagServiceClient) + # retrieve_contexts + api_client_mock.retrieve_contexts.side_effect = Exception + rag_client_mock_exception.return_value = api_client_mock + yield rag_client_mock_exception + + +def retrieve_contexts_eq(response, expected_response): + assert len(response.contexts.contexts) == len(expected_response.contexts.contexts) + assert ( + response.contexts.contexts[0].text + == expected_response.contexts.contexts[0].text + ) + assert ( + response.contexts.contexts[0].source_uri + == expected_response.contexts.contexts[0].source_uri + ) + + +@pytest.mark.usefixtures("google_auth_mock") +class TestRagRetrieval: + def setup_method(self): + importlib.reload(aiplatform.initializer) + importlib.reload(aiplatform) + aiplatform.init() + + def teardown_method(self): + aiplatform.initializer.global_pool.shutdown(wait=True) + + @pytest.mark.usefixtures("retrieve_contexts_mock") + def test_retrieval_query_rag_resources_success(self): + response = rag.retrieval_query( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + vector_search_alpha=0.5, + ) + retrieve_contexts_eq( + response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE + ) + + @pytest.mark.usefixtures("retrieve_contexts_mock") + def test_retrieval_query_rag_corpora_success(self): + response = rag.retrieval_query( + rag_corpora=[test_rag_constants_preview.TEST_RAG_CORPUS_ID], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + retrieve_contexts_eq( + response, test_rag_constants_preview.TEST_RETRIEVAL_RESPONSE + ) + + @pytest.mark.usefixtures("rag_client_mock_exception") + def test_retrieval_query_failure(self): + with pytest.raises(RuntimeError) as e: + rag.retrieval_query( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + e.match("Failed in retrieving contexts due to") + + def test_retrieval_query_invalid_name(self): + with pytest.raises(ValueError) as e: + rag.retrieval_query( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME + ], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + e.match("Invalid RagCorpus name") + + def test_retrieval_query_multiple_rag_corpora(self): + with pytest.raises(ValueError) as e: + rag.retrieval_query( + rag_corpora=[ + test_rag_constants_preview.TEST_RAG_CORPUS_ID, + test_rag_constants_preview.TEST_RAG_CORPUS_ID, + ], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + e.match("Currently only support 1 RagCorpus") + + def test_retrieval_query_multiple_rag_resources(self): + with pytest.raises(ValueError) as e: + rag.retrieval_query( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE, + test_rag_constants_preview.TEST_RAG_RESOURCE, + ], + text=test_rag_constants_preview.TEST_QUERY_TEXT, + similarity_top_k=2, + vector_distance_threshold=0.5, + ) + e.match("Currently only support 1 RagResource") diff --git a/tests/unit/vertex_rag/test_rag_store_preview.py b/tests/unit/vertex_rag/test_rag_store_preview.py new file mode 100644 index 0000000000..6d733b7baf --- /dev/null +++ b/tests/unit/vertex_rag/test_rag_store_preview.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from vertexai.preview import rag +from vertexai.preview.generative_models import Tool +import pytest +import test_rag_constants_preview + + +@pytest.mark.usefixtures("google_auth_mock") +class TestRagStoreValidations: + def test_retrieval_tool_invalid_name(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME + ], + similarity_top_k=3, + vector_distance_threshold=0.4, + ), + ) + ) + e.match("Invalid RagCorpus name") + + def test_retrieval_tool_multiple_rag_corpora(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_corpora=[ + test_rag_constants_preview.TEST_RAG_CORPUS_ID, + test_rag_constants_preview.TEST_RAG_CORPUS_ID, + ], + similarity_top_k=3, + vector_distance_threshold=0.4, + ), + ) + ) + e.match("Currently only support 1 RagCorpus") + + def test_retrieval_tool_multiple_rag_resources(self): + with pytest.raises(ValueError) as e: + Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE, + test_rag_constants_preview.TEST_RAG_RESOURCE, + ], + similarity_top_k=3, + vector_distance_threshold=0.4, + ), + ) + ) + e.match("Currently only support 1 RagResource")