Skip to content

Commit

Permalink
Merge pull request mmistakes#40 from pesser/hackday2
Browse files Browse the repository at this point in the history
fix hanging process
  • Loading branch information
jhaux authored May 2, 2019
2 parents 298ac70 + 20d59a5 commit b333207
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 36 deletions.
11 changes: 11 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ jobs:
- pip install black
script:
- black --check .
- stage: general_tests
name: "Tensorflow examples test"
install:
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION tensorflow=1.13.1
- source activate test-environment
- conda install pytorch-cpu torchvision-cpu -c pytorch
- pip install tensorboardX
- python setup.py install
script:
- cd tests
- python -m pytest
- stage: tf_example_tests
name: "Tensorflow examples test"
install:
Expand Down
70 changes: 36 additions & 34 deletions edflow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,48 +93,50 @@ def _train(config, root, checkpoint=None, retrain=False):
logger.info("Number of training samples: {}".format(len(dataset)))
n_processes = config.get("n_data_processes", min(16, config["batch_size"]))
n_prefetch = config.get("n_prefetch", 1)
batches = make_batches(
with make_batches(
dataset,
batch_size=config["batch_size"],
shuffle=True,
n_processes=n_processes,
n_prefetch=n_prefetch,
)
# get them going
logger.info("Warm up batches.")
next(batches)
batches.reset()
logger.info("Reset batches.")

if "num_steps" in config:
# set number of epochs to perform at least num_steps steps
steps_per_epoch = len(dataset) / config["batch_size"]
num_epochs = config["num_steps"] / steps_per_epoch
config["num_epochs"] = math.ceil(num_epochs)
else:
steps_per_epoch = len(dataset) / config["batch_size"]
num_steps = config["num_epochs"] * steps_per_epoch
config["num_steps"] = math.ceil(num_steps)

logger.info("Instantiating model.")
Model = implementations["model"](config)
if not "hook_freq" in config:
config["hook_freq"] = 1
compat_kwargs = dict(hook_freq=config["hook_freq"], num_epochs=config["num_epochs"])
logger.info("Instantiating iterator.")
Trainer = implementations["iterator"](config, root, Model, **compat_kwargs)
) as batches:
# get them going
logger.info("Warm up batches.")
next(batches)
batches.reset()
logger.info("Reset batches.")

if "num_steps" in config:
# set number of epochs to perform at least num_steps steps
steps_per_epoch = len(dataset) / config["batch_size"]
num_epochs = config["num_steps"] / steps_per_epoch
config["num_epochs"] = math.ceil(num_epochs)
else:
steps_per_epoch = len(dataset) / config["batch_size"]
num_steps = config["num_epochs"] * steps_per_epoch
config["num_steps"] = math.ceil(num_steps)

logger.info("Instantiating model.")
Model = implementations["model"](config)
if not "hook_freq" in config:
config["hook_freq"] = 1
compat_kwargs = dict(
hook_freq=config["hook_freq"], num_epochs=config["num_epochs"]
)
logger.info("Instantiating iterator.")
Trainer = implementations["iterator"](config, root, Model, **compat_kwargs)

logger.info("Initializing model.")
if checkpoint is not None:
Trainer.initialize(checkpoint_path=checkpoint)
else:
Trainer.initialize()
logger.info("Initializing model.")
if checkpoint is not None:
Trainer.initialize(checkpoint_path=checkpoint)
else:
Trainer.initialize()

if retrain:
Trainer.reset_global_step()
if retrain:
Trainer.reset_global_step()

logger.info("Iterating.")
Trainer.iterate(batches)
logger.info("Iterating.")
Trainer.iterate(batches)


def _test(config, root, nogpu=False, bar_position=0):
Expand Down
4 changes: 2 additions & 2 deletions edflow/project_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, base=None, given_directory=None, code_root=".", postfix=None)
self.setup()
self.setup_new_eval()

if given_directory is None:
if given_directory is None and ProjectManager.code_root is not None:
self.copy_code()

ProjectManager.exists = True
Expand Down Expand Up @@ -111,7 +111,7 @@ def ignore(directory, files):
print(directory, filtered)
return filtered

shutil.copytree(src, dst, symlinks=True, ignore=ignore)
shutil.copytree(src, dst, symlinks=False, ignore=ignore)

except shutil.Error as err:
print(err)
Expand Down
23 changes: 23 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import subprocess, os
from edflow.main import _train
from edflow.custom_logging import init_project


def fullname(o):
"""Get string to specify class in edflow config."""
module = o.__module__
return module + "." + o.__name__


def run_edflow_cmdline(command):
"""Just make sure example runs without errors."""
env = os.environ.copy()
if not "CUDA_VISIBLE_DEVICES" in env:
env["CUDA_VISIBLE_DEVICES"] = "0"
subprocess.run(command, shell=True, check=True, env=env)


def run_edflow(name, config):
"""Run edflow directly from config."""
P = init_project("logs", code_root=None, postfix=name)
_train(config, P.root)
44 changes: 44 additions & 0 deletions tests/test_hanging_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest, subprocess
from common import run_edflow, fullname
from edflow.iterators.model_iterator import PyHookedModelIterator
from edflow.iterators.batches import DatasetMixin


class Model(object):
def __init__(self, config):
self.config = config


class Iterator(PyHookedModelIterator):
def initialize(self, *args, **kwargs):
raise Exception("TestAbort")


class Dataset(DatasetMixin):
def __init__(self, config):
self.config = config

def __len__(self):
return 1000

def get_example(self, i):
return {"foo": 0}


def run_test():
config = dict()
config["model"] = fullname(Model)
config["iterator"] = fullname(Iterator)
config["dataset"] = fullname(Dataset)
config["batch_size"] = 16
config["num_steps"] = 100
run_edflow("0024", config)


def test():
subprocess.run(
'python -c "from test_hanging_process import run_test; run_test()"',
shell=True,
check=False,
timeout=60,
)

0 comments on commit b333207

Please sign in to comment.