Skip to content

Commit

Permalink
fix lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
kengz committed Dec 29, 2024
1 parent 7f54635 commit dabbf7b
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions torcharc/example/notebook/lightning_mnist.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import sys

import lightning as L
import torch
import yaml
from lightning.pytorch.loggers import TensorBoardLogger
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
Expand Down Expand Up @@ -43,12 +43,11 @@ def val_dataloader(self):
class MNISTClassifier(L.LightningModule):
def __init__(self, model_spec_path: str):
super().__init__()
spec = yaml.safe_load(open(model_spec_path))
self.model = torcharc.build(spec)
# NOTE optimize with torch.compile https://pytorch.org/docs/2.5/generated/torch.compile.html
self.model = torch.compile(self.model)
self.model = torcharc.build(model_spec_path)
# NOTE set this for log_graph and reporting params
self.example_input_array = torch.rand(4, 1, 28, 28)
# forward pass to init Lazy layers
self.model(self.example_input_array)

self.lr = 1e-3
self.accuracy = Accuracy(task="multiclass", num_classes=10, top_k=1)
Expand Down Expand Up @@ -79,9 +78,13 @@ def validation_step(self, val_batch, batch_idx):


if __name__ == "__main__":
# run: ARC=conv uv run torcharc/example/notebook/lightning_mnist.py
arc = os.getenv("ARC", "conv") # conv, mlp
dm = MNISTDataModule()
arc = os.getenv("ARC", "conv2d") # conv2d, mlp, rnn
model = MNISTClassifier(f"torcharc/example/spec/mnist/{arc}.yaml")
model = MNISTClassifier(torcharc.SPEC_DIR / "mnist" / f"{arc}.yaml")
# speed up with compile https://lightning.ai/docs/pytorch/stable/advanced/compile.html
if sys.platform != "darwin": # but breaks on macOS GPU (mps)
model = torch.compile(model)
# launch tensorboard with `tensorboard --logdir ./tb_logs`
logger = TensorBoardLogger("tb_logs", name=arc, log_graph=True)
trainer = L.Trainer(max_epochs=1, logger=logger)
Expand Down

0 comments on commit dabbf7b

Please sign in to comment.