Skip to content

Commit

Permalink
fixed multiprocessing import
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon committed Jun 29, 2019
1 parent f2134a4 commit 0a03042
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 31 deletions.
1 change: 0 additions & 1 deletion examples/new_project_templates/trainer_cpu_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def main(hparams):
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_function=None,
save_best_only=True,
verbose=True,
monitor='val_acc',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# ---------------------
# DEFINE MODEL HERE
# ---------------------
from examples.new_project_templates.lightning_module_template import LightningTemplateModel
from lightning_module_template import LightningTemplateModel
# ---------------------

AVAILABLE_MODELS = {
Expand Down Expand Up @@ -58,9 +58,7 @@ def main(hparams, cluster, results_dict):
log_dir = os.path.dirname(os.path.realpath(__file__))
exp = Experiment(
name='test_tube_exp',
debug=True,
save_dir=log_dir,
version=0,
autosave=False,
description='test demo'
)
Expand All @@ -84,7 +82,6 @@ def main(hparams, cluster, results_dict):
model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version)
checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_function=None,
save_best_only=True,
verbose=True,
monitor=hparams.model_save_monitor_value,
Expand All @@ -102,7 +99,7 @@ def main(hparams, cluster, results_dict):
cluster=cluster,
checkpoint_callback=checkpoint,
early_stop_callback=early_stop,
gpus=gpu_list
gpus=gpu_list,
)

# train model
Expand Down
20 changes: 18 additions & 2 deletions pytorch_lightning/models/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,15 @@ def __is_function_implemented(self, f_name):
def __tng_tqdm_dic(self):
tqdm_dic = {
'tng_loss': '{0:.3f}'.format(self.avg_loss),
'gpu': '{}'.format(self.current_gpu_name),
'v_nb': '{}'.format(self.experiment.version),
'epoch': '{}'.format(self.current_epoch),
'batch_nb':'{}'.format(self.batch_nb),
}
tqdm_dic.update(self.tqdm_metrics)

if self.on_gpu:
tqdm_dic['gpu'] = '{}'.format(self.current_gpu_name)

return tqdm_dic

def __layout_bookeeping(self, model):
Expand Down Expand Up @@ -371,7 +373,8 @@ def __train(self):
metrics.update(grad_norm_dic)

# log metrics
self.experiment.log(metrics)
scalar_metrics = self.__metrics_to_scalars(metrics)
self.experiment.log(scalar_metrics, global_step=self.global_step)
self.experiment.save()

# hook
Expand All @@ -398,6 +401,19 @@ def __train(self):
if stop:
return

def __metrics_to_scalars(self, metrics):
new_metrics = {}
for k, v in metrics.items():
if type(v) is torch.Tensor:
v = v.item()

if type(v) is dict:
v = self.__metrics_to_scalars(v)

new_metrics[k] = float(v)

return new_metrics


def __run_tng_batch(self, data_batch, batch_nb):
if data_batch is None:
Expand Down
61 changes: 39 additions & 22 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,27 +1,44 @@
atomicwrites==1.2.1
attrs==18.2.0
certifi==2018.11.29
cffi==1.11.5
absl-py==0.7.1
astor==0.8.0
bleach==3.1.0
certifi==2019.6.16
cffi==1.12.3
chardet==3.0.4
docutils==0.14
gast==0.2.2
google-pasta==0.1.7
grpcio==1.21.1
h5py==2.9.0
imageio==2.4.1
mkl-fft==1.0.6
idna==2.8
imageio==2.5.0
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.0
Markdown==3.1.1
mkl-fft==1.0.12
mkl-random==1.0.2
more-itertools==5.0.0
numpy==1.15.4
numpy==1.16.4
olefile==0.46
pandas==0.23.4
Pillow==5.3.0
pluggy==0.8.0
py==1.7.0
pandas==0.24.2
Pillow==6.0.0
pkginfo==1.5.0.1
protobuf==3.8.0
pycparser==2.19
pytest==4.0.2
python-dateutil==2.7.5
pytz==2018.7
scikit-learn==0.20.2
scipy==1.2.0
Pygments==2.4.1
python-dateutil==2.8.0
pytz==2019.1
readme-renderer==24.0
requests==2.22.0
requests-toolbelt==0.9.1
six==1.12.0
sklearn==0.0
test-tube==0.6282
torch==1.0.0
torchvision==0.2.1
tqdm==4.28.1
tensorboard==1.14.0
tensorboardX==1.7
tensorflow==1.14.0
tensorflow-estimator==1.14.0
termcolor==1.1.0
test-tube==0.643
tqdm==4.32.1
twine==1.13.0
urllib3==1.25.3
webencodings==0.5.1
Werkzeug==0.15.4
wrapt==1.11.2
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
install_requires=[
"torch>=1.0.0",
"tqdm",
"test-tube>=0.641",
"test-tube>=0.643",
"tensorflow>=1.14.0"
],
packages=find_packages(),
Expand Down

0 comments on commit 0a03042

Please sign in to comment.