From 88196675b5c3baa4dbeb2df534366aeb12fe4e0b Mon Sep 17 00:00:00 2001 From: Shay Aharon <80472096+shaydeci@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:28:02 +0200 Subject: [PATCH] added registries (#1862) Co-authored-by: Ofri Masad --- .../datasets/detection_datasets/coco_format_detection.py | 2 ++ .../datasets/detection_datasets/yolo_format_detection.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py b/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py index 97cd0fa763..5e78f4b7af 100644 --- a/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py +++ b/src/super_gradients/training/datasets/detection_datasets/coco_format_detection.py @@ -9,6 +9,7 @@ from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.exceptions.dataset_exceptions import DatasetValidationException, ParameterMismatchException from super_gradients.common.deprecate import deprecated_parameter +from super_gradients.common.registry import register_dataset from super_gradients.training.datasets.data_formats.bbox_formats.xywh import xywh_to_xyxy_inplace from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset from super_gradients.training.datasets.data_formats.default_formats import XYXY_LABEL @@ -17,6 +18,7 @@ logger = get_logger(__name__) +@register_dataset("COCOFormatDetectionDataset") class COCOFormatDetectionDataset(DetectionDataset): """Base dataset to load ANY dataset that is with a similar structure to the COCO dataset. - Annotation file (.json). It has to respect the exact same format as COCO, for both the json schema and the bbox format (xywh). diff --git a/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py b/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py index 987561c815..37750b5a64 100644 --- a/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py +++ b/src/super_gradients/training/datasets/detection_datasets/yolo_format_detection.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple from super_gradients.common.abstractions.abstract_logger import get_logger +from super_gradients.common.registry import register_dataset from super_gradients.training.utils.media.image import is_image from super_gradients.training.datasets.detection_datasets.detection_dataset import DetectionDataset from super_gradients.training.datasets.data_formats import ConcatenatedTensorFormatConverter @@ -13,6 +14,7 @@ logger = get_logger(__name__) +@register_dataset("YoloDarknetFormatDetectionDataset") class YoloDarknetFormatDetectionDataset(DetectionDataset): """Base dataset to load ANY dataset that is with a similar structure to the Yolo/Darknet dataset.