Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Fix PyTorchImageClassificationTrainer's training (#3339)
Browse files Browse the repository at this point in the history
* Fix PyTorchImageClassificationTrainer's training

It seems that the current process only calculates the loss and gradient, and does not use the optimizer for optimization. Therefore, the model is not actually trained, and its accuracy on the Web UI remains unchanged.

* Add intermediate reports for every epoch

For now, intermediate reports and final reports are consistent, and they are displayed once after all epochs have finished. This may not meet our expectations. We hope that intermediate reports can provide us with the validation results after each epoch.
  • Loading branch information
tczhangzhi authored Feb 2, 2021
1 parent 8175f28 commit 533f2ef
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions nni/retiarii/trainer/pytorch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,15 @@ def _validate(self):

def _train(self):
for i, batch in enumerate(self._train_dataloader):
self._optimizer.zero_grad()
loss = self.training_step(batch, i)
loss.backward()
self._optimizer.step()

def fit(self) -> None:
for _ in range(self._trainer_kwargs['max_epochs']):
self._train()
self._validate()
# assuming val_acc here
nni.report_final_result(self._validate()['val_acc'])

Expand Down Expand Up @@ -204,6 +207,7 @@ def fit(self) -> None:
max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']])
for _ in range(max_epochs):
self._train()
self._validate()
nni.report_final_result(self._validate())

def _train(self):
Expand Down

0 comments on commit 533f2ef

Please sign in to comment.