Skip to content

Commit

Permalink
create dir per runner
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Jan 15, 2024
1 parent 337b6af commit 13f40dd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
6 changes: 4 additions & 2 deletions pipelines/experimenter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

from torch import nn

from core.constants import DEVICE
Expand Down Expand Up @@ -52,7 +54,7 @@ def __train(self):
accuracy_criterion=self._training_params["accuracy_criterion"],
optimizer=self._training_params["optimizer"],
early_stopping=self._training_params["early_stopping"],
artifact_dir=self._artifact_dir,
artifact_dir=os.path.join(self._artifact_dir, "train"),
metrics_filename=self._training_params.get("metrics_filename")
or "metrics.csv",
)
Expand All @@ -63,6 +65,6 @@ def __evaluate(self):
evaluator = Evaluator(
model=self._model,
test_dataloader=self._data_loaders.test_dataloader,
artifact_dir=self._artifact_dir,
artifact_dir=os.path.join(self._artifact_dir, "evaluation"),
)
evaluator.run()
10 changes: 6 additions & 4 deletions tests/pipelines/test_experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_run(mocked_save_seq2seq_model):
patience=30,
verbose=True,
delta=0.0001,
model_save_path=os.path.join(tempdirpath, "model.pt"),
model_save_path=os.path.join(tempdirpath, "train", "model.pt"),
),
"metrics_filename": "metrics.csv",
}
Expand All @@ -44,8 +44,10 @@ def test_run(mocked_save_seq2seq_model):
experimenter.run()

# testing trainer artifacts
assert os.path.exists(os.path.join(tempdirpath, "model.pt"))
assert os.path.exists(os.path.join(tempdirpath, "metrics.csv"))
assert os.path.exists(os.path.join(tempdirpath, "train", "model.pt"))
assert os.path.exists(os.path.join(tempdirpath, "train", "metrics.csv"))
# testing evaluator artifacts
for i in range(dataset_length):
assert os.path.exists(os.path.join(tempdirpath, f"test-case{i}.png"))
assert os.path.exists(
os.path.join(tempdirpath, "evaluation", f"test-case{i}.png")
)

0 comments on commit 13f40dd

Please sign in to comment.