diff --git a/otx/cli/tools/eval.py b/otx/cli/tools/eval.py index 604f6718020..8172f32befc 100644 --- a/otx/cli/tools/eval.py +++ b/otx/cli/tools/eval.py @@ -56,8 +56,12 @@ def parse_args(): parser = argparse.ArgumentParser() if not os.path.exists("./template.yaml"): parser.add_argument("template") - parser.add_argument("--data", required=False, default="./data.yaml") - required = not os.path.exists("./data.yaml") + parser.add_argument("--data", required=False) + parsed, _ = parser.parse_known_args() + required = True + if parsed.data is not None: + assert os.path.exists(parsed.data) + required = False parser.add_argument( "--test-ann-files", diff --git a/otx/cli/tools/optimize.py b/otx/cli/tools/optimize.py index 9c8470e7dd3..fa484067af6 100644 --- a/otx/cli/tools/optimize.py +++ b/otx/cli/tools/optimize.py @@ -61,8 +61,12 @@ def parse_args(): parser = argparse.ArgumentParser() if not os.path.exists("./template.yaml"): parser.add_argument("template") - parser.add_argument("--data", required=False, default="./data.yaml") - required = not os.path.exists("./data.yaml") + parser.add_argument("--data", required=False) + parsed, _ = parser.parse_known_args() + required = True + if parsed.data is not None: + assert os.path.exists(parsed.data) + required = False parser.add_argument( "--train-ann-files", diff --git a/otx/cli/tools/train.py b/otx/cli/tools/train.py index d7b2ad023a0..05ecabbd1c6 100644 --- a/otx/cli/tools/train.py +++ b/otx/cli/tools/train.py @@ -69,8 +69,12 @@ def parse_args(): parser = argparse.ArgumentParser() if not os.path.exists("./template.yaml"): parser.add_argument("template") - parser.add_argument("--data", required=False, default="./data.yaml") - required = not os.path.exists("./data.yaml") + parser.add_argument("--data", required=False) + parsed, _ = parser.parse_known_args() + required = True + if parsed.data is not None: + assert os.path.exists(parsed.data) + required = False parser.add_argument( "--train-ann-files", diff --git a/otx/cli/utils/config.py b/otx/cli/utils/config.py index 060608e29d7..192bab49f10 100644 --- a/otx/cli/utils/config.py +++ b/otx/cli/utils/config.py @@ -43,7 +43,7 @@ def configure_dataset(args): data_subset_format = {"ann-files": None, "data-roots": None} data_config = {"data": {subset: data_subset_format.copy() for subset in ("train", "val", "test")}} data_config["data"]["unlabeled"] = {"file-list": None, "data-roots": None} - if os.path.exists(args.data): + if args.data is not None and os.path.exists(args.data): with open(args.data, "r", encoding="UTF-8") as stream: data_config = yaml.safe_load(stream) stream.close() diff --git a/otx/mpa/modules/models/segmentors/pixel_weights_mixin.py b/otx/mpa/modules/models/segmentors/pixel_weights_mixin.py index 79ff07e130c..b5ed556856f 100644 --- a/otx/mpa/modules/models/segmentors/pixel_weights_mixin.py +++ b/otx/mpa/modules/models/segmentors/pixel_weights_mixin.py @@ -5,6 +5,7 @@ import torch.nn as nn from mmseg.core import add_prefix from mmseg.models.builder import build_loss +from mmseg.ops import resize from ..losses.utils import LossEqualizer diff --git a/otx/mpa/seg/stage.py b/otx/mpa/seg/stage.py index 4469fc9d6f5..dd4c7cfcf8f 100644 --- a/otx/mpa/seg/stage.py +++ b/otx/mpa/seg/stage.py @@ -3,7 +3,6 @@ # from mmcv import ConfigDict -from mmcv.runner import load_checkpoint from otx.algorithms.segmentation.adapters.mmseg.utils.builder import build_segmentor from otx.mpa.stage import Stage diff --git a/tests/integration/cli/segmentation/test_segmentation.py b/tests/integration/cli/segmentation/test_segmentation.py index 34244aff6d4..b07afee6403 100644 --- a/tests/integration/cli/segmentation/test_segmentation.py +++ b/tests/integration/cli/segmentation/test_segmentation.py @@ -275,6 +275,7 @@ def test_otx_eval(self, template, tmp_dir_path): args_selfsl = { + "--data": "./data.yaml", "--train-ann-file": "data/segmentation/custom/annotations/detcon_masks", "--train-data-roots": "data/segmentation/custom/images/training", "--input": "data/segmentation/custom/images/training",