Skip to content

Commit

Permalink
Hub models map_location=device (#3894)
Browse files Browse the repository at this point in the history
* Hub models `map_location=device`

* cleanup
  • Loading branch information
glenn-jocher authored Jul 5, 2021
1 parent 8930e22 commit 6a3ee7c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
7 changes: 4 additions & 3 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo

fname = Path(name).with_suffix('.pt') # checkpoint filename
try:
device = select_device(('0' if torch.cuda.is_available() else 'cpu') if device is None else device)

if pretrained and channels == 3 and classes == 80:
model = attempt_load(fname, map_location=torch.device('cpu')) # download/load FP32 model
model = attempt_load(fname, map_location=device) # download/load FP32 model
else:
cfg = list((Path(__file__).parent / 'models').rglob(f'{name}.yaml'))[0] # model.yaml path
model = Model(cfg, channels, classes) # create model
if pretrained:
ckpt = torch.load(attempt_download(fname), map_location=torch.device('cpu')) # load
ckpt = torch.load(attempt_download(fname), map_location=device) # load
msd = model.state_dict() # model state_dict
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = {k: v for k, v in csd.items() if msd[k].shape == v.shape} # filter
Expand All @@ -51,7 +53,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo
model.names = ckpt['model'].names # set class names attribute
if autoshape:
model = model.autoshape() # for file/URI/PIL/cv2/np inputs and NMS
device = select_device('0' if torch.cuda.is_available() else 'cpu') if device is None else torch.device(device)
return model.to(device)

except Exception as e:
Expand Down
5 changes: 3 additions & 2 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import datetime
import logging
import math
import os
import platform
import subprocess
Expand All @@ -11,6 +10,7 @@
from copy import deepcopy
from pathlib import Path

import math
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
Expand Down Expand Up @@ -64,7 +64,8 @@ def git_describe(path=Path(__file__).parent): # path must be a directory
def select_device(device='', batch_size=None):
# device = 'cpu' or '0' or '0,1,2,3'
s = f'YOLOv5 🚀 {git_describe() or date_modified()} torch {torch.__version__} ' # string
cpu = device.lower() == 'cpu'
device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
cpu = device == 'cpu'
if cpu:
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
Expand Down

0 comments on commit 6a3ee7c

Please sign in to comment.