diff --git a/external/model-preparation-algorithm/mpa_tasks/apis/task.py b/external/model-preparation-algorithm/mpa_tasks/apis/task.py index fa8eb0e1b84..a102a847f45 100644 --- a/external/model-preparation-algorithm/mpa_tasks/apis/task.py +++ b/external/model-preparation-algorithm/mpa_tasks/apis/task.py @@ -85,7 +85,7 @@ def _run_task(self, stage_module, mode=None, dataset=None, parameters=None, **kw if dataset is not None: train_data_cfg = Stage.get_train_data_cfg(self._data_cfg) # if dataset size is smaller than batch size - if (len(train_data_cfg.get('ote_dataset', [])) < self._recipe_cfg.data.get('samples_per_gpu', 2)): + if 0 < len(dataset) < self._recipe_cfg.data.get('samples_per_gpu', 2): train_data_cfg.drop_last = False train_data_cfg['data_classes'] = data_classes new_classes = np.setdiff1d(data_classes, model_classes).tolist()