Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

fix flops counter bug in auto_pruners_torch.py #3265

Merged
merged 1 commit into from
Jan 6, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/model_compress/pruning/auto_pruners_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_trained_model_optimizer(args, device, train_loader, val_loader, criterio

if args.save_model:
torch.save(state_dict, os.path.join(args.experiment_data_dir, 'model_trained.pth'))
print('Model trained saved to %s', args.experiment_data_dir)
print('Model trained saved to %s' % args.experiment_data_dir)

return model, optimizer

Expand Down Expand Up @@ -312,7 +312,7 @@ def evaluator(model):
if args.save_model:
pruner.export_model(
os.path.join(args.experiment_data_dir, 'model_masked.pth'), os.path.join(args.experiment_data_dir, 'mask.pth'))
print('Masked model saved to %s', args.experiment_data_dir)
print('Masked model saved to %s' % args.experiment_data_dir)

# model speed up
if args.speed_up:
Expand All @@ -336,7 +336,7 @@ def evaluator(model):
result['performance']['speedup'] = evaluation_result

torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_speed_up.pth'))
print('Speed up model saved to %s', args.experiment_data_dir)
print('Speed up model saved to %s' % args.experiment_data_dir)
flops, params, _ = count_flops_params(model, get_input_size(args.dataset))
result['flops']['speedup'] = flops
result['params']['speedup'] = params
Expand Down Expand Up @@ -367,7 +367,7 @@ def evaluator(model):
torch.save(model.state_dict(), os.path.join(args.experiment_data_dir, 'model_fine_tuned.pth'))

print('Evaluation result (fine tuned): %s' % best_acc)
print('Fined tuned model saved to %s', args.experiment_data_dir)
print('Fined tuned model saved to %s' % args.experiment_data_dir)
result['performance']['finetuned'] = best_acc

with open(os.path.join(args.experiment_data_dir, 'result.json'), 'w+') as f:
Expand Down