Skip to content

Commit

Permalink
fix bug of evaluator_cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
darkliang committed Mar 20, 2024
1 parent ec30085 commit eb0b102
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions opengait/modeling/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,10 +438,10 @@ def run_train(model):
@ staticmethod
def run_test(model):
"""Accept the instance object(model) here, and then run the test loop."""

if torch.distributed.get_world_size() != model.engine_cfg['sampler']['batch_size']:
evaluator_cfg = model.cfgs['evaluator_cfg']
if torch.distributed.get_world_size() != evaluator_cfg['sampler']['batch_size']:
raise ValueError("The batch size ({}) must be equal to the number of GPUs ({}) in testing mode!".format(
model.engine_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
evaluator_cfg['sampler']['batch_size'], torch.distributed.get_world_size()))
rank = torch.distributed.get_rank()
with torch.no_grad():
info_dict = model.inference(rank)
Expand All @@ -454,13 +454,13 @@ def run_test(model):
info_dict.update({
'labels': label_list, 'types': types_list, 'views': views_list})

if 'eval_func' in model.cfgs["evaluator_cfg"].keys():
eval_func = model.cfgs['evaluator_cfg']["eval_func"]
if 'eval_func' in evaluator_cfg.keys():
eval_func = evaluator_cfg["eval_func"]
else:
eval_func = 'identification'
eval_func = getattr(eval_functions, eval_func)
valid_args = get_valid_args(
eval_func, model.cfgs["evaluator_cfg"], ['metric'])
eval_func, evaluator_cfg, ['metric'])
try:
dataset_name = model.cfgs['data_cfg']['test_dataset_name']
except:
Expand Down

0 comments on commit eb0b102

Please sign in to comment.