Skip to content

Commit

Permalink
Add medium and large reg tests for multiclass and multilabel classifi…
Browse files Browse the repository at this point in the history
…cations (#2770)

* Add reg tests for classification

* Fix typo

* Reflect reviews
  • Loading branch information
sungmanc authored Jan 11, 2024
1 parent 7843f3b commit 5b68f15
Showing 1 changed file with 45 additions and 18 deletions.
63 changes: 45 additions & 18 deletions tests/regression/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@ class TestMultiClassCls(BaseTest):
extra_overrides={"trainer.max_epochs": "20"},
)
for idx in range(1, 4)
] + [
DatasetTestCase(
name=f"multiclass_CUB_medium",
data_root=Path("multiclass_CUB_medium"),
data_format="imagenet_with_subset_dirs",
num_classes=67,
extra_overrides={"trainer.max_epochs": "20"},
),
DatasetTestCase(
name=f"multiclass_food101_large",
data_root=Path("multiclass_food101_large"),
data_format="imagenet_with_subset_dirs",
num_classes=20,
extra_overrides={"trainer.max_epochs": "20"},
)
]

@pytest.mark.parametrize(
Expand Down Expand Up @@ -159,6 +174,21 @@ class TestMultilabelCls(BaseTest):
extra_overrides={"trainer.max_epochs": "20"},
)
for idx in range(1, 4)
] + [
DatasetTestCase(
name=f"multilabel_CUB_medium",
data_root=Path("multilabel_CUB_medium"),
data_format="datumaro",
num_classes=68,
extra_overrides={"trainer.max_epochs": "20"},
),
DatasetTestCase(
name=f"multilabel_food101_large",
data_root=Path("multilabel_food101_large"),
data_format="datumaro",
num_classes=21,
extra_overrides={"trainer.max_epochs": "20"},
)
]

@pytest.mark.parametrize(
Expand Down Expand Up @@ -265,25 +295,22 @@ class TestObjectDetection(BaseTest):
extra_overrides={"trainer.max_epochs": "10"},
)
for idx in range(1, 4)
] + [
DatasetTestCase(
name="pothole_medium",
data_root="pothole_medium",
data_format="coco",
num_classes=1,
extra_overrides={"trainer.max_epochs": "10"}
),
DatasetTestCase(
name="vitens_large",
data_root="vitens_large",
data_format="coco",
num_classes=1,
extra_overrides={"trainer.max_epochs": "10"}
)
]
DATASET_TEST_CASES.extend(
[
DatasetTestCase(
name="pothole_medium",
data_root="pothole_medium",
data_format="coco",
num_classes=1,
extra_overrides={"trainer.max_epochs": "10"}
),
DatasetTestCase(
name="vitens_large",
data_root="vitens_large",
data_format="coco",
num_classes=1,
extra_overrides={"trainer.max_epochs": "10"}
),
]
)

@pytest.mark.parametrize(
"model_test_case",
Expand Down

0 comments on commit 5b68f15

Please sign in to comment.