Skip to content

Commit

Permalink
Fixed config.json download to go to user-supplied cache directory.
Browse files Browse the repository at this point in the history
  • Loading branch information
ulatekh committed Apr 11, 2024
1 parent a5e5c92 commit e1b14a1
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,10 @@ def pipeline(
pretrained_model_name_or_path = model

if not isinstance(config, PretrainedConfig) and pretrained_model_name_or_path is not None:
# cached_file needs the cache directory, so that config.json is found in the right place.
if "cache_dir" in model_kwargs:
hub_kwargs["cache_dir"] = model_kwargs["cache_dir"]

# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
resolved_config_file = cached_file(
pretrained_model_name_or_path,
Expand All @@ -785,6 +789,10 @@ def pipeline(
**hub_kwargs,
)
hub_kwargs["_commit_hash"] = extract_commit_hash(resolved_config_file, commit_hash)

# Remove the cache directory from hub_kwargs, so it doesn't conflict with model_kwargs.
if "cache_dir" in model_kwargs:
del hub_kwargs["cache_dir"]
else:
hub_kwargs["_commit_hash"] = getattr(config, "_commit_hash", None)

Expand Down

0 comments on commit e1b14a1

Please sign in to comment.