Skip to content

Commit

Permalink
Fix import bug in DDP code
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-jeff committed Aug 14, 2024
1 parent cc88c0f commit 432bd7a
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion train/train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import random
from tensorboardX import SummaryWriter

import sys
sys.path.insert(1, '../common')
from models import *
from utils import *

Expand Down Expand Up @@ -219,8 +221,10 @@ def main(args):
net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[args.gpu])


start_epoch = 1
if args.pretrained:
if args.net_path is not None:
start_epoch = int(os.path.splitext(os.path.basename(args.net_path).split('_')[2])[0]) + 1
net.load_state_dict(torch.load(args.net_path))
else:
print("Error: Did not specify path to pretrained weights.")
Expand All @@ -231,7 +235,7 @@ def main(args):

nb_update_network = 0

for epoch in range(args.epochs): # loop over the dataset multiple times
for epoch in range(start_epoch, args.epochs+1): # loop over the dataset multiple times
train_sampler.set_epoch(epoch)
_runnetwork(args.gpu, local_rank, epoch, nb_update_network, net, train_loader, optimizer,
writer)
Expand Down

0 comments on commit 432bd7a

Please sign in to comment.