This repository has been archived by the owner on Jun 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
117 lines (95 loc) · 3.43 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Copyright (c) 2021 Robert Bosch GmbH
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import argparse
import os
import pandas as pd
import torch
from tqdm import tqdm
from im2mesh import config, data
from im2mesh.checkpoints import CheckpointIO
parser = argparse.ArgumentParser(
description='Evaluate mesh algorithms.'
)
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
# Get configuration and basic arguments
args = parser.parse_args()
cfg = config.load_config(args.config, 'configs/default.yaml')
is_cuda = (torch.cuda.is_available() and not args.no_cuda)
device = torch.device("cuda" if is_cuda else "cpu")
# Shorthands
out_dir = cfg['training']['out_dir']
out_file = os.path.join(out_dir, 'eval_full.pkl')
out_file_class = os.path.join(out_dir, 'eval.csv')
# Dataset
dataset = config.get_dataset('test', cfg, return_idx=True)
model = config.get_model(cfg, device=device, dataset=dataset)
checkpoint_io = CheckpointIO(out_dir, model=model)
try:
checkpoint_io.load(cfg['test']['model_file'])
except FileExistsError:
print('Model file does not exist. Exiting.')
exit()
# Trainer
trainer = config.get_trainer(model, None, cfg, device=device)
# Print model
nparameters = sum(p.numel() for p in model.parameters())
print(model)
print('Total number of parameters: %d' % nparameters)
# Evaluate
model.eval()
eval_dicts = []
print('Evaluating networks...')
test_loader = torch.utils.data.DataLoader(
dataset, batch_size=1, shuffle=False,
collate_fn=data.collate_remove_none,
worker_init_fn=data.worker_init_fn)
# Handle each dataset separately
for it, data in enumerate(tqdm(test_loader)):
if data is None:
print('Invalid data.')
continue
# Get index etc.
idx = data['idx'].item()
try:
model_dict = dataset.get_model_dict(idx)
except AttributeError:
model_dict = {'model': str(idx), 'category': 'n/a'}
modelname = model_dict['model']
category_id = model_dict['category']
try:
category_name = dataset.metadata[category_id].get('name', 'n/a')
except AttributeError:
category_name = 'n/a'
eval_dict = {
'idx': idx,
'class id': category_id,
'class name': category_name,
'modelname':modelname.split('_')[0],
}
eval_dicts.append(eval_dict)
eval_data = trainer.eval_step(data)
eval_dict.update(eval_data)
# Create pandas dataframe and save
eval_df = pd.DataFrame(eval_dicts)
eval_df.set_index(['idx'], inplace=True)
eval_df.to_pickle(out_file)
# Create CSV file with main statistics
# eval_df_class = eval_df.groupby(by=['class name']).mean()
eval_df_class = eval_df.groupby(by=['modelname']).mean()
eval_df_class.to_csv(out_file_class)
# Print results
eval_df_class.loc['mean'] = eval_df_class.mean()
print(eval_df_class)