-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
31 lines (26 loc) · 839 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.strategies import DDPStrategy
from lightning import AutoregressiveLM, DiffusionLM
from lightning.mock import MockData
from lightning.data import ConcatData
def cli_main():
cli = LightningCLI(
trainer_defaults={
'accelerator': 'gpu',
'strategy': 'ddp',
'log_every_n_steps': 1,
'callbacks': [
ModelCheckpoint(
save_top_k=1,
save_last=True,
every_n_train_steps=10000,
filename='{epoch}-{step}',
),
ModelSummary(max_depth=4)
]
}
)
if __name__ == "__main__":
cli_main()