Skip to content

Commit

Permalink
save learning plot
Browse files Browse the repository at this point in the history
  • Loading branch information
tsugumi-sys committed Jan 15, 2024
1 parent 02a6da5 commit b18d6ae
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 6 deletions.
10 changes: 8 additions & 2 deletions pipelines/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from core.constants import DEVICE
from pipelines.base import BaseRunner
from pipelines.utils.early_stopping import EarlyStopping
from pipelines.utils.visualize_utils import save_learning_curve_plot


class TrainingParams(TypedDict):
Expand Down Expand Up @@ -83,7 +84,7 @@ def run(self) -> None:
print(f"Early stopped at epoch {epoch}")
break

self.__save_metrics()
self._save_artifacts()

@property
def training_metrics(self) -> TrainingMetrics:
Expand Down Expand Up @@ -139,7 +140,12 @@ def __log_metric(
def __latest_training_metric(self) -> Dict[str, float]:
return {k: cast(List[float], v)[-1] for k, v in self._training_metrics.items()}

def __save_metrics(self) -> None:
def _save_artifacts(self) -> None:
pd.DataFrame(self._training_metrics).to_csv(
os.path.join(self.artifact_dir, self.metrics_filename)
)
save_learning_curve_plot(
os.path.join(self.artifact_dir, "learning_curve.png"),
self._training_metrics["train_loss"],
self._training_metrics["validation_loss"],
)
6 changes: 2 additions & 4 deletions pipelines/utils/visualize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@


def save_learning_curve_plot(
save_dir_path: str,
model_name: str,
save_img_path: str,
training_losses: List,
validation_losses: List,
) -> None:
Expand Down Expand Up @@ -47,8 +46,7 @@ def save_learning_curve_plot(

ax.legend(loc="upper center")
plt.tight_layout()
save_path = os.path.join(save_dir_path, f"{model_name}_training_results.png")
plt.savefig(save_path)
plt.savefig(save_img_path)
plt.close()


Expand Down
1 change: 1 addition & 0 deletions tests/pipelines/test_experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_run(mocked_save_seq2seq_model):
# testing trainer artifacts
assert os.path.exists(os.path.join(tempdirpath, "train", "model.pt"))
assert os.path.exists(os.path.join(tempdirpath, "train", "metrics.csv"))
assert os.path.exists(os.path.join(tempdirpath, "train", "learning_curve.png"))
# testing evaluator artifacts
for i in range(dataset_length):
assert os.path.exists(
Expand Down
2 changes: 2 additions & 0 deletions tests/pipelines/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_run(mocked_save_seq2seq_model):

assert os.path.exists(os.path.join(tempdirpath, "checkpoint.pt"))
assert os.path.exists(os.path.join(tempdirpath, "example.csv"))
assert os.path.exists(os.path.join(tempdirpath, "learning_curve.png"))
for metrics in trainer.training_metrics.values():
assert len(metrics) == epochs

Expand Down Expand Up @@ -75,5 +76,6 @@ def test_run_early_stopping(mocked_save_seq2seq_model):

assert os.path.exists(os.path.join(tempdirpath, "checkpoint.pt"))
assert os.path.exists(os.path.join(tempdirpath, "example.csv"))
assert os.path.exists(os.path.join(tempdirpath, "learning_curve.png"))
for metrics in trainer.training_metrics.values():
assert len(metrics) == epochs - patience

0 comments on commit b18d6ae

Please sign in to comment.