From 88b750a018fe89b3cdc21d9f68ec1fbdce7c21cc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 14 Jan 2020 14:40:41 -0500 Subject: [PATCH] default logger is now tensorboard (#609) * refactor * refactor * refactor * made tensorboard the default not test-tube --- pytorch_lightning/trainer/callback_config.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 25 +++++++++++++++----- requirements.txt | 2 +- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 02db09c31cab1..1f23d3c19a093 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -2,7 +2,7 @@ from abc import ABC from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping -from pytorch_lightning.logging import TestTubeLogger +from pytorch_lightning.logging import TensorboardLogger class TrainerCallbackConfigMixin(ABC): @@ -69,7 +69,7 @@ def configure_early_stopping(self, early_stop_callback, logger): # configure logger if logger is True: # default logger - self.logger = TestTubeLogger( + self.logger = TensorboardLogger( save_dir=self.default_save_path, version=self.slurm_job_id, name='lightning_logs' diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 9852c8b7a4567..a1133004bc448 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -139,41 +139,53 @@ def __init__( """ # Transfer params - if nb_gpu_nodes is not None: # Backward compatibility + # Backward compatibility + if nb_gpu_nodes is not None: warnings.warn("`nb_gpu_nodes` has renamed to `num_nodes` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not num_nodes: # in case you did not set the proper value num_nodes = nb_gpu_nodes self.num_gpu_nodes = num_nodes + self.log_gpu_memory = log_gpu_memory - if gradient_clip is not None: # Backward compatibility + + # Backward compatibility + if gradient_clip is not None: warnings.warn("`gradient_clip` has renamed to `gradient_clip_val` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not gradient_clip_val: # in case you did not set the proper value gradient_clip_val = gradient_clip self.gradient_clip_val = gradient_clip_val + self.check_val_every_n_epoch = check_val_every_n_epoch self.track_grad_norm = track_grad_norm self.on_gpu = True if (gpus and torch.cuda.is_available()) else False self.process_position = process_position self.weights_summary = weights_summary - if max_nb_epochs is not None: # Backward compatibility + + # Backward compatibility + if max_nb_epochs is not None: warnings.warn("`max_nb_epochs` has renamed to `max_epochs` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not max_epochs: # in case you did not set the proper value max_epochs = max_nb_epochs self.max_epochs = max_epochs - if min_nb_epochs is not None: # Backward compatibility + + # Backward compatibility + if min_nb_epochs is not None: warnings.warn("`min_nb_epochs` has renamed to `min_epochs` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not min_epochs: # in case you did not set the proper value min_epochs = min_nb_epochs self.min_epochs = min_epochs - if nb_sanity_val_steps is not None: # Backward compatibility + + # Backward compatibility + if nb_sanity_val_steps is not None: warnings.warn("`nb_sanity_val_steps` has renamed to `num_sanity_val_steps` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not num_sanity_val_steps: # in case you did not set the proper value num_sanity_val_steps = nb_sanity_val_steps + self.num_sanity_val_steps = num_sanity_val_steps self.print_nan_grads = print_nan_grads self.truncated_bptt_steps = truncated_bptt_steps @@ -262,8 +274,9 @@ def __init__( # logging self.log_save_interval = log_save_interval self.val_check_interval = val_check_interval + + # backward compatibility if add_row_log_interval is not None: - # backward compatibility warnings.warn("`add_row_log_interval` has renamed to `row_log_interval` since v0.5.0" " and will be removed in v0.8.0", DeprecationWarning) if not row_log_interval: # in case you did not set the proper value diff --git a/requirements.txt b/requirements.txt index 1e59e7f7f967d..26d9084a5c78e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ torch>=1.1 torchvision>=0.4.0 pandas>=0.24 # lower version do not support py3.7 test-tube>=0.7.5 -future>=0.17.1 # required for builtins in setup.py +future>=0.17.1 # required for builtins in setup.py \ No newline at end of file