Skip to content

Commit

Permalink
make tests pass again
Browse files Browse the repository at this point in the history
  • Loading branch information
pesser committed Feb 14, 2020
1 parent c1df06d commit d504dc3
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 11 deletions.
8 changes: 5 additions & 3 deletions edflow/edsetup
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def create_edflow_project(project_name, replace: bool = False, **kwargs):
with open(source_config, "r+") as source_config_file:
source_config_dict = yaml.load(source_config_file, Loader=yaml.FullLoader)

path_keys = ["model_path", "dataset_path", "iterator_path"]
path_defaults = ["model.py", "dataset.py", "iterator.py"]
path_keys = ["model_path", "dataset_path", "dataset_path", "iterator_path"]
path_defaults = ["model.py", "dataset.py", "dataset.py", "iterator.py"]

destination_training_files = list()
for key, default in zip(path_keys, path_defaults):
Expand All @@ -61,8 +61,10 @@ def create_edflow_project(project_name, replace: bool = False, **kwargs):
for file, class_name in zip(training_files_to_module, training_classes)
]
training_parameters_dict = dict(
zip(["model", "dataset", "iterator"], full_address_to_class)
zip(["model", "train_dataset", "validation_dataset", "iterator"], full_address_to_class)
)
source_config_dict["datasets"]["train"] = training_parameters_dict.pop("train_dataset")
source_config_dict["datasets"]["validation"] = training_parameters_dict.pop("validation_dataset")
source_config_dict.update(training_parameters_dict)

with open(destination_config, "w+") as new_config_file:
Expand Down
4 changes: 3 additions & 1 deletion edflow/edsetup_files/config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
dataset: dataset.Dataset
datasets:
train: dataset.Dataset
validation: dataset.Dataset
iterator: iterator.Iterator
model: model.Model

Expand Down
30 changes: 24 additions & 6 deletions tests/test_edflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def test_1(self, tmpdir):
config = dict()
config["model"] = "tmptest." + fullname(Model)
config["iterator"] = "tmptest." + fullname(Iterator_checkpoint)
config["dataset"] = "tmptest." + fullname(Dataset)
config["datasets"] = {
"train": "tmptest." + fullname(Dataset),
"validation": "tmptest." + fullname(Dataset),
}
config["batch_size"] = 16
config["num_steps"] = 100
config["n_processes"] = 1
Expand Down Expand Up @@ -193,7 +196,10 @@ def test_2(self, tmpdir):
config = dict()
config["model"] = "tmptest." + fullname(Model)
config["iterator"] = "tmptest." + fullname(Iterator_checkpoint)
config["dataset"] = "tmptest." + fullname(Dataset)
config["datasets"] = {
"train": "tmptest." + fullname(Dataset),
"validation": "tmptest." + fullname(Dataset),
}
config["batch_size"] = 16
config["num_steps"] = 100
import yaml
Expand Down Expand Up @@ -238,7 +244,10 @@ def test_3(self, tmpdir):
config = dict()
config["model"] = "tmptest." + fullname(Model)
config["iterator"] = "tmptest." + fullname(Iterator_checkpoint_latest)
config["dataset"] = "tmptest." + fullname(Dataset)
config["datasets"] = {
"train": "tmptest." + fullname(Dataset),
"validation": "tmptest." + fullname(Dataset),
}
config["batch_size"] = 16
config["num_steps"] = 100
import yaml
Expand Down Expand Up @@ -283,7 +292,10 @@ def test_4(self, tmpdir):
config = dict()
config["model"] = "tmptest." + fullname(Model)
config["iterator"] = "tmptest." + fullname(Iterator4)
config["dataset"] = "tmptest." + fullname(Dataset)
config["datasets"] = {
"train": "tmptest." + fullname(Dataset),
"validation": "tmptest." + fullname(Dataset),
}
config["batch_size"] = 16
config["num_steps"] = 100
config["eval_all"] = True
Expand Down Expand Up @@ -333,7 +345,10 @@ def test_5(self, tmpdir):
config = dict()
config["model"] = "tmptest." + fullname(Model)
config["iterator"] = "tmptest." + fullname(Iterator_no_checkpoint)
config["dataset"] = "tmptest." + fullname(Dataset)
config["datasets"] = {
"train": "tmptest." + fullname(Dataset),
"validation": "tmptest." + fullname(Dataset),
}
config["batch_size"] = 16
config["num_steps"] = 100
config["eval_all"] = True
Expand Down Expand Up @@ -374,7 +389,10 @@ def test_6(self, tmpdir):
config = dict()
config["model"] = "tmptest." + fullname(Model)
config["iterator"] = "tmptest." + fullname(Iterator_no_checkpoint)
config["dataset"] = "tmptest." + fullname(LLDataset)
config["datasets"] = {
"train": "tmptest." + fullname(Dataset),
"validation": "tmptest." + fullname(Dataset),
}
config["batch_size"] = 16
config["num_steps"] = 100
config["eval_all"] = True
Expand Down
4 changes: 3 additions & 1 deletion tests/test_eval/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
dataset: edflow.debug.ConfigDebugDataset
datasets:
train: edflow.debug.ConfigDebugDataset
validation: edflow.debug.ConfigDebugDataset
size: 10

0 comments on commit d504dc3

Please sign in to comment.