Skip to content

Commit

Permalink
Add pytest for FeatureViewSyncSensor
Browse files Browse the repository at this point in the history
  • Loading branch information
CYarros10 committed Dec 12, 2024
1 parent 657c6ff commit 740d823
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@ def setup_method(self):
@mock.patch(FEATURE_STORE_STRING.format("FeatureOnlineStoreAdminServiceClient"), autospec=True)
@mock.patch(BASE_STRING.format("GoogleBaseHook.get_credentials"))
def test_get_feature_online_store_admin_service_client(self, mock_get_credentials, mock_client):
# Test with location
self.hook.get_feature_online_store_admin_service_client(location=TEST_LOCATION)
mock_client.assert_called_once_with(
credentials=mock_get_credentials.return_value, client_info=mock.ANY, client_options=mock.ANY
)
client_options = mock_client.call_args[1]["client_options"]
assert client_options.api_endpoint == f"{TEST_LOCATION}-aiplatform.googleapis.com:443"

# Test without location (global)
mock_client.reset_mock()
self.hook.get_feature_online_store_admin_service_client()
mock_client.assert_called_once_with(
Expand Down Expand Up @@ -112,5 +110,3 @@ def test_sync_feature_view(self, mock_client_getter):

mock_client.sync_feature_view.assert_called_once_with(feature_view=TEST_FEATURE_VIEW)
assert result == "test-sync-operation-name"

# Removing test_list_feature_view_syncs as the method doesn't exist in the hook
157 changes: 157 additions & 0 deletions providers/tests/google/cloud/sensors/test_vertex_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from unittest import mock
from unittest.mock import Mock

import pytest

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.sensors.vertex_ai.feature_store import FeatureViewSyncSensor

TASK_ID = "test-task"
GCP_CONN_ID = "test-conn"
GCP_LOCATION = "us-central1"
FEATURE_VIEW_SYNC_NAME = "projects/123/locations/us-central1/featureViews/test-view/operations/sync-123"
TIMEOUT = 120


class TestFeatureViewSyncSensor:
def create_sync_response(self, end_time=None, row_synced=None, total_slot=None):
response = {}
if end_time is not None:
response["end_time"] = end_time
if row_synced is not None and total_slot is not None:
response["sync_summary"] = {"row_synced": str(row_synced), "total_slot": str(total_slot)}
return response

@mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook")
def test_sync_completed(self, mock_hook):
mock_hook.return_value.get_feature_view_sync.return_value = self.create_sync_response(
end_time=1234567890, row_synced=1000, total_slot=5
)

sensor = FeatureViewSyncSensor(
task_id=TASK_ID,
feature_view_sync_name=FEATURE_VIEW_SYNC_NAME,
location=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)
ret = sensor.poke(context={})

mock_hook.return_value.get_feature_view_sync.assert_called_once_with(
location=GCP_LOCATION,
feature_view_sync_name=FEATURE_VIEW_SYNC_NAME,
)
assert ret

@mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook")
def test_sync_running(self, mock_hook):
mock_hook.return_value.get_feature_view_sync.return_value = self.create_sync_response(
end_time=0, row_synced=0, total_slot=5
)

sensor = FeatureViewSyncSensor(
task_id=TASK_ID,
feature_view_sync_name=FEATURE_VIEW_SYNC_NAME,
location=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)
ret = sensor.poke(context={})

mock_hook.return_value.get_feature_view_sync.assert_called_once_with(
location=GCP_LOCATION,
feature_view_sync_name=FEATURE_VIEW_SYNC_NAME,
)
assert not ret

@mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook")
def test_sync_error_with_retry(self, mock_hook):
mock_hook.return_value.get_feature_view_sync.side_effect = Exception("API Error")

sensor = FeatureViewSyncSensor(
task_id=TASK_ID,
feature_view_sync_name=FEATURE_VIEW_SYNC_NAME,
location=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)
ret = sensor.poke(context={})

mock_hook.return_value.get_feature_view_sync.assert_called_once_with(
location=GCP_LOCATION,
feature_view_sync_name=FEATURE_VIEW_SYNC_NAME,
)
assert not ret

@mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook")
def test_timeout_during_running(self, mock_hook):
mock_hook.return_value.get_feature_view_sync.return_value = self.create_sync_response(
end_time=0, row_synced=0, total_slot=5
)

sensor = FeatureViewSyncSensor(
task_id=TASK_ID,
feature_view_sync_name=FEATURE_VIEW_SYNC_NAME,
location=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
wait_timeout=300,
)

sensor._duration = Mock()
sensor._duration.return_value = 301

with pytest.raises(
AirflowException,
match=f"Timeout: Feature View sync {FEATURE_VIEW_SYNC_NAME} not completed after 300s",
):
sensor.poke(context={})

@mock.patch("airflow.providers.google.cloud.sensors.vertex_ai.feature_store.FeatureStoreHook")
def test_timeout_during_error(self, mock_hook):
mock_hook.return_value.get_feature_view_sync.side_effect = Exception("API Error")

sensor = FeatureViewSyncSensor(
task_id=TASK_ID,
feature_view_sync_name=FEATURE_VIEW_SYNC_NAME,
location=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
wait_timeout=300,
)

sensor._duration = Mock()
sensor._duration.return_value = 301

with pytest.raises(
AirflowException,
match=f"Timeout: Feature View sync {FEATURE_VIEW_SYNC_NAME} not completed after 300s",
):
sensor.poke(context={})

def test_missing_location(self):
with pytest.raises(TypeError, match="missing keyword argument 'location'"):
FeatureViewSyncSensor(
task_id=TASK_ID,
feature_view_sync_name=FEATURE_VIEW_SYNC_NAME,
gcp_conn_id=GCP_CONN_ID,
timeout=TIMEOUT,
)

0 comments on commit 740d823

Please sign in to comment.