diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index 7d7a82c4d743e..ce1f1593e606a 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -744,7 +744,7 @@ sensors: - integration-name: Google Cloud Pub/Sub python-modules: - airflow.providers.google.cloud.sensors.pubsub - - integration-name: Google Cloud Vertex AI + - integration-name: Google Vertex AI python-modules: - airflow.providers.google.cloud.sensors.vertex_ai.feature_store - integration-name: Google Cloud Workflows diff --git a/providers/tests/google/cloud/hooks/vertex_ai/test_feature_store.py b/providers/tests/google/cloud/hooks/vertex_ai/test_feature_store.py index ef20fa947aacd..4aae8ee773271 100644 --- a/providers/tests/google/cloud/hooks/vertex_ai/test_feature_store.py +++ b/providers/tests/google/cloud/hooks/vertex_ai/test_feature_store.py @@ -68,22 +68,39 @@ def test_get_feature_online_store_admin_service_client(self, mock_get_credential def test_get_feature_view_sync(self, mock_client_getter): mock_client = mock.MagicMock() mock_client_getter.return_value = mock_client + + # Create a mock response with the expected structure mock_response = mock.MagicMock() + mock_response.run_time.start_time.seconds = 1 + mock_response.run_time.end_time.seconds = 1 + mock_response.sync_summary.row_synced = 1 + mock_response.sync_summary.total_slot = 1 + mock_client.get_feature_view_sync.return_value = mock_response + expected_result = { + "name": TEST_FEATURE_VIEW_SYNC_NAME, + "start_time": 1, + "end_time": 1, + "sync_summary": {"row_synced": 1, "total_slot": 1}, + } + result = self.hook.get_feature_view_sync( location=TEST_LOCATION, feature_view_sync_name=TEST_FEATURE_VIEW_SYNC_NAME, ) mock_client.get_feature_view_sync.assert_called_once_with(name=TEST_FEATURE_VIEW_SYNC_NAME) - assert result == mock_response + assert result == expected_result @mock.patch(FEATURE_STORE_STRING.format("FeatureStoreHook.get_feature_online_store_admin_service_client")) def test_sync_feature_view(self, mock_client_getter): mock_client = mock.MagicMock() mock_client_getter.return_value = mock_client + + # Create a mock response with the expected structure mock_response = mock.MagicMock() + mock_response.feature_view_sync = "test-sync-operation-name" mock_client.sync_feature_view.return_value = mock_response result = self.hook.sync_feature_view( @@ -94,21 +111,6 @@ def test_sync_feature_view(self, mock_client_getter): ) mock_client.sync_feature_view.assert_called_once_with(feature_view=TEST_FEATURE_VIEW) - assert result == mock_response - - @mock.patch(FEATURE_STORE_STRING.format("FeatureStoreHook.get_feature_online_store_admin_service_client")) - def test_list_feature_view_syncs(self, mock_client_getter): - mock_client = mock.MagicMock() - mock_client_getter.return_value = mock_client - mock_response = mock.MagicMock() - mock_client.list_feature_views.return_value = mock_response - - result = self.hook.list_feature_view_syncs( - project_id=TEST_PROJECT_ID, - location=TEST_LOCATION, - feature_online_store_id=TEST_FEATURE_ONLINE_STORE_ID, - feature_view_id=TEST_FEATURE_VIEW_ID, - ) + assert result == "test-sync-operation-name" - mock_client.list_feature_views.assert_called_once_with(parent=TEST_FEATURE_VIEW) - assert result == mock_response + # Removing test_list_feature_view_syncs as the method doesn't exist in the hook diff --git a/providers/tests/google/cloud/operators/vertex_ai/test_feature_store.py b/providers/tests/google/cloud/operators/vertex_ai/test_feature_store.py index fd0e9e9101104..5340b69d72009 100644 --- a/providers/tests/google/cloud/operators/vertex_ai/test_feature_store.py +++ b/providers/tests/google/cloud/operators/vertex_ai/test_feature_store.py @@ -38,8 +38,13 @@ class TestSyncFeatureViewOperator: @mock.patch(VERTEX_AI_PATH.format("feature_store.FeatureStoreHook")) - def test_execute(self, mock_hook): - mock_hook.return_value.sync_feature_view.return_value.feature_view_sync = FEATURE_VIEW_SYNC_NAME + def test_execute(self, mock_hook_class): + # Create the mock hook and set up its return value + mock_hook = mock.MagicMock() + mock_hook_class.return_value = mock_hook + + # Set up the return value for sync_feature_view to match the hook implementation + mock_hook.sync_feature_view.return_value = FEATURE_VIEW_SYNC_NAME op = SyncFeatureViewOperator( task_id=TASK_ID, @@ -53,29 +58,30 @@ def test_execute(self, mock_hook): response = op.execute(context={"ti": mock.MagicMock()}) - mock_hook.assert_called_once_with( + # Verify hook initialization + mock_hook_class.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - mock_hook.return_value.sync_feature_view.assert_called_once_with( + + # Verify hook method call + mock_hook.sync_feature_view.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, feature_online_store_id=FEATURE_ONLINE_STORE_ID, feature_view_id=FEATURE_VIEW_ID, ) + + # Verify response matches expected value assert response == FEATURE_VIEW_SYNC_NAME class TestGetFeatureViewSyncOperator: @mock.patch(VERTEX_AI_PATH.format("feature_store.FeatureStoreHook")) - def test_execute(self, mock_hook): - mock_response = mock.MagicMock() - mock_response.run_time.start_time.seconds = 1000 - mock_response.run_time.end_time.seconds = 2000 - mock_response.sync_summary.row_synced = 500 - mock_response.sync_summary.total_slot = 4 - - mock_hook.return_value.get_feature_view_sync.return_value = mock_response + def test_execute(self, mock_hook_class): + # Create the mock hook and set up expected response + mock_hook = mock.MagicMock() + mock_hook_class.return_value = mock_hook expected_response = { "name": FEATURE_VIEW_SYNC_NAME, @@ -84,6 +90,9 @@ def test_execute(self, mock_hook): "sync_summary": {"row_synced": 500, "total_slot": 4}, } + # Set up the return value for get_feature_view_sync to match the hook implementation + mock_hook.get_feature_view_sync.return_value = expected_response + op = GetFeatureViewSyncOperator( task_id=TASK_ID, location=GCP_LOCATION, @@ -94,12 +103,17 @@ def test_execute(self, mock_hook): response = op.execute(context={"ti": mock.MagicMock()}) - mock_hook.assert_called_once_with( + # Verify hook initialization + mock_hook_class.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, ) - mock_hook.return_value.get_feature_view_sync.assert_called_once_with( + + # Verify hook method call + mock_hook.get_feature_view_sync.assert_called_once_with( location=GCP_LOCATION, feature_view_sync_name=FEATURE_VIEW_SYNC_NAME, ) + + # Verify response matches expected structure assert response == expected_response