Skip to content

Commit

Permalink
fix training logs
Browse files Browse the repository at this point in the history
  • Loading branch information
prajnan93 committed Apr 15, 2022
1 parent 5f8acfa commit 94b3686
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions ezflow/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _epoch_trainer(self, n_epochs=None, start_epoch=None):

if epoch % self.cfg.VALIDATE_INTERVAL == 0 and self._is_main_process():
new_avg_val_loss, new_avg_val_metric = self._validate_model()
print("-" * 80)
print("\n", "-" * 80)
self.writer.add_scalar(
"avg_validation_loss", new_avg_val_loss, epoch + 1
)
Expand All @@ -166,7 +166,7 @@ def _epoch_trainer(self, n_epochs=None, start_epoch=None):
print(
f"Epoch {epoch+1}: Average validation metric = {new_avg_val_metric}\n"
)
print("-" * 80)
print("-" * 80, "\n")
best_model = self._save_best_model(
best_model, new_avg_val_loss, new_avg_val_metric
)
Expand Down Expand Up @@ -361,7 +361,7 @@ def _save_best_model(self, best_model, new_avg_val_loss, new_avg_val_metric):
if new_avg_val_loss < self.min_avg_val_loss:

self.min_avg_val_loss = new_avg_val_loss
print("New minimum average validation loss!")
print("\nNew minimum average validation loss!")

if self.cfg.VALIDATE_ON.lower() == "loss":
best_model = deepcopy(self.model)
Expand All @@ -372,14 +372,14 @@ def _save_best_model(self, best_model, new_avg_val_loss, new_avg_val_metric):
save_best_model.state_dict(),
os.path.join(self.cfg.CKPT_DIR, self.model_name + "_best.pth"),
)
print(f"Saved new best model!")
print(f"Saved new best model!\n")

return best_model

if new_avg_val_metric < self.min_avg_val_metric:

self.min_avg_val_metric = new_avg_val_metric
print("New minimum average validation metric!")
print("\nNew minimum average validation metric!")

if self.cfg.VALIDATE_ON.lower() == "metric":
best_model = deepcopy(self.model)
Expand All @@ -390,7 +390,7 @@ def _save_best_model(self, best_model, new_avg_val_loss, new_avg_val_metric):
save_best_model.state_dict(),
os.path.join(self.cfg.CKPT_DIR, self.model_name + "_best.pth"),
)
print(f"Saved new best model!")
print(f"Saved new best model!\n")

return best_model

Expand Down Expand Up @@ -700,7 +700,7 @@ def _setup_ddp(self, rank):
print(f"{rank + 1}/{self.cfg.DISTRIBUTED.WORLD_SIZE} process initialized.")

def _is_main_process(self):
return dist.get_rank() == 0
return torch.cuda.current_device() == 0

def _setup_model(self, rank):

Expand Down

0 comments on commit 94b3686

Please sign in to comment.