Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

worked (simple) example of loading model and transforms? #20

Open
ColinConwell opened this issue Nov 20, 2022 · 1 comment
Open

worked (simple) example of loading model and transforms? #20

ColinConwell opened this issue Nov 20, 2022 · 1 comment

Comments

@ColinConwell
Copy link

ColinConwell commented Nov 20, 2022

Thank you for this exciting repository. Can you provide a simple example of how I might be able to load the models you provide in your model zoo?

Something along the lines of what is provided by the timm (pytorch-image-models) model repository:

import timm
model_name = 'ghostnet_100'
model = timm.create_model(model_name, pretrained=True)
model.eval()

from timm.data.transforms_factory import create_transform
from timm.data import resolve_data_config
    
config = resolve_data_config({}, model = model_name)
transform = create_transform(**config)

Ideally, this would allow us to use the models in a jupyter notebook or other interactive context.

Thanks in advance!

@ColinConwell
Copy link
Author

ColinConwell commented Nov 20, 2022

By way of example, here's a little script I worked out. If this looks incorrect, let me know!

import os, sys, torch
from PIL import Image

from torchvision import transforms

if not os.path.exists('DeCLIP'):
    !git clone https://github.com/Sense-GVT/DeCLIP/
    
sys.path.append('DeCLIP')
    
sample_image = Image.open('dog.jpg')

from prototype.utils.misc import parse_config

config_path = 'DeCLIP/experiments/declip_experiments/declip88m/declip88m_r50_declip/config.yaml'
config = parse_config(config_path)

from prototype.model.declip import declip_res50

bpe_path = 'DeCLIP/prototype/text_info/bpe_simple_vocab_16e6.txt.gz'
config['model']['kwargs']['text_encode']['bpe_path'] = bpe_path
config['model']['kwargs']['clip']['text_mask_type'] = None

weights = torch.load('DeCLIP/weights/declip_88m/r50.pth.tar')['model']
weights = {k.replace('module.',''):v for k,v in weights.items()}
weights['logit_scale'] = weights['logit_scale'].unsqueeze(0)

model = declip_res50(**config['model']['kwargs'])
model.load_state_dict(weights, strict = False)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocess = transforms.Compose([transforms.Resize(256), transforms.ToTensor(), normalize])

inputs = preprocess(sample_image).unsqueeze(0)
model.visual(inputs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant