Skip to content

Commit

Permalink
cvat_sdk/pytorch: add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
SpecLad committed Dec 7, 2022
1 parent ccb6ae3 commit d19cea6
Showing 1 changed file with 82 additions and 0 deletions.
82 changes: 82 additions & 0 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,52 @@ class UnsupportedDatasetError(cvat_sdk.core.exceptions.CvatSdkException):

@attrs.frozen
class FrameAnnotations:
"""
Contains annotations that pertain to a single frame.
"""

tags: List[LabeledImage] = attrs.Factory(list)
shapes: List[LabeledShape] = attrs.Factory(list)


@attrs.frozen
class Target:
"""
Non-image data for a dataset sample.
"""

annotations: FrameAnnotations
"""Annotations for the frame corresponding to the sample."""

label_id_to_index: Mapping[int, int]
"""
A mapping from label_id values in `LabeledImage` and `LabeledShape` objects
to an index in the range [0, num_labels), where num_labels is the number of labels
defined in the task. This mapping is consistent across all samples for a given task.
"""


class TaskVisionDataset(torchvision.datasets.VisionDataset):
"""
Represents a task on a CVAT server as a PyTorch Dataset.
This dataset contains one sample for each frame in the task, in the same
order as the frames are in the task. Deleted frames are omitted.
Before transforms are applied, each sample is a tuple of
(image, target), where:
* image is a `PIL.Image.Image` object for the corresponding frame.
* target is a `Target` object containing annotations for the frame.
This class caches all data and annotations for the task on the local file system
during construction. If the task is updated on the server, the cache is updated.
Limitations:
* Only tasks with image (not video) data are supported at the moment.
* Track annotations are currently not accessible.
"""

def __init__(
self,
client: cvat_sdk.core.Client,
Expand All @@ -65,6 +100,15 @@ def __init__(
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
server that `client` is connected to.
`transforms`, `transform` and `target_transforms` are optional transformation
functions; see the documentation for `torchvision.datasets.VisionDataset` for
more information.
"""

self._logger = client.logger

self._logger.info(f"Fetching task {task_id}...")
Expand Down Expand Up @@ -206,6 +250,12 @@ def _ensure_model(
return model

def __getitem__(self, sample_index: int):
"""
Returns the sample with index `sample_index`.
`sample_index` must satisfy the condition `0 <= sample_index < len(self)`.
"""

frame_index = self._active_frame_indexes[sample_index]
chunk_index = frame_index // self._task.data_chunk_size
member_index = frame_index % self._task.data_chunk_size
Expand All @@ -225,11 +275,22 @@ def __getitem__(self, sample_index: int):
return sample_image, sample_target

def __len__(self) -> int:
"""Returns the number of samples in the dataset."""
return len(self._active_frame_indexes)


@attrs.frozen
class ExtractSingleLabelIndex:
"""
A target transform that takes a `Target` object and produces a single label index
based on the tag in that object.
This makes the dataset samples compatible with the image classification networks
in torchvision.
If the annotations contain no tags, or multiple tags, raises a `ValueError`.
"""

def __call__(self, target: Target) -> int:
tags = target.annotations.tags
if not tags:
Expand All @@ -251,11 +312,32 @@ class LabeledBoxes(TypedDict):

@attrs.frozen
class ExtractBoundingBoxes:
"""
A target transform that takes a `Target` object and returns a dictionary compatible
with the object detection networks in torchvision.
The dictionary contains the following entries:
"boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape
in the annotations.
"labels": a sequence of corresponding label indices.
Limitations:
* Only the following shape types are supported: rectangle, polygon, polyline,
points, ellipse.
* Rotated shapes are not supported.
Unsupported shapes will cause a `UnsupportedDatasetError` exception to be
raised unless they are filtered out by `include_shape_types`.
"""

include_shape_types: FrozenSet[str] = attrs.field(
converter=frozenset,
validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)),
kw_only=True,
)
"""Shapes whose type is not in this set will be ignored."""

def __call__(self, target: Target) -> LabeledBoxes:
boxes = []
Expand Down

0 comments on commit d19cea6

Please sign in to comment.