diff --git a/airflow/providers/samba/hooks/samba.py b/airflow/providers/samba/hooks/samba.py index 535ec267ccf42..895c885d92205 100644 --- a/airflow/providers/samba/hooks/samba.py +++ b/airflow/providers/samba/hooks/samba.py @@ -245,10 +245,19 @@ def setxattr(self, path, attribute, value, flags=0, follow_symlinks=True): **self._conn_kwargs, ) - def push_from_local(self, destination_filepath: str, local_filepath: str): - """Push local file to samba server.""" + def push_from_local(self, destination_filepath: str, local_filepath: str, buffer_size: int | None = None): + """ + Push local file to samba server. + + :param destination_filepath: the samba location to push to + :param local_filepath: the file to push + :param buffer_size: + size in bytes of the individual chunks of file to send. Larger values may + speed up large file transfers + """ + extra_args = (buffer_size,) if buffer_size else () with open(local_filepath, "rb") as f, self.open_file(destination_filepath, mode="wb") as g: - copyfileobj(f, g) + copyfileobj(f, g, *extra_args) @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: diff --git a/airflow/providers/samba/transfers/gcs_to_samba.py b/airflow/providers/samba/transfers/gcs_to_samba.py index fb1cb6ad98b07..bddc038b736ed 100644 --- a/airflow/providers/samba/transfers/gcs_to_samba.py +++ b/airflow/providers/samba/transfers/gcs_to_samba.py @@ -93,6 +93,9 @@ class GCSToSambaOperator(BaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param buffer_size: Optional specification of the size in bytes of the chunks sent to + Samba. Larger buffer lengths may decrease the time to upload large files. The default + length is determined by shutil, which is 64 KB. """ template_fields: Sequence[str] = ( @@ -114,6 +117,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", samba_conn_id: str = "samba_default", impersonation_chain: str | Sequence[str] | None = None, + buffer_size: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -127,6 +131,7 @@ def __init__( self.samba_conn_id = samba_conn_id self.impersonation_chain = impersonation_chain self.sftp_dirs = None + self.buffer_size = buffer_size def execute(self, context: Context): gcs_hook = GCSHook( @@ -154,12 +159,16 @@ def execute(self, context: Context): for source_object in objects: destination_path = self._resolve_destination_path(source_object, prefix=prefix_dirname) - self._copy_single_object(gcs_hook, samba_hook, source_object, destination_path) + self._copy_single_object( + gcs_hook, samba_hook, source_object, destination_path, self.buffer_size + ) self.log.info("Done. Uploaded '%d' files to %s", len(objects), self.destination_path) else: destination_path = self._resolve_destination_path(self.source_object) - self._copy_single_object(gcs_hook, samba_hook, self.source_object, destination_path) + self._copy_single_object( + gcs_hook, samba_hook, self.source_object, destination_path, self.buffer_size + ) self.log.info("Done. Uploaded '%s' file to %s", self.source_object, destination_path) def _resolve_destination_path(self, source_object: str, prefix: str | None = None) -> str: @@ -176,6 +185,7 @@ def _copy_single_object( samba_hook: SambaHook, source_object: str, destination_path: str, + buffer_size: int | None = None, ) -> None: """Copy single object.""" self.log.info( @@ -194,7 +204,7 @@ def _copy_single_object( object_name=source_object, filename=tmp.name, ) - samba_hook.push_from_local(destination_path, tmp.name) + samba_hook.push_from_local(destination_path, tmp.name, buffer_size=buffer_size) if self.move_object: self.log.info("Executing delete of gs://%s/%s", self.source_bucket, source_object) diff --git a/tests/providers/samba/transfers/test_gcs_to_samba.py b/tests/providers/samba/transfers/test_gcs_to_samba.py index 100fde5f7dc75..f335d7842371b 100644 --- a/tests/providers/samba/transfers/test_gcs_to_samba.py +++ b/tests/providers/samba/transfers/test_gcs_to_samba.py @@ -70,7 +70,7 @@ def test_execute_copy_single_file( bucket_name=TEST_BUCKET, object_name=source_object, filename=mock.ANY ) samba_hook_mock.return_value.push_from_local.assert_called_with( - os.path.join(DESTINATION_SMB, target_object), mock.ANY + os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None ) gcs_hook_mock.return_value.delete.assert_not_called() @@ -114,7 +114,52 @@ def test_execute_move_single_file( bucket_name=TEST_BUCKET, object_name=source_object, filename=mock.ANY ) samba_hook_mock.return_value.push_from_local.assert_called_with( - os.path.join(DESTINATION_SMB, target_object), mock.ANY + os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None + ) + gcs_hook_mock.return_value.delete.assert_called_once_with(TEST_BUCKET, source_object) + + @pytest.mark.parametrize( + "source_object, target_object, keep_directory_structure", + [ + ("folder/test_object.txt", "folder/test_object.txt", True), + ("folder/subfolder/test_object.txt", "folder/subfolder/test_object.txt", True), + ("folder/test_object.txt", "test_object.txt", False), + ("folder/subfolder/test_object.txt", "test_object.txt", False), + ], + ) + @mock.patch("airflow.providers.samba.transfers.gcs_to_samba.GCSHook") + @mock.patch("airflow.providers.samba.transfers.gcs_to_samba.SambaHook") + def test_execute_adjust_buffer_size( + self, + samba_hook_mock, + gcs_hook_mock, + source_object, + target_object, + keep_directory_structure, + ): + operator = GCSToSambaOperator( + task_id=TASK_ID, + source_bucket=TEST_BUCKET, + source_object=source_object, + destination_path=DESTINATION_SMB, + keep_directory_structure=keep_directory_structure, + move_object=True, + gcp_conn_id=GCP_CONN_ID, + samba_conn_id=SAMBA_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + buffer_size=128000, + ) + operator.execute(None) + gcs_hook_mock.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + samba_hook_mock.assert_called_once_with(samba_conn_id=SAMBA_CONN_ID) + gcs_hook_mock.return_value.download.assert_called_with( + bucket_name=TEST_BUCKET, object_name=source_object, filename=mock.ANY + ) + samba_hook_mock.return_value.push_from_local.assert_called_with( + os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=128000 ) gcs_hook_mock.return_value.delete.assert_called_once_with(TEST_BUCKET, source_object) @@ -201,7 +246,7 @@ def test_execute_copy_with_wildcard( ) samba_hook_mock.return_value.push_from_local.assert_has_calls( [ - mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY) + mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None) for target_object in target_objects ] ) @@ -290,7 +335,7 @@ def test_execute_move_with_wildcard( ) samba_hook_mock.return_value.push_from_local.assert_has_calls( [ - mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY) + mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None) for target_object in target_objects ] )