Skip to content

Commit

Permalink
SagemakerProcessingOperator stopped honoring existing_jobs_found (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
ferruzzi authored and Adityamalik123 committed Nov 12, 2022
1 parent 83a32ad commit 194dec5
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/actions/configure-aws-credentials
45 changes: 40 additions & 5 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tarfile
import tempfile
import time
import warnings
from datetime import datetime
from functools import partial
from typing import Any, Callable, Generator, cast
Expand Down Expand Up @@ -939,13 +940,47 @@ def _list_request(
next_token = response["NextToken"]

def find_processing_job_by_name(self, processing_job_name: str) -> bool:
"""Query processing job by name"""
"""
Query processing job by name
This method is deprecated.
Please use `airflow.providers.amazon.aws.hooks.sagemaker.count_processing_jobs_by_name`.
"""
warnings.warn(
"This method is deprecated. "
"Please use `airflow.providers.amazon.aws.hooks.sagemaker.count_processing_jobs_by_name`.",
DeprecationWarning,
stacklevel=2,
)
return bool(self.count_processing_jobs_by_name(processing_job_name))

def count_processing_jobs_by_name(
self,
processing_job_name: str,
throttle_retry_delay: int = 2,
retries: int = 3,
) -> int:
"""
Returns the number of processing jobs found with the provided name prefix.
:param processing_job_name: The prefix to look for.
:param throttle_retry_delay: Seconds to wait if a ThrottlingException is hit.
:param retries: The max number of times to retry.
:returns: The number of processing jobs that start with the provided prefix.
"""
try:
self.get_conn().describe_processing_job(ProcessingJobName=processing_job_name)
return True
jobs = self.get_conn().list_processing_jobs(NameContains=processing_job_name)
return len(jobs["ProcessingJobSummaries"])
except ClientError as e:
if e.response["Error"]["Code"] in ["ValidationException", "ResourceNotFound"]:
return False
if e.response["Error"]["Code"] == "ResourceNotFound":
# No jobs found with that name. This is good, return 0.
return 0
if e.response["Error"]["Code"] == "ThrottlingException":
# If we hit a ThrottlingException, back off a little and try again.
if retries:
time.sleep(throttle_retry_delay)
return self.count_processing_jobs_by_name(
processing_job_name, throttle_retry_delay * 2, retries - 1
)
raise

def delete_model(self, model_name: str):
Expand Down
17 changes: 12 additions & 5 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,18 @@ def expand_role(self) -> None:
def execute(self, context: Context) -> dict:
self.preprocess_config()
processing_job_name = self.config["ProcessingJobName"]
if self.hook.find_processing_job_by_name(processing_job_name):
raise AirflowException(
f"A SageMaker processing job with name {processing_job_name} already exists."
)
self.log.info("Creating SageMaker processing job %s.", self.config["ProcessingJobName"])
existing_jobs_found = self.hook.count_processing_jobs_by_name(processing_job_name)
if existing_jobs_found:
if self.action_if_job_exists == "fail":
raise AirflowException(
f"A SageMaker processing job with name {processing_job_name} already exists."
)
elif self.action_if_job_exists == "increment":
self.log.info("Found existing processing job with name '%s'.", processing_job_name)
new_processing_job_name = f"{processing_job_name}-{existing_jobs_found + 1}"
self.config["ProcessingJobName"] = new_processing_job_name
self.log.info("Incremented processing job name to '%s'.", new_processing_job_name)

response = self.hook.create_processing_job(
self.config,
wait_for_completion=self.wait_for_completion,
Expand Down
68 changes: 63 additions & 5 deletions tests/providers/amazon/aws/hooks/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,15 @@ def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, m
assert mock_session.describe_training_job.call_count == 1

@mock.patch.object(SageMakerHook, "get_conn")
def test_find_processing_job_by_name(self, _):
def test_find_processing_job_by_name(self, mock_conn):
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
ret = hook.find_processing_job_by_name("existing_job")
assert ret
mock_conn().list_processing_jobs.return_value = {
"ProcessingJobSummaries": [{"ProcessingJobName": "existing_job"}]
}

with pytest.warns(DeprecationWarning):
ret = hook.find_processing_job_by_name("existing_job")
assert ret

@mock.patch.object(SageMakerHook, "get_conn")
def test_find_processing_job_by_name_job_not_exists_should_return_false(self, mock_conn):
Expand All @@ -634,8 +639,61 @@ def test_find_processing_job_by_name_job_not_exists_should_return_false(self, mo
)
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")

ret = hook.find_processing_job_by_name("existing_job")
assert not ret
with pytest.warns(DeprecationWarning):
ret = hook.find_processing_job_by_name("existing_job")
assert not ret

@mock.patch.object(SageMakerHook, "get_conn")
def test_count_processing_jobs_by_name(self, mock_conn):
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
existing_job_name = "existing_job"
mock_conn().list_processing_jobs.return_value = {
"ProcessingJobSummaries": [{"ProcessingJobName": existing_job_name}]
}
ret = hook.count_processing_jobs_by_name(existing_job_name)
assert ret == 1

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch("time.sleep", return_value=None)
def test_count_processing_jobs_by_name_retries_on_throttle_exception(self, _, mock_conn):
throttle_exception = ClientError(
error_response={"Error": {"Code": "ThrottlingException"}}, operation_name="empty"
)
successful_result = {"ProcessingJobSummaries": [{"ProcessingJobName": "existing_job"}]}
# Return a ThrottleException on the first call, then a mocked successful value the second.
mock_conn().list_processing_jobs.side_effect = [throttle_exception, successful_result]
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")

ret = hook.count_processing_jobs_by_name("existing_job")

assert mock_conn().list_processing_jobs.call_count == 2
assert ret == 1

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch("time.sleep", return_value=None)
def test_count_processing_jobs_by_name_fails_after_max_retries(self, _, mock_conn):
mock_conn().list_processing_jobs.side_effect = ClientError(
error_response={"Error": {"Code": "ThrottlingException"}}, operation_name="empty"
)
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")

with pytest.raises(ClientError) as raised_exception:
hook.count_processing_jobs_by_name("existing_job")

# One initial call plus retries
assert mock_conn().list_processing_jobs.call_count == 4
assert raised_exception.value.response["Error"]["Code"] == "ThrottlingException"

@mock.patch.object(SageMakerHook, "get_conn")
def test_count_processing_jobs_by_name_job_not_exists_should_return_falsy(self, mock_conn):
error_resp = {"Error": {"Code": "ResourceNotFound"}}
mock_conn().list_processing_jobs.side_effect = ClientError(
error_response=error_resp, operation_name="empty"
)
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")

ret = hook.count_processing_jobs_by_name("existing_job")
assert ret == 0

@mock_sagemaker
def test_delete_model(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def setUp(self):
)

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
Expand All @@ -113,7 +113,7 @@ def test_integer_fields_without_stopping_condition(
assert sagemaker.config[key1][key2][key3] == int(sagemaker.config[key1][key2][key3])

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
Expand All @@ -136,7 +136,7 @@ def test_integer_fields_with_stopping_condition(self, serialize, mock_processing
sagemaker.config[key1][key2] == int(sagemaker.config[key1][key2])

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
Expand All @@ -153,7 +153,7 @@ def test_execute(self, serialize, mock_processing, mock_hook, mock_client):
)

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
Expand Down Expand Up @@ -187,12 +187,12 @@ def test_execute_with_failure(self, mock_processing, mock_client):

@unittest.skip("Currently, the auto-increment jobname functionality is not missing.")
@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=1)
@mock.patch.object(
SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
)
def test_execute_with_existing_job_increment(
self, mock_create_processing_job, find_processing_job_by_name, mock_client
self, mock_create_processing_job, count_processing_jobs_by_name, mock_client
):
sagemaker = SageMakerProcessingOperator(
**self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
Expand All @@ -211,7 +211,7 @@ def test_execute_with_existing_job_increment(
)

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=1)
@mock.patch.object(
SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
)
Expand Down

0 comments on commit 194dec5

Please sign in to comment.