-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
36 lines (29 loc) · 1.1 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import hydra
import lightning.pytorch as pl
import torch
from tricolo.data.data_module import DataModule
from tricolo.model.tricolo_net import TriCoLoNet
@hydra.main(version_base=None, config_path="config", config_name="config")
def main(cfg):
# fix the seed
pl.seed_everything(cfg.test_seed, workers=True)
os.makedirs(cfg.inference.output_dir, exist_ok=True)
model = TriCoLoNet(cfg)
# load checkpoints
assert os.path.exists(cfg.ckpt_path), "Error: Checkpoint path does not exists."
ckpt = torch.load(cfg.ckpt_path)["state_dict"]
to_be_deleted = []
for key in ckpt:
if cfg.model.image_encoder is None and "image_encoder" in key:
to_be_deleted.append(key)
if cfg.model.voxel_encoder is None and "voxel_encoder" in key:
to_be_deleted.append(key)
for key in to_be_deleted:
del ckpt[key]
model.load_state_dict(ckpt)
data_module = DataModule(cfg)
trainer = pl.Trainer(accelerator="gpu", devices=1, max_epochs=1, logger=False)
trainer.test(model=model, datamodule=data_module)
if __name__ == '__main__':
main()