Skip to content

Commit

Permalink
Fix dataset length bug in mpa task
Browse files Browse the repository at this point in the history
Signed-off-by: Songki Choi <songki.choi@intel.com>
  • Loading branch information
goodsong81 committed Nov 7, 2022
1 parent 6b0c0ec commit a840c57
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
4 changes: 3 additions & 1 deletion external/model-preparation-algorithm/mpa_tasks/apis/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def _run_task(self, stage_module, mode=None, dataset=None, parameters=None, **kw
if dataset is not None:
train_data_cfg = Stage.get_train_data_cfg(self._data_cfg)
# if dataset size is smaller than batch size
if (len(train_data_cfg.get('ote_dataset', [])) < self._recipe_cfg.data.get('samples_per_gpu', 2)):
dataset = train_data_cfg.get('ote_dataset', None)
dataset_len = len(dataset) if dataset else 0
if 0 < dataset_len < self._recipe_cfg.data.get('samples_per_gpu', 2):
train_data_cfg.drop_last = False
train_data_cfg['data_classes'] = data_classes
new_classes = np.setdiff1d(data_classes, model_classes).tolist()
Expand Down

0 comments on commit a840c57

Please sign in to comment.