Skip to content

Commit

Permalink
enhance download_and_extract (#8216)
Browse files Browse the repository at this point in the history
Fixes #5463 

### Description

According to issue, the error messages are not very intuitive. 
I think maybe we can check if the file name matches the downloaded
file’s base name before starting the download.
If it doesn’t match, it will notify user.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: jerome_Hsieh <jerome910810@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 21, 2024
1 parent e1e3d8e commit efff647
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
39 changes: 38 additions & 1 deletion monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import json
import logging
import os
import re
import shutil
import sys
import tarfile
Expand All @@ -30,7 +31,9 @@
from monai.config.type_definitions import PathLike
from monai.utils import look_up_option, min_version, optional_import

requests, has_requests = optional_import("requests")
gdown, has_gdown = optional_import("gdown", "4.7.3")
BeautifulSoup, has_bs4 = optional_import("bs4", name="BeautifulSoup")

if TYPE_CHECKING:
from tqdm import tqdm
Expand Down Expand Up @@ -298,6 +301,29 @@ def extractall(
)


def get_filename_from_url(data_url: str) -> str:
"""
Get the filename from the URL link.
"""
try:
response = requests.head(data_url, allow_redirects=True)
content_disposition = response.headers.get("Content-Disposition")
if content_disposition:
filename = re.findall('filename="?([^";]+)"?', content_disposition)
if filename:
return str(filename[0])
if "drive.google.com" in data_url:
response = requests.get(data_url)
if "text/html" in response.headers.get("Content-Type", ""):
soup = BeautifulSoup(response.text, "html.parser")
filename_div = soup.find("span", {"class": "uc-name-size"})
if filename_div:
return str(filename_div.find("a").text)
return _basename(data_url)
except Exception as e:
raise Exception(f"Error processing URL: {e}") from e


def download_and_extract(
url: str,
filepath: PathLike = "",
Expand Down Expand Up @@ -327,7 +353,18 @@ def download_and_extract(
be False.
progress: whether to display progress bar.
"""
url_filename_ext = "".join(Path(get_filename_from_url(url)).suffixes)
filepath_ext = "".join(Path(_basename(filepath)).suffixes)
if filepath not in ["", "."]:
if filepath_ext == "":
new_filepath = Path(filepath).with_suffix(url_filename_ext)
logger.warning(
f"filepath={filepath}, which missing file extension. Auto-appending extension to: {new_filepath}"
)
filepath = new_filepath
if filepath_ext and filepath_ext != url_filename_ext:
raise ValueError(f"File extension mismatch: expected extension {url_filename_ext}, but get {filepath_ext}")
with tempfile.TemporaryDirectory() as tmp_dir:
filename = filepath or Path(tmp_dir, _basename(url)).resolve()
filename = filepath or Path(tmp_dir, get_filename_from_url(url)).resolve()
download_url(url=url, filepath=filename, hash_val=hash_val, hash_type=hash_type, progress=progress)
extractall(filepath=filename, output_dir=output_dir, file_type=file_type, has_base=has_base)
3 changes: 2 additions & 1 deletion tests/test_download_and_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from parameterized import parameterized

from monai.apps import download_and_extract, download_url, extractall
from tests.utils import skip_if_downloading_fails, skip_if_quick, testing_data_config
from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick, testing_data_config


@SkipIfNoModule("requests")
class TestDownloadAndExtract(unittest.TestCase):

@skip_if_quick
Expand Down

0 comments on commit efff647

Please sign in to comment.