From 9265d54a3ac1ede3f5b3c5b4f62daeedb6e5da46 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 16 Jan 2025 11:08:35 -0800 Subject: [PATCH] LeViT safetensors load is broken by conversion code that wasn't deactivated --- timm/models/levit.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/timm/models/levit.py b/timm/models/levit.py index 16186cae7a..577fc5f2d7 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -763,17 +763,18 @@ def checkpoint_filter_fn(state_dict, model): # filter out attn biases, should not have been persistent state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k} - D = model.state_dict() - out_dict = {} - for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()): - if va.ndim == 4 and vb.ndim == 2: - vb = vb[:, :, None, None] - if va.shape != vb.shape: - # head or first-conv shapes may change for fine-tune - assert 'head' in ka or 'stem.conv1.linear' in ka - out_dict[ka] = vb - - return out_dict + # NOTE: old weight conversion code, disabled + # D = model.state_dict() + # out_dict = {} + # for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()): + # if va.ndim == 4 and vb.ndim == 2: + # vb = vb[:, :, None, None] + # if va.shape != vb.shape: + # # head or first-conv shapes may change for fine-tune + # assert 'head' in ka or 'stem.conv1.linear' in ka + # out_dict[ka] = vb + + return state_dict model_cfgs = dict(