Skip to content

Commit

Permalink
Proper learning
Browse files Browse the repository at this point in the history
  • Loading branch information
JVGD committed Nov 1, 2021
1 parent 456b02d commit 022e9d4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 6 additions & 3 deletions experiment1/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def __init__(self, lr, optim_conf: dict={}):
def forward(self, x):
y0 = self.adder(x)
y1 = self.substracter(x)
y = self.weights_adder * y0 + self.weights_substracter * y1
y = (T.sigmoid(self.weights_adder) * y0 +
T.sigmoid(self.weights_substracter) * y1)
return y

def step(self, batch, batch_idx, *args, **kwargs) -> T.Tensor:
Expand Down Expand Up @@ -59,5 +60,7 @@ def configure_optimizers(self):
return optim.SGD(self.parameters(), lr=self.lr, **self.optim_conf)

def on_epoch_end(self) -> None:
self.log("weights/weights_adder", self.weights_adder)
self.log("weights/weights_substractor", self.weights_substracter)
self.log("weights/wa", self.weights_adder)
self.log("weights/ws", self.weights_substracter)
self.log("weights/sigma(wa)", T.sigmoid(self.weights_adder))
self.log("weights/sigma(ws)", T.sigmoid(self.weights_substracter))
4 changes: 3 additions & 1 deletion experiment1/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data.dataloader import DataLoader

from dataset import NumberAdd
from model import Model1

Expand All @@ -12,6 +14,6 @@
model = Model1(lr=1e-4, optim_conf={"momentum": 0.9})

# Trainer
trainer = Trainer(max_epochs=500)
trainer = Trainer(max_epochs=500, callbacks=[ModelCheckpoint(monitor="loss/valid")])
trainer.fit(model=model, train_dataloaders=dl_train, val_dataloaders=dl_valid)

0 comments on commit 022e9d4

Please sign in to comment.