From 2a61d34cca8807bd3d922eb696541c4cfd3b8ed0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 23 Oct 2024 10:53:29 -0700 Subject: [PATCH 1/3] All default pretrained weights pushed to HF hub, stragglers uploaded to timm org for simplicity. * OpenAI models no longer use special path that loads from torchscript archive, use same path as other models * Handling of QuickGELU consistent btw openai and non, made a bit more safe, warn on mismatch * safetensors is default weight load if available --- src/open_clip/factory.py | 252 +++++++++++------- .../model_configs/RN50x16-quickgelu.json | 22 ++ .../model_configs/RN50x4-quickgelu.json | 22 ++ .../model_configs/RN50x64-quickgelu.json | 22 ++ src/open_clip/model_configs/ViT-H-14-378.json | 17 ++ .../model_configs/ViT-L-14-336-quickgelu.json | 17 ++ src/open_clip/pretrained.py | 242 +++++++++++------ src/open_clip/push_to_hf_hub.py | 1 + 8 files changed, 413 insertions(+), 182 deletions(-) create mode 100644 src/open_clip/model_configs/RN50x16-quickgelu.json create mode 100644 src/open_clip/model_configs/RN50x4-quickgelu.json create mode 100644 src/open_clip/model_configs/RN50x64-quickgelu.json create mode 100644 src/open_clip/model_configs/ViT-H-14-378.json create mode 100644 src/open_clip/model_configs/ViT-L-14-336-quickgelu.json diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 2a722eec6..3e40355a6 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -2,6 +2,7 @@ import logging import os import re +import warnings from copy import deepcopy from dataclasses import asdict from pathlib import Path @@ -222,8 +223,58 @@ def create_model( cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, require_pretrained: bool = False, + load_weights_only: bool = True, **model_kwargs, ): + """Creates and configures a contrastive vision-language model. + + Args: + model_name: Name of the model architecture to create. Can be a local model name + or a Hugging Face model ID prefixed with 'hf-hub:'. + pretrained: Tag/path for pretrained model weights. Can be: + - A pretrained tag name (e.g., 'openai') + - A path to local weights + - None to initialize with random weights + precision: Model precision/AMP configuration. Options: + - 'fp32': 32-bit floating point + - 'fp16'/'bf16': Mixed precision with FP32 for certain layers + - 'pure_fp16'/'pure_bf16': Pure 16-bit precision + device: Device to load the model on ('cpu', 'cuda', or torch.device object) + jit: If True, JIT compile the model + force_quick_gelu: Force use of QuickGELU activation + force_custom_text: Force use of custom text encoder + force_patch_dropout: Override default patch dropout value + force_image_size: Override default image size for vision encoder + force_preprocess_cfg: Override default preprocessing configuration + pretrained_image: Load pretrained weights for timm vision models + pretrained_hf: Load pretrained weights for HF text models when not loading CLIP weights + cache_dir: Override default cache directory for downloaded model files + output_dict: If True and model supports it, return dictionary of features + require_pretrained: Raise error if pretrained weights cannot be loaded + load_weights_only: Only deserialize model weights and unpickling torch checkpoints (for safety) + **model_kwargs: Additional keyword arguments passed to model constructor + + Returns: + Created and configured model instance + + Raises: + RuntimeError: If model config is not found or required pretrained weights + cannot be loaded + + Examples: + # Create basic CLIP model + model = create_model('ViT-B/32') + + # Create CLIP model with mixed precision on GPU + model = create_model('ViT-B/32', precision='fp16', device='cuda') + + # Load pretrained OpenAI weights + model = create_model('ViT-B/32', pretrained='openai') + + # Load Hugging Face model + model = create_model('hf-hub:organization/model-name') + """ + force_preprocess_cfg = force_preprocess_cfg or {} preprocess_cfg = asdict(PreprocessCfg()) has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) @@ -242,112 +293,113 @@ def create_model( if isinstance(device, str): device = torch.device(device) - if pretrained and pretrained.lower() == 'openai': - logging.info(f'Loading pretrained {model_name} from OpenAI.') - model = load_openai_model( - model_name, - precision=precision, - device=device, - cache_dir=cache_dir, - ) + model_cfg = model_cfg or get_model_config(model_name) + if model_cfg is not None: + logging.info(f'Loaded {model_name} model config.') else: - model_cfg = model_cfg or get_model_config(model_name) - if model_cfg is not None: - logging.info(f'Loaded {model_name} model config.') + logging.error(f'Model config for {model_name} not found; available models {list_models()}.') + raise RuntimeError(f'Model config for {model_name} not found.') + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + if force_patch_dropout is not None: + # override the default patch dropout value + model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout + + if force_image_size is not None: + # override model config's image size + model_cfg["vision_cfg"]["image_size"] = force_image_size + + is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) + if pretrained_image: + if is_timm_model: + # pretrained weight loading for timm models set via vision_cfg + model_cfg['vision_cfg']['timm_model_pretrained'] = True else: - logging.error(f'Model config for {model_name} not found; available models {list_models()}.') - raise RuntimeError(f'Model config for {model_name} not found.') - - if force_quick_gelu: - # override for use of QuickGELU on non-OpenAI transformer models - model_cfg["quick_gelu"] = True - - if force_patch_dropout is not None: - # override the default patch dropout value - model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout - - if force_image_size is not None: - # override model config's image size - model_cfg["vision_cfg"]["image_size"] = force_image_size - - is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) - if pretrained_image: - if is_timm_model: - # pretrained weight loading for timm models set via vision_cfg - model_cfg['vision_cfg']['timm_model_pretrained'] = True - else: - assert False, 'pretrained image towers currently only supported for timm models' - - # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes - cast_dtype = get_cast_dtype(precision) - is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) - if is_hf_model: - # load pretrained weights for HF text model IFF no CLIP weights being loaded - model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained - custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model - - model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) - if custom_text: - if "multimodal_cfg" in model_cfg: - model = CoCa(**model_cfg, cast_dtype=cast_dtype) - else: - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + assert False, 'pretrained image towers currently only supported for timm models' + + # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes + cast_dtype = get_cast_dtype(precision) + is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) + if is_hf_model: + # load pretrained weights for HF text model IFF no CLIP weights being loaded + model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained + custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model + + model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) + if custom_text: + if "multimodal_cfg" in model_cfg: + model = CoCa(**model_cfg, cast_dtype=cast_dtype) else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) - - if precision in ("fp16", "bf16"): - dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 - # manual mixed precision that matches original OpenAI behaviour - if is_timm_model: - # FIXME this is a bit janky, create timm based model in low-precision and - # then cast only LayerNormFp32 instances back to float32 so they don't break. - # Why? The convert_weights_to_lp fn only works with native models. - model.to(device=device, dtype=dtype) - from .transformer import LayerNormFp32 - - def _convert_ln(m): - if isinstance(m, LayerNormFp32): - m.weight.data = m.weight.data.to(torch.float32) - m.bias.data = m.bias.data.to(torch.float32) - model.apply(_convert_ln) - else: - model.to(device=device) - convert_weights_to_lp(model, dtype=dtype) - elif precision in ("pure_fp16", "pure_bf16"): - dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) + else: + model = CLIP(**model_cfg, cast_dtype=cast_dtype) + + if precision in ("fp16", "bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + # manual mixed precision that matches original OpenAI behaviour + if is_timm_model: + # FIXME this is a bit janky, create timm based model in low-precision and + # then cast only LayerNormFp32 instances back to float32 so they don't break. + # Why? The convert_weights_to_lp fn only works with native models. model.to(device=device, dtype=dtype) + from .transformer import LayerNormFp32 + + def _convert_ln(m): + if isinstance(m, LayerNormFp32): + m.weight.data = m.weight.data.to(torch.float32) + m.bias.data = m.bias.data.to(torch.float32) + model.apply(_convert_ln) else: model.to(device=device) - - pretrained_loaded = False - if pretrained: - checkpoint_path = '' - pretrained_cfg = get_pretrained_cfg(model_name, pretrained) - if pretrained_cfg: - checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) - preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) - elif os.path.exists(pretrained): - checkpoint_path = pretrained - - if checkpoint_path: - logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') - load_checkpoint(model, checkpoint_path) - else: - error_str = ( - f'Pretrained weights ({pretrained}) not found for model {model_name}.' - f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') - logging.warning(error_str) - raise RuntimeError(error_str) - pretrained_loaded = True - elif has_hf_hub_prefix: - logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') - load_checkpoint(model, checkpoint_path) - pretrained_loaded = True - - if require_pretrained and not pretrained_loaded: - # callers of create_model_from_pretrained always expect pretrained weights - raise RuntimeError( - f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') + convert_weights_to_lp(model, dtype=dtype) + elif precision in ("pure_fp16", "pure_bf16"): + dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 + model.to(device=device, dtype=dtype) + else: + model.to(device=device) + + pretrained_loaded = False + if pretrained: + checkpoint_path = '' + pretrained_cfg = get_pretrained_cfg(model_name, pretrained) + if pretrained_cfg: + checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) + preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) + pretrained_quick_gelu = pretrained_cfg.get('quick_gelu', False) + model_quick_gelu = model_cfg.get('quick_gelu', False) + if pretrained_quick_gelu and not model_quick_gelu: + warnings.warn( + f'These pretrained weights were trained with QuickGELU activation but the model config does ' + f'not have that enabled. Consider using a model config with a "-quickgelu" suffix or enable with a flag.') + elif not pretrained_quick_gelu and model_quick_gelu: + warnings.warn( + f'The pretrained weights were not trained with QuickGELU but this activation is enabled in the ' + f'model config, consider using a model config without QuickGELU or disable override flags.') + elif os.path.exists(pretrained): + checkpoint_path = pretrained + + if checkpoint_path: + logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') + load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) + else: + error_str = ( + f'Pretrained weights ({pretrained}) not found for model {model_name}.' + f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') + logging.warning(error_str) + raise RuntimeError(error_str) + pretrained_loaded = True + elif has_hf_hub_prefix: + logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') + load_checkpoint(model, checkpoint_path, weights_only=load_weights_only) + pretrained_loaded = True + + if require_pretrained and not pretrained_loaded: + # callers of create_model_from_pretrained always expect pretrained weights + raise RuntimeError( + f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') if output_dict and hasattr(model, "output_dict"): model.output_dict = True diff --git a/src/open_clip/model_configs/RN50x16-quickgelu.json b/src/open_clip/model_configs/RN50x16-quickgelu.json new file mode 100644 index 000000000..989bb87c6 --- /dev/null +++ b/src/open_clip/model_configs/RN50x16-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 768, + "quick_gelu": true, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/RN50x4-quickgelu.json b/src/open_clip/model_configs/RN50x4-quickgelu.json new file mode 100644 index 000000000..9bf11fc3a --- /dev/null +++ b/src/open_clip/model_configs/RN50x4-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 640, + "quick_gelu": true, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/RN50x64-quickgelu.json b/src/open_clip/model_configs/RN50x64-quickgelu.json new file mode 100644 index 000000000..6da9d7e21 --- /dev/null +++ b/src/open_clip/model_configs/RN50x64-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 448, + "layers": [ + 3, + 15, + 36, + 10 + ], + "width": 128, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-H-14-378.json b/src/open_clip/model_configs/ViT-H-14-378.json new file mode 100644 index 000000000..04b2e62d6 --- /dev/null +++ b/src/open_clip/model_configs/ViT-H-14-378.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 378, + "layers": 32, + "width": 1280, + "head_width": 80, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 1024, + "heads": 16, + "layers": 24 + } +} \ No newline at end of file diff --git a/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json b/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json new file mode 100644 index 000000000..d928c0284 --- /dev/null +++ b/src/open_clip/model_configs/ViT-L-14-336-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 768, + "quick_gelu": true, + "vision_cfg": { + "image_size": 336, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index 8c89d3035..aac87619d 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -1,3 +1,4 @@ +import copy import hashlib import os import urllib @@ -91,60 +92,81 @@ def _mccfg(url='', hf_hub='', **kwargs): _RN50 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), - yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), - cc12m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), -) - -_RN50_quickgelu = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), + url="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + hf_hub="timm/resnet50_clip.openai/", + quick_gelu=True, + ), yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + hf_hub="timm/resnet50_clip.yfcc15m/", + quick_gelu=True, + ), cc12m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", + hf_hub="timm/resnet50_clip.cc12m/", + quick_gelu=True, + ), ) _RN101 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), - yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), -) - -_RN101_quickgelu = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), + url="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + hf_hub="timm/resnet101_clip.openai/", + quick_gelu=True, + ), yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", + hf_hub="timm/resnet101_clip.yfcc15m/", + quick_gelu=True, + ), ) _RN50x4 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), + url="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", + hf_hub="timm/resnet50x4_clip.openai/", + quick_gelu=True, + ), ) _RN50x16 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), + url="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", + hf_hub="timm/resnet50x16_clip.openai/", + quick_gelu=True, + ), ) _RN50x64 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), + url="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", + hf_hub="timm/resnet50x64_clip.openai/", + quick_gelu=True, + ), ) _VITB32 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), + url="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + hf_hub="timm/vit_base_patch32_clip_224.openai/", + quick_gelu=True, + ), + # LAION 400M (quick gelu) laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + hf_hub="timm/vit_base_patch32_clip_224.laion400m_e31/", + quick_gelu=True, + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + hf_hub="timm/vit_base_patch32_clip_224.laion400m_e32/", + quick_gelu=True, + ), + # LAION 2B-en laion2b_e16=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth", + hf_hub="timm/vit_base_patch32_clip_224.laion2b_e16/", + ), laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), @@ -164,19 +186,15 @@ def _mccfg(url='', hf_hub='', **kwargs): commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), -) - -_VITB32_quickgelu = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), - laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), - laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), + # MetaClip models (NOTE quick-gelu activation used) metaclip_400m=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt", + quick_gelu=True, + ), metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt", + quick_gelu=True, + ), ) _VITB32_256 = dict( @@ -185,11 +203,20 @@ def _mccfg(url='', hf_hub='', **kwargs): _VITB16 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), + url="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", + hf_hub="timm/vit_base_patch16_clip_224.openai/", + quick_gelu=True, + ), + # LAION-400M laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt", + hf_hub="timm/vit_base_patch16_clip_224.laion400m_e31/", + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt", + hf_hub="timm/vit_base_patch16_clip_224.laion400m_e32/", + ), + # LAION-2B laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), # DataComp-XL models datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), @@ -202,30 +229,50 @@ def _mccfg(url='', hf_hub='', **kwargs): commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), # DFN - dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/') -) - -_VITB16_quickgelu = dict( + dfn2b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-B-16/', + quick_gelu=True, + ), + # MetaCLIP (these are quick-gelu) metaclip_400m=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt", + hf_hub="timm/vit_base_patch16_clip_224.metaclip_400m/", + quick_gelu=True, + ), metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt", + hf_hub="timm/vit_base_patch16_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), ) _VITB16_PLUS_240 = dict( laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt", + hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt", + hf_hub="timm/vit_base_patch16_plus_clip_240.laion400m_e31/", + ), ) _VITL14 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), + url="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", + hf_hub="timm/vit_large_patch14_clip_224.openai/", + quick_gelu=True, + ), + # LAION-400M laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt", + hf_hub="timm/vit_large_patch14_clip_224.laion400m_e31/", + ), laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), + url="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt", + hf_hub="timm/vit_large_patch14_clip_224.laion400m_e32/", + ), + # LAION-2B-en laion2b_s32b_b82k=_pcfg( hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', mean=INCEPTION_MEAN, std=INCEPTION_STD), @@ -234,38 +281,55 @@ def _mccfg(url='', hf_hub='', **kwargs): commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), -) - -_VITL14_quickgelu = dict( + # MetaCLIP metaclip_400m=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt", + hf_hub="timm/vit_large_patch14_clip_224.metaclip_400m/", + quick_gelu=True, + ), metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt"), - dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt", + hf_hub="timm/vit_large_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), + # DFN-2B (quick-gelu) + dfn2b=_pcfg( + hf_hub='apple/DFN2B-CLIP-ViT-L-14/', + quick_gelu=True, + ), ) _VITL14_336 = dict( openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), + url="https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", + hf_hub="timm/vit_large_patch14_clip_336.openai/", + quick_gelu=True, + ), ) _VITH14 = dict( + # LAION-2B-en laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), -) - -_VITH14_quickgelu = dict( + # MetaCLIP (quick-gelu) metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt"), + url="https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt", + hf_hub="timm/vit_huge_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), + # DFN-5B (quick-gelu) dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14/', + quick_gelu=True, interpolation="bicubic", resize_mode="squash" ), ) -_VITH14_378_quickgelu = dict( +_VITH14_378 = dict( + # DFN-5B (quick-gelu) dfn5b=_pcfg( hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', + quick_gelu=True, interpolation="bicubic", resize_mode="squash" ), @@ -277,11 +341,14 @@ def _mccfg(url='', hf_hub='', **kwargs): ) _VITbigG14 = dict( + # LAION-2B-en laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), -) - -_VITbigG14_quickgelu = dict( - metaclip_fullcc=_pcfg(url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt'), + # MetaCLIP (quick-gelu) + metaclip_fullcc=_pcfg( + url='https://dl.fbaipublicfiles.com/MMPT/metaclip/G14_fullcc2.5b.pt', + hf_hub="timm/vit_gigantic_patch14_clip_224.metaclip_2pt5b/", + quick_gelu=True, + ), ) _robertaViTB32 = dict( @@ -339,28 +406,21 @@ def _mccfg(url='', hf_hub='', **kwargs): _PRETRAINED = { "RN50": _RN50, - "RN50-quickgelu": _RN50_quickgelu, "RN101": _RN101, - "RN101-quickgelu": _RN101_quickgelu, "RN50x4": _RN50x4, "RN50x16": _RN50x16, "RN50x64": _RN50x64, "ViT-B-32": _VITB32, "ViT-B-32-256": _VITB32_256, - "ViT-B-32-quickgelu": _VITB32_quickgelu, "ViT-B-16": _VITB16, - "ViT-B-16-quickgelu": _VITB16_quickgelu, "ViT-B-16-plus-240": _VITB16_PLUS_240, "ViT-L-14": _VITL14, - "ViT-L-14-quickgelu": _VITL14_quickgelu, "ViT-L-14-336": _VITL14_336, "ViT-H-14": _VITH14, - "ViT-H-14-quickgelu": _VITH14_quickgelu, - "ViT-H-14-378-quickgelu": _VITH14_378_quickgelu, + "ViT-H-14-378": _VITH14_378, "ViT-g-14": _VITg14, "ViT-bigG-14": _VITbigG14, - "ViT-bigG-14-quickgelu": _VITbigG14_quickgelu, "roberta-ViT-B-32": _robertaViTB32, "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, @@ -531,6 +591,15 @@ def _mccfg(url='', hf_hub='', **kwargs): ), } +_PRETRAINED_quickgelu = {} +for k, v in _PRETRAINED.items(): + quick_gelu_tags = {} + for tk, tv in v.items(): + if tv.get('quick_gelu', False): + quick_gelu_tags[tk] = copy.deepcopy(tv) + if quick_gelu_tags: + _PRETRAINED_quickgelu[k + '-quickgelu'] = quick_gelu_tags +_PRETRAINED.update(_PRETRAINED_quickgelu) def _clean_tag(tag: str): # normalize pretrained tags @@ -662,7 +731,11 @@ def download_pretrained_from_hf( for safe_filename in _get_safe_alternatives(filename): try: cached_file = hf_hub_download( - repo_id=model_id, filename=safe_filename, revision=revision, cache_dir=cache_dir) + repo_id=model_id, + filename=safe_filename, + revision=revision, + cache_dir=cache_dir, + ) return cached_file except Exception: pass @@ -670,7 +743,11 @@ def download_pretrained_from_hf( try: # Attempt to download the file cached_file = hf_hub_download( - repo_id=model_id, filename=filename, revision=revision, cache_dir=cache_dir) + repo_id=model_id, + filename=filename, + revision=revision, + cache_dir=cache_dir, + ) return cached_file # Return the path to the downloaded file if successful except Exception as e: raise FileNotFoundError(f"Failed to download any files for {model_id}. Last error: {e}") @@ -678,17 +755,18 @@ def download_pretrained_from_hf( def download_pretrained( cfg: Dict, - force_hf_hub: bool = False, + prefer_hf_hub: bool = True, cache_dir: Optional[str] = None, ): target = '' if not cfg: return target + has_hub = has_hf_hub() download_url = cfg.get('url', '') download_hf_hub = cfg.get('hf_hub', '') - if download_hf_hub and force_hf_hub: - # use HF hub even if url exists + if has_hub and prefer_hf_hub and download_hf_hub: + # prefer to use HF hub, remove url info download_url = '' if download_url: diff --git a/src/open_clip/push_to_hf_hub.py b/src/open_clip/push_to_hf_hub.py index 867a6d5f3..6a8eeedb9 100644 --- a/src/open_clip/push_to_hf_hub.py +++ b/src/open_clip/push_to_hf_hub.py @@ -114,6 +114,7 @@ def push_to_hf_hub( try: repo_files = set(list_repo_files(repo_id)) repo_exists = True + print('Repo exists', repo_files) except Exception as e: print('Repo does not exist', e) From c3c3d7c6b55f64d9e97c1739db1ccb5dfa1cc2fc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 23 Oct 2024 11:02:42 -0700 Subject: [PATCH 2/3] Remove safeglobals add, not worth having with all pretrained weights on hub and numpy 1 vs 2 issues --- src/open_clip/factory.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 3e40355a6..88dc03aef 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -10,13 +10,11 @@ import torch -from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD from .convert import convert_state_dict from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg from .coca_model import CoCa from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss -from .openai import load_openai_model from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ list_pretrained_tags_by_model, download_pretrained_from_hf from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs @@ -27,20 +25,6 @@ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs -try: - import _codecs - import numpy as np - # add safe globals that are known to be needed for metaclip weights loading in weights_only=True mode - torch.serialization.add_safe_globals([ - _codecs.encode, # this one not needed for PyTorch >= 2.5.0 - np.core.multiarray.scalar, - np.dtype, - np.dtypes.Float64DType, - ]) -except Exception: - pass - - def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] From 08991ac41c167661c6897f0011f68f0028237237 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 23 Oct 2024 11:53:00 -0700 Subject: [PATCH 3/3] Improve cache_dir behaviour * make use of cache_dir for HF tokenizer wrapper explicit * add missing use of cache_dir for an _get_hf_config call * add cache_dir as argument to train/val script --- src/open_clip/factory.py | 40 +++++++++++++++++++++++++++++------ src/open_clip/pretrained.py | 6 +++--- src/open_clip/tokenizer.py | 6 +++++- src/open_clip_train/main.py | 4 +++- src/open_clip_train/params.py | 6 ++++++ 5 files changed, 51 insertions(+), 11 deletions(-) diff --git a/src/open_clip/factory.py b/src/open_clip/factory.py index 88dc03aef..358b51fb0 100644 --- a/src/open_clip/factory.py +++ b/src/open_clip/factory.py @@ -67,14 +67,25 @@ def add_model_config(path): def get_model_config(model_name): + """ Fetch model config from builtin (local library) configs. + """ if model_name in _MODEL_CONFIGS: return deepcopy(_MODEL_CONFIGS[model_name]) else: return None -def _get_hf_config(model_id, cache_dir=None): - config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) +def _get_hf_config( + model_id: str, + cache_dir: Optional[str] = None, +): + """ Fetch model config from HuggingFace Hub. + """ + config_path = download_pretrained_from_hf( + model_id, + filename='open_clip_config.json', + cache_dir=cache_dir, + ) with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) return config @@ -83,16 +94,18 @@ def _get_hf_config(model_id, cache_dir=None): def get_tokenizer( model_name: str = '', context_length: Optional[int] = None, + cache_dir: Optional[str] = None, **kwargs, ): if model_name.startswith(HF_HUB_PREFIX): model_name = model_name[len(HF_HUB_PREFIX):] try: - config = _get_hf_config(model_name)['model_cfg'] + config = _get_hf_config(model_name, cache_dir=cache_dir)['model_cfg'] except Exception: tokenizer = HFTokenizer( model_name, context_length=context_length or DEFAULT_CONTEXT_LENGTH, + cache_dir=cache_dir, **kwargs, ) return tokenizer @@ -113,6 +126,7 @@ def get_tokenizer( tokenizer = HFTokenizer( text_config['hf_tokenizer_name'], context_length=context_length, + cache_dir=cache_dir, **tokenizer_kwargs, ) else: @@ -265,7 +279,7 @@ def create_model( if has_hf_hub_prefix: model_id = model_name[len(HF_HUB_PREFIX):] checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) - config = _get_hf_config(model_id, cache_dir) + config = _get_hf_config(model_id, cache_dir=cache_dir) preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) model_cfg = config['model_cfg'] pretrained_hf = False # override, no need to load original HF text weights @@ -456,10 +470,16 @@ def create_model_and_transforms( pretrained_hf: bool = True, cache_dir: Optional[str] = None, output_dict: Optional[bool] = None, + load_weights_only: bool = True, **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( - {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + {}, + mean=image_mean, + std=image_std, + interpolation=image_interpolation, + resize_mode=image_resize_mode, + ) model = create_model( model_name, @@ -476,6 +496,7 @@ def create_model_and_transforms( pretrained_hf=pretrained_hf, cache_dir=cache_dir, output_dict=output_dict, + load_weights_only=load_weights_only, **model_kwargs, ) @@ -509,10 +530,16 @@ def create_model_from_pretrained( image_resize_mode: Optional[str] = None, # only effective for inference return_transform: bool = True, cache_dir: Optional[str] = None, + load_weights_only: bool = True, **model_kwargs, ): force_preprocess_cfg = merge_preprocess_kwargs( - {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) + {}, + mean=image_mean, + std=image_std, + interpolation=image_interpolation, + resize_mode=image_resize_mode, + ) model = create_model( model_name, @@ -526,6 +553,7 @@ def create_model_from_pretrained( force_preprocess_cfg=force_preprocess_cfg, cache_dir=cache_dir, require_pretrained=True, + load_weights_only=load_weights_only, **model_kwargs, ) diff --git a/src/open_clip/pretrained.py b/src/open_clip/pretrained.py index aac87619d..24c27ef3a 100644 --- a/src/open_clip/pretrained.py +++ b/src/open_clip/pretrained.py @@ -651,7 +651,7 @@ def get_pretrained_url(model: str, tag: str): def download_pretrained_from_url( url: str, - cache_dir: Union[str, None] = None, + cache_dir: Optional[str] = None, ): if not cache_dir: cache_dir = os.path.expanduser("~/.cache/clip") @@ -712,7 +712,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]: if filename == HF_WEIGHTS_NAME: yield HF_SAFE_WEIGHTS_NAME - if filename not in (HF_WEIGHTS_NAME,) and filename.endswith(".bin") or filename.endswith(".pth"): + if filename not in (HF_WEIGHTS_NAME,) and (filename.endswith(".bin") or filename.endswith(".pth")): yield filename[:-4] + ".safetensors" @@ -750,7 +750,7 @@ def download_pretrained_from_hf( ) return cached_file # Return the path to the downloaded file if successful except Exception as e: - raise FileNotFoundError(f"Failed to download any files for {model_id}. Last error: {e}") + raise FileNotFoundError(f"Failed to download file ({filename}) for {model_id}. Last error: {e}") def download_pretrained( diff --git a/src/open_clip/tokenizer.py b/src/open_clip/tokenizer.py index 3b762c2fa..872c1833b 100644 --- a/src/open_clip/tokenizer.py +++ b/src/open_clip/tokenizer.py @@ -410,10 +410,11 @@ def __init__( clean: str = 'whitespace', strip_sep_token: bool = False, language: Optional[str] = None, + cache_dir: Optional[str] = None, **kwargs ): from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **kwargs) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, cache_dir=cache_dir, **kwargs) set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None) if callable(set_lang_fn): self.set_lang_fn = set_lang_fn @@ -462,6 +463,9 @@ def set_language(self, src_lang): class SigLipTokenizer: """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs + + NOTE: this is not needed in normal library use, but is used to import new sentencepiece tokenizers + into OpenCLIP. Leaving code here in case future models use new tokenizers. """ VOCAB_FILES = { # english, vocab_size=32_000 diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index 1aa0750fc..7c244ae35 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -236,6 +236,7 @@ def main(args): aug_cfg=args.aug_cfg, pretrained_image=args.pretrained_image, output_dict=True, + cache_dir=args.cache_dir, **model_kwargs, ) if args.distill: @@ -246,6 +247,7 @@ def main(args): device=device, precision=args.precision, output_dict=True, + cache_dir=args.cache_dir, ) if args.use_bnb_linear is not None: print('=> using a layer from bitsandbytes.\n' @@ -357,7 +359,7 @@ def main(args): logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})") # initialize datasets - tokenizer = get_tokenizer(args.model) + tokenizer = get_tokenizer(args.model, cache_dir=args.cache_dir) data = get_data( args, (preprocess_train, preprocess_val), diff --git a/src/open_clip_train/params.py b/src/open_clip_train/params.py index b36ae7bec..2d94b7e21 100644 --- a/src/open_clip_train/params.py +++ b/src/open_clip_train/params.py @@ -101,6 +101,12 @@ def parse_args(args): default=None, help="Path to imagenet v2 for conducting zero shot evaluation.", ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Override system default cache path for model & tokenizer file downloads.", + ) parser.add_argument( "--logs", type=str,