diff --git a/cvat-sdk/cvat_sdk/pytorch/__init__.py b/cvat-sdk/cvat_sdk/pytorch/__init__.py index f0ac2d54d643..55b88186e7a5 100644 --- a/cvat-sdk/cvat_sdk/pytorch/__init__.py +++ b/cvat-sdk/cvat_sdk/pytorch/__init__.py @@ -45,17 +45,52 @@ class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException): @attrs.frozen class FrameAnnotations: + """ + Contains annotations that pertain to a single frame. + """ + tags: List[LabeledImage] = attrs.Factory(list) shapes: List[LabeledShape] = attrs.Factory(list) @attrs.frozen class Target: + """ + Non-image data for a dataset sample. + """ + annotations: FrameAnnotations + """Annotations for the frame corresponding to the sample.""" + label_id_to_index: Mapping[int, int] + """ + A mapping from label_id values in `LabeledImage` and `LabeledShape` objects + to an index in the range [0, num_labels), where num_labels is the number of labels + defined in the task. This mapping is consistent across all samples for a given task. + """ class TaskVisionDataset(torchvision.datasets.VisionDataset): + """ + Represents a task on a CVAT server as a PyTorch Dataset. + + This dataset contains one sample for each frame in the task, in the same + order as the frames are in the task. Deleted frames are omitted. + Before transforms are applied, each sample is a tuple of + (image, target), where: + + * image is a `PIL.Image.Image` object for the corresponding frame. + * target is a `Target` object containing annotations for the frame. + + This class caches all data and annotations for the task on the local file system + during construction. If the task is updated on the server, the cache is updated. + + Limitations: + + * Only tasks with image (not video) data are supported at the moment. + * Track annotations are currently not accessible. + """ + def __init__( self, client: cvat_sdk.core.Client, @@ -65,6 +100,15 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: + """ + Creates a dataset corresponding to the task with ID `task_id` on the + server that `client` is connected to. + + `transforms`, `transform` and `target_transforms` are optional transformation + functions; see the documentation for `torchvision.datasets.VisionDataset` for + more information. + """ + self._logger = client.logger self._logger.info(f"Fetching task {task_id}...") @@ -206,6 +250,12 @@ def _ensure_model( return model def __getitem__(self, sample_index: int): + """ + Returns the sample with index `sample_index`. + + `sample_index` must satisfy the condition `0 <= sample_index < len(self)`. + """ + frame_index = self._active_frame_indexes[sample_index] chunk_index = frame_index // self._task.data_chunk_size member_index = frame_index % self._task.data_chunk_size @@ -225,11 +275,22 @@ def __getitem__(self, sample_index: int): return sample_image, sample_target def __len__(self) -> int: + """Returns the number of samples in the dataset.""" return len(self._active_frame_indexes) @attrs.frozen class ExtractSingleLabelIndex: + """ + A target transform that takes a `Target` object and produces a single label index + based on the tag in that object. + + This makes the dataset samples compatible with the image classification networks + in torchvision. + + If the annotations contain no tags, or multiple tags, raises a `ValueError`. + """ + def __call__(self, target: Target) -> int: tags = target.annotations.tags if not tags: @@ -251,11 +312,32 @@ class LabeledBoxes(TypedDict): @attrs.frozen class ExtractBoundingBoxes: + """ + A target transform that takes a `Target` object and returns a dictionary compatible + with the object detection networks in torchvision. + + The dictionary contains the following entries: + + "boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape + in the annotations. + "labels": a sequence of corresponding label indices. + + Limitations: + + * Only the following shape types are supported: rectangle, polygon, polyline, + points, ellipse. + * Rotated shapes are not supported. + + Unsupported shapes will cause a `UnsupportedDatasetError` exception to be + raised unless they are filtered out by `include_shape_types`. + """ + include_shape_types: FrozenSet[str] = attrs.field( converter=frozenset, validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)), kw_only=True, ) + """Shapes whose type is not in this set will be ignored.""" def __call__(self, target: Target) -> LabeledBoxes: boxes = []