diff --git a/nni/retiarii/oneshot/pytorch/darts.py b/nni/retiarii/oneshot/pytorch/darts.py index edcb6d7b86..fd6b988c29 100644 --- a/nni/retiarii/oneshot/pytorch/darts.py +++ b/nni/retiarii/oneshot/pytorch/darts.py @@ -9,7 +9,7 @@ import torch.nn.functional as F from ..interface import BaseOneShotTrainer -from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice +from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device _logger = logging.getLogger(__name__) @@ -160,8 +160,8 @@ def _train_one_epoch(self, epoch): self.model.train() meters = AverageMeterGroup() for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): - trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) - val_X, val_y = val_X.to(self.device), val_y.to(self.device) + trn_X, trn_y = to_device(trn_X, self.device), to_device(trn_y, self.device) + val_X, val_y = to_device(val_X, self.device), to_device(val_y, self.device) # phase 1. architecture step self.ctrl_optim.zero_grad() diff --git a/nni/retiarii/oneshot/pytorch/proxyless.py b/nni/retiarii/oneshot/pytorch/proxyless.py index b44d883711..90759f9240 100644 --- a/nni/retiarii/oneshot/pytorch/proxyless.py +++ b/nni/retiarii/oneshot/pytorch/proxyless.py @@ -8,7 +8,7 @@ import torch.nn.functional as F from ..interface import BaseOneShotTrainer -from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice +from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device _logger = logging.getLogger(__name__) @@ -181,8 +181,8 @@ def _train_one_epoch(self, epoch): self.model.train() meters = AverageMeterGroup() for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): - trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) - val_X, val_y = val_X.to(self.device), val_y.to(self.device) + trn_X, trn_y = to_device(trn_X, self.device), to_device(trn_y, self.device) + val_X, val_y = to_device(val_X, self.device), to_device(val_y, self.device) if epoch >= self.warmup_epochs: # 1) train architecture parameters diff --git a/nni/retiarii/oneshot/pytorch/random.py b/nni/retiarii/oneshot/pytorch/random.py index a82ddada10..0944ad7c72 100644 --- a/nni/retiarii/oneshot/pytorch/random.py +++ b/nni/retiarii/oneshot/pytorch/random.py @@ -8,7 +8,7 @@ import torch.nn as nn from ..interface import BaseOneShotTrainer -from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice +from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device _logger = logging.getLogger(__name__) @@ -160,7 +160,7 @@ def _train_one_epoch(self, epoch): self.model.train() meters = AverageMeterGroup() for step, (x, y) in enumerate(self.train_loader): - x, y = x.to(self.device), y.to(self.device) + x, y = to_device(x, self.device), to_device(y, self.device) self.optimizer.zero_grad() self._resample() logits = self.model(x) @@ -180,7 +180,7 @@ def _validate_one_epoch(self, epoch): meters = AverageMeterGroup() with torch.no_grad(): for step, (x, y) in enumerate(self.valid_loader): - x, y = x.to(self.device), y.to(self.device) + x, y = to_device(x, self.device), to_device(y, self.device) self._resample() logits = self.model(x) loss = self.loss(logits, y)