diff --git a/experiment1/model.py b/experiment1/model.py index 98e798d..cb5ccca 100644 --- a/experiment1/model.py +++ b/experiment1/model.py @@ -3,7 +3,7 @@ from torch import nn from torch.nn.parameter import Parameter -from experiment1.modules import Adder, Substracter +from modules import Adder, Substracter class Model1(LightningModule): diff --git a/experiment1/trainer.py b/experiment1/trainer.py new file mode 100644 index 0000000..f7fded5 --- /dev/null +++ b/experiment1/trainer.py @@ -0,0 +1,17 @@ +from pytorch_lightning import Trainer +from torch.utils.data.dataloader import DataLoader +from dataset import NumberAdd +from model import Model1 + +# Data loaders +dl_train = DataLoader(NumberAdd(20000), batch_size=8, shuffle=True, num_workers=1) +dl_valid = DataLoader(NumberAdd(1000), batch_size=8, shuffle=True, num_workers=1) +dl_tests = DataLoader(NumberAdd(500), batch_size=8, shuffle=True, num_workers=1) + +# Model +model = Model1() + +# Trainer +trainer = Trainer(max_epochs=500, progress_bar_refresh_rate=20) +trainer.fit(model=model, train_dataloaders=dl_train, val_dataloaders=dl_valid) +