Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 Simplify download_with_progress inputs and fix clean up logic #506

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
🐛 Simplify download_with_progress inputs and fix clean up logic
* File buffers were not being closed correctly when exceptions were raised
* Move buffer management logic to top level function
* Passed in buffers are no longer closed when an exception is raised.
* Pull expected size from Content-Length header
  • Loading branch information
kevinsantana11 committed Aug 7, 2024
commit 12ca5f666bf731193225439b7dc8d7638ab72a6f
2 changes: 1 addition & 1 deletion clouddrift/adapters/andro.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def to_xarray(tmp_path: str | None = None):

# get or update dataset
local_file = f"{tmp_path}/{ANDRO_URL.split('/')[-1]}"
download_with_progress([(ANDRO_URL, local_file, None)])
download_with_progress([(ANDRO_URL, local_file)])

# parse with panda
col_names = [
Expand Down
2 changes: 1 addition & 1 deletion clouddrift/adapters/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_dataframes() -> tuple[pd.DataFrame, pd.DataFrame]:
)
sorted_data_urls = [data_urls[i] for i in sorted_indices]
buffers = [BytesIO() for _ in range(len(sorted_data_urls))]
requests = [(url, buffer, None) for url, buffer in zip(sorted_data_urls, buffers)]
requests = [(url, buffer) for url, buffer in zip(sorted_data_urls, buffers)]

download_with_progress(requests, desc="Downloading data")
[b.seek(0) for b in buffers]
Expand Down
2 changes: 1 addition & 1 deletion clouddrift/adapters/subsurface_floats.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@


def download(file: str):
download_with_progress([(SUBSURFACE_FLOATS_DATA_URL, file, None)])
download_with_progress([(SUBSURFACE_FLOATS_DATA_URL, file)])


def to_xarray(
Expand Down
42 changes: 30 additions & 12 deletions clouddrift/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def _before_call(rcs: RetryCallState):


def download_with_progress(
download_map: Sequence[tuple[str, BufferedIOBase | str, float | None]],
download_map: Sequence[
tuple[str, BufferedIOBase | str] | tuple[str, BufferedIOBase | str, float]
],
show_list_progress: bool | None = None,
desc: str = "Downloading files",
custom_retry_protocol: Callable[[WrappedFn], WrappedFn] | None = None,
Expand All @@ -66,10 +68,19 @@ def download_with_progress(

buffer: BufferedIOBase | BufferedWriter
executor = concurrent.futures.ThreadPoolExecutor()
futures: dict[concurrent.futures.Future, tuple[str, BufferedIOBase | str]] = dict()
futures: dict[
concurrent.futures.Future,
tuple[str, BufferedIOBase | str, BufferedIOBase | BufferedWriter],
] = dict()
bar = None

for src, dst, exp_size in download_map:
for request in download_map:
if len(request) > 2:
src, dst, exp_size = request
else:
src, dst = request
exp_size = None

if isinstance(dst, (str,)):
buffer = open(dst, "wb")
else:
Expand All @@ -80,10 +91,11 @@ def download_with_progress(
retry_protocol(_download_with_progress),
src,
buffer,
exp_size or 0,
exp_size,
not show_list_progress,
)
] = (src, dst)
] = (src, dst, buffer)

try:
if show_list_progress:
bar = tqdm(
Expand All @@ -94,7 +106,11 @@ def download_with_progress(
)

for fut in concurrent.futures.as_completed(futures):
(src, dst) = futures[fut]
src, dst, buffer = futures[fut]

if isinstance(dst, (str,)):
buffer.close()

ex = fut.exception(0)
if ex is None:
_logger.debug(f"Finished download job: ({src}, {dst})")
Expand All @@ -108,14 +124,13 @@ def download_with_progress(
any created resources."
)
for x in futures.keys():
(src, dst) = futures[x]
src, dst, buffer = futures[x]

if not x.done():
x.cancel()

if isinstance(dst, (str,)) and os.path.exists(dst):
os.remove(dst)
elif isinstance(dst, (BufferedIOBase,)):
dst.close()
raise e
finally:
executor.shutdown(True)
Expand All @@ -126,7 +141,7 @@ def download_with_progress(
def _download_with_progress(
url: str,
output: BufferedIOBase | BufferedWriter,
expected_size: float,
expected_size: float | None,
show_progress: bool,
):
if isinstance(output, str) and os.path.exists(output):
Expand Down Expand Up @@ -158,13 +173,16 @@ def _download_with_progress(
try:
response = requests.get(url, timeout=5, stream=True)

if (content_length := response.headers.get("Content-Length")) is not None:
expected_size = float(content_length)

if show_progress:
bar = tqdm(
desc=url,
total=float(response.headers.get("Content-Length", expected_size)),
total=expected_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
unit_divisor=_CHUNK_SIZE,
nrows=2,
disable=_DISABLE_SHOW_PROGRESS,
)
Expand Down
4 changes: 2 additions & 2 deletions clouddrift/adapters/yomaha.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@

def download(tmp_path: str):
download_requests = [
(url, f"{tmp_path}/{url.split('/')[-1]}", None) for url in YOMAHA_URLS[:-1]
(url, f"{tmp_path}/{url.split('/')[-1]}") for url in YOMAHA_URLS[:-1]
]
download_with_progress(download_requests)

filename_gz = f"{tmp_path}/{YOMAHA_URLS[-1].split('/')[-1]}"
filename = filename_gz.removesuffix(".gz")

buffer = BytesIO()
download_with_progress([(YOMAHA_URLS[-1], buffer, None)])
download_with_progress([(YOMAHA_URLS[-1], buffer)])

decompressed_fp = os.path.join(tmp_path, filename)
with (
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ files = [
]

[tool.pytest.ini_options]
testpaths = ["tests/*_tests.py", "tests/adapters/*_tests.py"]
testpaths = [
"tests/**/*_tests.py",
]


[[tool.mypy.overrides]]
Expand Down
Loading