Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

support directly loading the mask #4144

Merged
merged 1 commit into from
Sep 6, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions nni/compression/pytorch/speedup/compressor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import queue
import logging
import copy
Expand Down Expand Up @@ -35,8 +35,8 @@ def __init__(self, model, dummy_input, masks_file, map_location=None,
Note: The first dimension of the dummy_input should be the batchsize.
The dummy input for ```jit.trace```, users should put it on the right
device.
masks_file : str
The path of user provided mask file
masks_file : str/dict
The path of user provided mask file, or the mask object
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
batch_dim : int
Expand All @@ -63,9 +63,13 @@ def __init__(self, model, dummy_input, masks_file, map_location=None,
# load the mask tensor to the same device with the dummy_input
# self.masks save the mask tensors pruned by the user and the infered
# masks of the others modules
self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device))

if isinstance(masks_file, str) and os.path.exists(masks_file):
self.masks = torch.load(
masks_file, map_location if map_location is not None else str(self.device))
elif isinstance(masks_file, dict):
self.masks = masks_file
else:
raise Exception('Please provide the mask or the path of the mask file')
self.constant = {}
# self.internal_result save the internal output of the submodules
self.internal_result = {}
Expand Down