Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SagemakerProcessingOperator stopped honoring existing_jobs_found #27456

Merged
merged 3 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/actions/configure-aws-credentials
45 changes: 40 additions & 5 deletions airflow/providers/amazon/aws/hooks/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tarfile
import tempfile
import time
import warnings
from datetime import datetime
from functools import partial
from typing import Any, Callable, Generator, cast
Expand Down Expand Up @@ -939,13 +940,47 @@ def _list_request(
next_token = response["NextToken"]

def find_processing_job_by_name(self, processing_job_name: str) -> bool:
"""Query processing job by name"""
"""
Query processing job by name

This method is deprecated.
Please use `airflow.providers.amazon.aws.hooks.sagemaker.count_processing_jobs_by_name`.
"""
warnings.warn(
"This method is deprecated. "
"Please use `airflow.providers.amazon.aws.hooks.sagemaker.count_processing_jobs_by_name`.",
DeprecationWarning,
stacklevel=2,
)
return bool(self.count_processing_jobs_by_name(processing_job_name))

def count_processing_jobs_by_name(
self,
processing_job_name: str,
throttle_retry_delay: int = 2,
retries: int = 3,
) -> int:
"""
Returns the number of processing jobs found with the provided name prefix.
:param processing_job_name: The prefix to look for.
:param throttle_retry_delay: Seconds to wait if a ThrottlingException is hit.
:param retries: The max number of times to retry.
:returns: The number of processing jobs that start with the provided prefix.
"""
try:
self.get_conn().describe_processing_job(ProcessingJobName=processing_job_name)
return True
jobs = self.get_conn().list_processing_jobs(NameContains=processing_job_name)
return len(jobs["ProcessingJobSummaries"])
except ClientError as e:
if e.response["Error"]["Code"] in ["ValidationException", "ResourceNotFound"]:
return False
if e.response["Error"]["Code"] == "ResourceNotFound":
# No jobs found with that name. This is good, return 0.
return 0
if e.response["Error"]["Code"] == "ThrottlingException":
# If we hit a ThrottlingException, back off a little and try again.
if retries:
time.sleep(throttle_retry_delay)
return self.count_processing_jobs_by_name(
processing_job_name, throttle_retry_delay * 2, retries - 1
)
raise

def delete_model(self, model_name: str):
Expand Down
17 changes: 12 additions & 5 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,18 @@ def expand_role(self) -> None:
def execute(self, context: Context) -> dict:
self.preprocess_config()
processing_job_name = self.config["ProcessingJobName"]
if self.hook.find_processing_job_by_name(processing_job_name):
raise AirflowException(
f"A SageMaker processing job with name {processing_job_name} already exists."
)
self.log.info("Creating SageMaker processing job %s.", self.config["ProcessingJobName"])
existing_jobs_found = self.hook.count_processing_jobs_by_name(processing_job_name)
if existing_jobs_found:
if self.action_if_job_exists == "fail":
raise AirflowException(
f"A SageMaker processing job with name {processing_job_name} already exists."
)
elif self.action_if_job_exists == "increment":
self.log.info("Found existing processing job with name '%s'.", processing_job_name)
new_processing_job_name = f"{processing_job_name}-{existing_jobs_found + 1}"
self.config["ProcessingJobName"] = new_processing_job_name
self.log.info("Incremented processing job name to '%s'.", new_processing_job_name)

response = self.hook.create_processing_job(
self.config,
wait_for_completion=self.wait_for_completion,
Expand Down
68 changes: 63 additions & 5 deletions tests/providers/amazon/aws/hooks/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,15 @@ def test_training_with_logs(self, mock_describe, mock_client, mock_log_client, m
assert mock_session.describe_training_job.call_count == 1

@mock.patch.object(SageMakerHook, "get_conn")
def test_find_processing_job_by_name(self, _):
def test_find_processing_job_by_name(self, mock_conn):
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
ret = hook.find_processing_job_by_name("existing_job")
assert ret
mock_conn().list_processing_jobs.return_value = {
"ProcessingJobSummaries": [{"ProcessingJobName": "existing_job"}]
}

with pytest.warns(DeprecationWarning):
ret = hook.find_processing_job_by_name("existing_job")
assert ret

@mock.patch.object(SageMakerHook, "get_conn")
def test_find_processing_job_by_name_job_not_exists_should_return_false(self, mock_conn):
Expand All @@ -634,8 +639,61 @@ def test_find_processing_job_by_name_job_not_exists_should_return_false(self, mo
)
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")

ret = hook.find_processing_job_by_name("existing_job")
assert not ret
with pytest.warns(DeprecationWarning):
ret = hook.find_processing_job_by_name("existing_job")
assert not ret

@mock.patch.object(SageMakerHook, "get_conn")
def test_count_processing_jobs_by_name(self, mock_conn):
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")
existing_job_name = "existing_job"
mock_conn().list_processing_jobs.return_value = {
"ProcessingJobSummaries": [{"ProcessingJobName": existing_job_name}]
}
ret = hook.count_processing_jobs_by_name(existing_job_name)
assert ret == 1

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch("time.sleep", return_value=None)
def test_count_processing_jobs_by_name_retries_on_throttle_exception(self, _, mock_conn):
throttle_exception = ClientError(
error_response={"Error": {"Code": "ThrottlingException"}}, operation_name="empty"
)
successful_result = {"ProcessingJobSummaries": [{"ProcessingJobName": "existing_job"}]}
# Return a ThrottleException on the first call, then a mocked successful value the second.
mock_conn().list_processing_jobs.side_effect = [throttle_exception, successful_result]
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")

ret = hook.count_processing_jobs_by_name("existing_job")

assert mock_conn().list_processing_jobs.call_count == 2
assert ret == 1

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch("time.sleep", return_value=None)
def test_count_processing_jobs_by_name_fails_after_max_retries(self, _, mock_conn):
mock_conn().list_processing_jobs.side_effect = ClientError(
error_response={"Error": {"Code": "ThrottlingException"}}, operation_name="empty"
)
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")

with pytest.raises(ClientError) as raised_exception:
hook.count_processing_jobs_by_name("existing_job")

# One initial call plus retries
assert mock_conn().list_processing_jobs.call_count == 4
assert raised_exception.value.response["Error"]["Code"] == "ThrottlingException"

@mock.patch.object(SageMakerHook, "get_conn")
def test_count_processing_jobs_by_name_job_not_exists_should_return_falsy(self, mock_conn):
error_resp = {"Error": {"Code": "ResourceNotFound"}}
mock_conn().list_processing_jobs.side_effect = ClientError(
error_response=error_resp, operation_name="empty"
)
hook = SageMakerHook(aws_conn_id="sagemaker_test_conn_id")

ret = hook.count_processing_jobs_by_name("existing_job")
assert ret == 0

@mock_sagemaker
def test_delete_model(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def setUp(self):
)

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
Expand All @@ -113,7 +113,7 @@ def test_integer_fields_without_stopping_condition(
assert sagemaker.config[key1][key2][key3] == int(sagemaker.config[key1][key2][key3])

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
Expand All @@ -136,7 +136,7 @@ def test_integer_fields_with_stopping_condition(self, serialize, mock_processing
sagemaker.config[key1][key2] == int(sagemaker.config[key1][key2])

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
Expand All @@ -153,7 +153,7 @@ def test_execute(self, serialize, mock_processing, mock_hook, mock_client):
)

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=False)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=0)
@mock.patch.object(
SageMakerHook,
"create_processing_job",
Expand Down Expand Up @@ -187,12 +187,12 @@ def test_execute_with_failure(self, mock_processing, mock_client):

@unittest.skip("Currently, the auto-increment jobname functionality is not missing.")
@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=1)
@mock.patch.object(
SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
)
def test_execute_with_existing_job_increment(
self, mock_create_processing_job, find_processing_job_by_name, mock_client
self, mock_create_processing_job, count_processing_jobs_by_name, mock_client
):
sagemaker = SageMakerProcessingOperator(
**self.processing_config_kwargs, config=CREATE_PROCESSING_PARAMS
Expand All @@ -211,7 +211,7 @@ def test_execute_with_existing_job_increment(
)

@mock.patch.object(SageMakerHook, "get_conn")
@mock.patch.object(SageMakerHook, "find_processing_job_by_name", return_value=True)
@mock.patch.object(SageMakerHook, "count_processing_jobs_by_name", return_value=1)
@mock.patch.object(
SageMakerHook, "create_processing_job", return_value={"ResponseMetadata": {"HTTPStatusCode": 200}}
)
Expand Down