Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CloudDataTransferServiceRunJobOperator #39154

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,32 @@ def delete_transfer_job(self, job_name: str, project_id: str) -> None:
.execute(num_retries=self.num_retries)
)

@GoogleBaseHook.fallback_to_default_project_id
def run_transfer_job(self, job_name: str, project_id: str) -> dict:
"""Run Google Storage Transfer Service job.

:param job_name: (Required) Name of the job to be fetched
:param project_id: (Optional) the ID of the project that owns the Transfer
Job. If set to None or missing, the default project_id from the Google Cloud
connection is used.
:return: If successful, Operation. See:
https://cloud.google.com/storage-transfer/docs/reference/rest/v1/Operation

.. seealso:: https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/run

"""
return (
self.get_conn()
.transferJobs()
.run(
jobName=job_name,
body={
PROJECT_ID: project_id,
},
)
.execute(num_retries=self.num_retries)
)

def cancel_transfer_operation(self, operation_name: str) -> None:
"""Cancel a transfer operation in Google Storage Transfer Service.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,82 @@ def execute(self, context: Context) -> None:
hook.delete_transfer_job(job_name=self.job_name, project_id=self.project_id)


class CloudDataTransferServiceRunJobOperator(GoogleCloudBaseOperator):
"""
Runs a transfer job.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:CloudDataTransferServiceRunJobOperator`

:param job_name: (Required) Name of the job to be run
:param project_id: (Optional) the ID of the project that owns the Transfer
Job. If set to None or missing, the default project_id from the Google Cloud
connection is used.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param api_version: API version used (e.g. v1).
:param google_impersonation_chain: Optional Google service account to impersonate using
short-term credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

# [START gcp_transfer_job_run_template_fields]
template_fields: Sequence[str] = (
"job_name",
"project_id",
"gcp_conn_id",
"api_version",
"google_impersonation_chain",
)
# [END gcp_transfer_job_run_template_fields]
operator_extra_links = (CloudStorageTransferJobLink(),)

def __init__(
self,
*,
job_name: str,
gcp_conn_id: str = "google_cloud_default",
api_version: str = "v1",
project_id: str = PROVIDE_PROJECT_ID,
google_impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.job_name = job_name
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
self.google_impersonation_chain = google_impersonation_chain

def _validate_inputs(self) -> None:
if not self.job_name:
raise AirflowException("The required parameter 'job_name' is empty or None")

def execute(self, context: Context) -> dict:
self._validate_inputs()
hook = CloudDataTransferServiceHook(
api_version=self.api_version,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
)

project_id = self.project_id or hook.project_id
if project_id:
CloudStorageTransferJobLink.persist(
context=context,
task_instance=self,
project_id=project_id,
job_name=self.job_name,
)

return hook.run_transfer_job(job_name=self.job_name, project_id=self.project_id)


class CloudDataTransferServiceGetOperationOperator(GoogleCloudBaseOperator):
"""
Gets the latest state of a long-running operation in Google Storage Transfer Service.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,41 @@ See `Google Cloud Transfer Service - REST Resource: transferJobs - Status

.. _howto/operator:CloudDataTransferServiceUpdateJobOperator:

CloudDataTransferServiceRunJobOperator
-----------------------------------------

Runs a transfer job.

For parameter definition, take a look at
:class:`~airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceRunJobOperator`.


Using the operator
""""""""""""""""""

.. exampleinclude:: /../../tests/system/providers/google/cloud/storage_transfer/example_cloud_storage_transfer_service_gcp.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gcp_transfer_run_job]
:end-before: [END howto_operator_gcp_transfer_run_job]

Templating
""""""""""

.. literalinclude:: /../../airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py
:language: python
:dedent: 4
:start-after: [START gcp_transfer_job_run_template_fields]
:end-before: [END gcp_transfer_job_run_template_fields]

More information
""""""""""""""""

See `Google Cloud Transfer Service - REST Resource: transferJobs - Run
<https://cloud.google.com/storage-transfer/docs/reference/rest/v1/transferJobs/run>`_

.. _howto/operator:CloudDataTransferServiceRunJobOperator:

CloudDataTransferServiceUpdateJobOperator
-----------------------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,25 @@ def test_delete_transfer_job(self, get_conn):
)
execute_method.assert_called_once_with(num_retries=5)

@mock.patch(
"airflow.providers.google.cloud.hooks.cloud_storage_transfer_service."
"CloudDataTransferServiceHook.get_conn"
)
def test_run_transfer_job(self, get_conn):
run_method = get_conn.return_value.transferJobs.return_value.run
execute_method = run_method.return_value.execute
execute_method.return_value = TEST_TRANSFER_OPERATION

res = self.gct_hook.run_transfer_job(job_name=TEST_TRANSFER_JOB_NAME, project_id=TEST_PROJECT_ID)
assert res == TEST_TRANSFER_OPERATION
run_method.assert_called_once_with(
jobName=TEST_TRANSFER_JOB_NAME,
body={
PROJECT_ID: TEST_PROJECT_ID,
},
)
execute_method.assert_called_once_with(num_retries=5)

@mock.patch(
"airflow.providers.google.cloud.hooks.cloud_storage_transfer_service"
".CloudDataTransferServiceHook.get_conn"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
CloudDataTransferServiceListOperationsOperator,
CloudDataTransferServicePauseOperationOperator,
CloudDataTransferServiceResumeOperationOperator,
CloudDataTransferServiceRunJobOperator,
CloudDataTransferServiceS3ToGCSOperator,
CloudDataTransferServiceUpdateJobOperator,
TransferJobPreprocessor,
Expand Down Expand Up @@ -493,6 +494,61 @@ def test_job_delete_should_throw_ex_when_name_none(self):
CloudDataTransferServiceDeleteJobOperator(job_name="", task_id="task-id")


class TestGcpStorageTransferJobRunOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
def test_job_run(self, mock_hook):
mock_hook.return_value.run_transfer_job.return_value = VALID_OPERATION
op = CloudDataTransferServiceRunJobOperator(
job_name=JOB_NAME,
project_id=GCP_PROJECT_ID,
task_id="task-id",
google_impersonation_chain=IMPERSONATION_CHAIN,
)
result = op.execute(context=mock.MagicMock())
mock_hook.assert_called_once_with(
api_version="v1",
gcp_conn_id="google_cloud_default",
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.run_transfer_job.assert_called_once_with(
job_name=JOB_NAME, project_id=GCP_PROJECT_ID
)
assert result == VALID_OPERATION

# Setting all the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all
# fields
@pytest.mark.db_test
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
)
def test_job_run_with_templates(self, _, create_task_instance_of_operator):
dag_id = "test_job_run_with_templates"
ti = create_task_instance_of_operator(
CloudDataTransferServiceRunJobOperator,
dag_id=dag_id,
job_name="{{ dag.dag_id }}",
project_id="{{ dag.dag_id }}",
gcp_conn_id="{{ dag.dag_id }}",
api_version="{{ dag.dag_id }}",
google_impersonation_chain="{{ dag.dag_id }}",
task_id=TASK_ID,
)
ti.render_templates()
assert dag_id == ti.task.job_name
assert dag_id == ti.task.project_id
assert dag_id == ti.task.gcp_conn_id
assert dag_id == ti.task.api_version
assert dag_id == ti.task.google_impersonation_chain

def test_job_run_should_throw_ex_when_name_none(self):
op = CloudDataTransferServiceRunJobOperator(job_name="", task_id="task-id")
with pytest.raises(AirflowException, match="The required parameter 'job_name' is empty or None"):
op.execute(context=mock.MagicMock())


class TestGpcStorageTransferOperationsGetOperator:
@mock.patch(
"airflow.providers.google.cloud.operators.cloud_storage_transfer_service.CloudDataTransferServiceHook"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
CloudDataTransferServiceDeleteJobOperator,
CloudDataTransferServiceGetOperationOperator,
CloudDataTransferServiceListOperationsOperator,
CloudDataTransferServiceRunJobOperator,
CloudDataTransferServiceUpdateJobOperator,
)
from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator, GCSDeleteBucketOperator
Expand Down Expand Up @@ -147,6 +148,14 @@
expected_statuses={GcpTransferOperationStatus.SUCCESS},
)

# [START howto_operator_gcp_transfer_run_job]
run_transfer = CloudDataTransferServiceRunJobOperator(
task_id="run_transfer",
job_name="{{task_instance.xcom_pull('create_transfer')['name']}}",
project_id=PROJECT_ID_TRANSFER,
)
# [END howto_operator_gcp_transfer_run_job]

list_operations = CloudDataTransferServiceListOperationsOperator(
task_id="list_operations",
request_filter={
Expand Down Expand Up @@ -180,6 +189,7 @@
>> create_transfer
>> wait_for_transfer
>> update_transfer
>> run_transfer
>> list_operations
>> get_operation
>> [delete_transfer, delete_bucket_src, delete_bucket_dst]
Expand Down
Loading