Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specification of buffer length for GCS to Samba #38373

Merged
merged 8 commits into from
Apr 5, 2024
15 changes: 12 additions & 3 deletions airflow/providers/samba/hooks/samba.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,19 @@ def setxattr(self, path, attribute, value, flags=0, follow_symlinks=True):
**self._conn_kwargs,
)

def push_from_local(self, destination_filepath: str, local_filepath: str):
"""Push local file to samba server."""
def push_from_local(self, destination_filepath: str, local_filepath: str, buffer_size: int | None = None):
"""
Push local file to samba server.

:param destination_filepath: the samba location to push to
:param local_filepath: the file to push
:param buffer_size:
size in bytes of the individual chunks of file to send. Larger values may
speed up large file transfers
"""
extra_args = (buffer_size,) if buffer_size else ()
with open(local_filepath, "rb") as f, self.open_file(destination_filepath, mode="wb") as g:
copyfileobj(f, g)
copyfileobj(f, g, *extra_args)

@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
Expand Down
16 changes: 13 additions & 3 deletions airflow/providers/samba/transfers/gcs_to_samba.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ class GCSToSambaOperator(BaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param buffer_size: Optional specification of the size in bytes of the chunks sent to
Samba. Larger buffer lengths may decrease the time to upload large files. The default
length is determined by shutil, which is 64 KB.
"""

template_fields: Sequence[str] = (
Expand All @@ -114,6 +117,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
samba_conn_id: str = "samba_default",
impersonation_chain: str | Sequence[str] | None = None,
buffer_size: int | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -127,6 +131,7 @@ def __init__(
self.samba_conn_id = samba_conn_id
self.impersonation_chain = impersonation_chain
self.sftp_dirs = None
self.buffer_size = buffer_size

def execute(self, context: Context):
gcs_hook = GCSHook(
Expand Down Expand Up @@ -154,12 +159,16 @@ def execute(self, context: Context):

for source_object in objects:
destination_path = self._resolve_destination_path(source_object, prefix=prefix_dirname)
self._copy_single_object(gcs_hook, samba_hook, source_object, destination_path)
self._copy_single_object(
gcs_hook, samba_hook, source_object, destination_path, self.buffer_size
)

self.log.info("Done. Uploaded '%d' files to %s", len(objects), self.destination_path)
else:
destination_path = self._resolve_destination_path(self.source_object)
self._copy_single_object(gcs_hook, samba_hook, self.source_object, destination_path)
self._copy_single_object(
gcs_hook, samba_hook, self.source_object, destination_path, self.buffer_size
)
self.log.info("Done. Uploaded '%s' file to %s", self.source_object, destination_path)

def _resolve_destination_path(self, source_object: str, prefix: str | None = None) -> str:
Expand All @@ -176,6 +185,7 @@ def _copy_single_object(
samba_hook: SambaHook,
source_object: str,
destination_path: str,
buffer_size: int | None = None,
) -> None:
"""Copy single object."""
self.log.info(
Expand All @@ -194,7 +204,7 @@ def _copy_single_object(
object_name=source_object,
filename=tmp.name,
)
samba_hook.push_from_local(destination_path, tmp.name)
samba_hook.push_from_local(destination_path, tmp.name, buffer_size=buffer_size)

if self.move_object:
self.log.info("Executing delete of gs://%s/%s", self.source_bucket, source_object)
Expand Down
53 changes: 49 additions & 4 deletions tests/providers/samba/transfers/test_gcs_to_samba.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_execute_copy_single_file(
bucket_name=TEST_BUCKET, object_name=source_object, filename=mock.ANY
)
samba_hook_mock.return_value.push_from_local.assert_called_with(
os.path.join(DESTINATION_SMB, target_object), mock.ANY
os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None
)
gcs_hook_mock.return_value.delete.assert_not_called()

Expand Down Expand Up @@ -114,7 +114,52 @@ def test_execute_move_single_file(
bucket_name=TEST_BUCKET, object_name=source_object, filename=mock.ANY
)
samba_hook_mock.return_value.push_from_local.assert_called_with(
os.path.join(DESTINATION_SMB, target_object), mock.ANY
os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None
)
gcs_hook_mock.return_value.delete.assert_called_once_with(TEST_BUCKET, source_object)

@pytest.mark.parametrize(
"source_object, target_object, keep_directory_structure",
[
("folder/test_object.txt", "folder/test_object.txt", True),
("folder/subfolder/test_object.txt", "folder/subfolder/test_object.txt", True),
("folder/test_object.txt", "test_object.txt", False),
("folder/subfolder/test_object.txt", "test_object.txt", False),
],
)
@mock.patch("airflow.providers.samba.transfers.gcs_to_samba.GCSHook")
@mock.patch("airflow.providers.samba.transfers.gcs_to_samba.SambaHook")
def test_execute_adjust_buffer_size(
self,
samba_hook_mock,
gcs_hook_mock,
source_object,
target_object,
keep_directory_structure,
):
operator = GCSToSambaOperator(
task_id=TASK_ID,
source_bucket=TEST_BUCKET,
source_object=source_object,
destination_path=DESTINATION_SMB,
keep_directory_structure=keep_directory_structure,
move_object=True,
gcp_conn_id=GCP_CONN_ID,
samba_conn_id=SAMBA_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
buffer_size=128000,
)
operator.execute(None)
gcs_hook_mock.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
samba_hook_mock.assert_called_once_with(samba_conn_id=SAMBA_CONN_ID)
gcs_hook_mock.return_value.download.assert_called_with(
bucket_name=TEST_BUCKET, object_name=source_object, filename=mock.ANY
)
samba_hook_mock.return_value.push_from_local.assert_called_with(
os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=128000
)
gcs_hook_mock.return_value.delete.assert_called_once_with(TEST_BUCKET, source_object)

Expand Down Expand Up @@ -201,7 +246,7 @@ def test_execute_copy_with_wildcard(
)
samba_hook_mock.return_value.push_from_local.assert_has_calls(
[
mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY)
mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None)
for target_object in target_objects
]
)
Expand Down Expand Up @@ -290,7 +335,7 @@ def test_execute_move_with_wildcard(
)
samba_hook_mock.return_value.push_from_local.assert_has_calls(
[
mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY)
mock.call(os.path.join(DESTINATION_SMB, target_object), mock.ANY, buffer_size=None)
for target_object in target_objects
]
)
Expand Down