-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Retiarii] add validation in base trainers #3184
Conversation
nni/retiarii/trainer/pytorch/base.py
Outdated
self._val_dataset = getattr(datasets, dataset_cls)(train=False, | ||
transform=get_default_transform( | ||
dataset_cls), | ||
**(dataset_kwargs or {})) | ||
self._optimizer = getattr(torch.optim, optimizer_cls)( | ||
model.parameters(), **(optimizer_kwargs or {})) | ||
self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10} | ||
|
||
# TODO: we will need at least two (maybe three) data loaders in future. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]: | ||
x, y = self.training_step_before_model(batch, batch_idx) | ||
y_hat = self.model(x) | ||
return self.training_step_after_model(x, y, y_hat) | ||
|
||
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device = None): | ||
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest using self.device
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In MultiModel, different model's input may need to be placed on different devices (called in _train). Currently, the trainer just sets one GPU per model in hard-code.
BTW, train_step and validation_step are not used in PyTorchImageClassificationTrainer. Removed.
nni/retiarii/trainer/pytorch/base.py
Outdated
summed_loss = sum(losses) | ||
summed_loss.backward() | ||
for opt in self._optimizers: | ||
opt.step() | ||
if batch_idx % 50 == 0: | ||
nni.report_intermediate_result(report_loss) | ||
# if batch_idx % 50 == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why comment this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was for debug. training_loss is not reported. Removed.
NNI's line limit is 140. You might need to configure your autopep to avoid unwanted linebreaks. :) |
No description provided.