Skip to content

Commit

Permalink
small updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Sherry-SR authored and Shen Rui committed Aug 12, 2019
1 parent c6d08c4 commit 0176f26
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 18 deletions.
4 changes: 2 additions & 2 deletions resources/train_config_case3D.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trainer:
# how many iterations before start level set alignment
align_start_iters: 5000
# how many iterations between level set alignment
align_after_iters: 1000
align_after_iters: 2000
# max number of epochs
epochs: 100
# max number of iterations
Expand Down Expand Up @@ -94,7 +94,7 @@ loaders:
val_path:
- '/mnt/lustre/shenrui/data/pelvis_resampled/dataset_val.txt'
# how many subprocesses to use for data loading
num_workers: 8
num_workers: 16
# batch size in training process
batch_size: 8
# data transformations/augmentations
Expand Down
30 changes: 14 additions & 16 deletions utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def train(self, train_loader):
input, target, weight = self._split_training_batch(t)
output, loss = self._forward_pass(input, target, weight)
train_losses.update(loss.item(), self._batch_size(input))

# if model contains final_activation layer for normalizing logits apply it, otherwise both
# the evaluation metric as well as images in tensorboard will be incorrectly computed
if hasattr(self.model, 'final_activation'):
Expand All @@ -199,8 +199,19 @@ def train(self, train_loader):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

if self.num_iterations % self.validate_after_iters == 0:

if (self.num_iterations == 1) or (self.num_iterations % self.log_after_iters == 0):
# log stats, params and images
self.logger.info(
f'Training iteration [{self.num_iterations}/{self.max_num_iterations}]. Batch [{i}/{len(train_loader) - 1}]. Epoch [{self.num_epoch}/{self.max_num_epochs - 1}]')
self.logger.info(
f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}')
self._log_stats('train', train_losses.avg, train_eval_scores.avg)

train_losses = RunningAverage()
train_eval_scores = RunningAverage()

if (self.num_iterations == 1) or (self.num_iterations % self.validate_after_iters == 0):
# evaluate on validation set
eval_score = self.validate(self.loaders['val'])
# adjust learning rate if necessary
Expand All @@ -212,24 +223,11 @@ def train(self, train_loader):
self._log_lr()
# remember best validation metric
is_best = self._is_best_eval_score(eval_score)

# save checkpoint
self._save_checkpoint(is_best)

self._log_params()
#self._log_images(input, target, output)

if self.num_iterations % self.log_after_iters == 0:
# log stats, params and images
self.logger.info(
f'Training iteration [{self.num_iterations}/{self.max_num_iterations}]. Batch [{i}/{len(train_loader) - 1}]. Epoch [{self.num_epoch}/{self.max_num_epochs - 1}]')
self.logger.info(
f'Training stats. Loss: {train_losses.avg}. Evaluation score: {train_eval_scores.avg}')
self._log_stats('train', train_losses.avg, train_eval_scores.avg)

train_losses = RunningAverage()
train_eval_scores = RunningAverage()

if (self.num_iterations >= self.align_start_iters) and (self.num_iterations % self.align_after_iters == 0):
self.loaders['train'] = self.align(self.loaders['train'])

Expand Down

0 comments on commit 0176f26

Please sign in to comment.