Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK: Improve the PyTorch adapter layer #5455

Merged
merged 2 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions cvat-sdk/cvat_sdk/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import attrs
import attrs.validators
import PIL.Image
import torch
import torchvision.datasets
from typing_extensions import TypedDict

Expand Down Expand Up @@ -65,8 +66,7 @@ class Target:
label_id_to_index: Mapping[int, int]
"""
A mapping from label_id values in `LabeledImage` and `LabeledShape` objects
to an index in the range [0, num_labels), where num_labels is the number of labels
defined in the task. This mapping is consistent across all samples for a given task.
to an integer index. This mapping is consistent across all samples for a given task.
"""


Expand Down Expand Up @@ -99,6 +99,7 @@ def __init__(
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
label_name_to_index: Mapping[str, int] = None,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
Expand All @@ -107,6 +108,17 @@ def __init__(
`transforms`, `transform` and `target_transforms` are optional transformation
functions; see the documentation for `torchvision.datasets.VisionDataset` for
more information.

`label_name_to_index` affects the `label_id_to_index` member in `Target` objects
returned by the dataset. If it is specified, then it must contain an entry for
each label name in the task. The `label_id_to_index` mapping will be constructed
so that each label will be mapped to the index corresponding to the label's name
in `label_name_to_index`.

If `label_name_to_index` is unspecified or set to `None`, then `label_id_to_index`
will map each label ID to a distinct integer in the range [0, `num_labels`), where
`num_labels` is the number of labels defined in the task. This mapping will be
generally unpredictable, but consistent for a given task.
"""

self._logger = client.logger
Expand Down Expand Up @@ -162,12 +174,19 @@ def __init__(

self._logger.info("All chunks downloaded")

self._label_id_to_index = types.MappingProxyType(
{
label["id"]: label_index
for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id))
}
)
if label_name_to_index is None:
self._label_id_to_index = types.MappingProxyType(
{
label.id: label_index
for label_index, label in enumerate(
sorted(self._task.labels, key=lambda l: l.id)
)
}
)
else:
self._label_id_to_index = types.MappingProxyType(
{label.id: label_name_to_index[label.name] for label in self._task.labels}
)

annotations = self._ensure_model(
"annotations.json", LabeledData, self._task.get_annotations, "annotations"
Expand Down Expand Up @@ -283,7 +302,7 @@ def __len__(self) -> int:
class ExtractSingleLabelIndex:
"""
A target transform that takes a `Target` object and produces a single label index
based on the tag in that object.
based on the tag in that object, as a 0-dimensional tensor.

This makes the dataset samples compatible with the image classification networks
in torchvision.
Expand All @@ -299,12 +318,12 @@ def __call__(self, target: Target) -> int:
if len(tags) > 1:
raise ValueError("sample has multiple tags")

return target.label_id_to_index[tags[0].label_id]
return torch.tensor(target.label_id_to_index[tags[0].label_id], dtype=torch.long)


class LabeledBoxes(TypedDict):
boxes: Sequence[Tuple[float, float, float, float]]
labels: Sequence[int]
boxes: torch.Tensor
labels: torch.Tensor


_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"])
Expand All @@ -318,9 +337,9 @@ class ExtractBoundingBoxes:

The dictionary contains the following entries:

"boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape
in the annotations.
"labels": a sequence of corresponding label indices.
"boxes": a tensor with shape [N, 4], where each row represents a bounding box of a shape
in the annotations in the (xmin, ymin, xmax, ymax) format.
"labels": a tensor with shape [N] containing corresponding label indices.

Limitations:

Expand Down Expand Up @@ -356,4 +375,7 @@ def __call__(self, target: Target) -> LabeledBoxes:
boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords)))
labels.append(target.label_id_to_index[shape.label_id])

return LabeledBoxes(boxes=boxes, labels=labels)
return LabeledBoxes(
boxes=torch.tensor(boxes, dtype=torch.float),
labels=torch.tensor(labels, dtype=torch.long),
)
29 changes: 24 additions & 5 deletions tests/python/sdk/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ def test_extract_single_label_index(self):
target_transform=cvatpt.ExtractSingleLabelIndex(),
)

assert dataset[5][1] == 0
assert dataset[6][1] == 1
assert torch.equal(dataset[5][1], torch.tensor(0))
assert torch.equal(dataset[6][1], torch.tensor(1))

with pytest.raises(ValueError):
# no tags
Expand All @@ -192,9 +192,15 @@ def test_extract_bounding_boxes(self):
target_transform=cvatpt.ExtractBoundingBoxes(include_shape_types={"rectangle"}),
)

assert dataset[0][1] == {"boxes": [], "labels": []}
assert dataset[6][1] == {"boxes": [(1.0, 2.0, 3.0, 4.0)], "labels": [1]}
assert dataset[7][1] == {"boxes": [], "labels": []} # points are filtered out
assert torch.equal(dataset[0][1]["boxes"], torch.tensor([]))
assert torch.equal(dataset[0][1]["labels"], torch.tensor([]))

assert torch.equal(dataset[6][1]["boxes"], torch.tensor([(1.0, 2.0, 3.0, 4.0)]))
assert torch.equal(dataset[6][1]["labels"], torch.tensor([1]))

# points are filtered out
assert torch.equal(dataset[7][1]["boxes"], torch.tensor([]))
assert torch.equal(dataset[7][1]["labels"], torch.tensor([]))

def test_transforms(self):
dataset = cvatpt.TaskVisionDataset(
Expand All @@ -205,3 +211,16 @@ def test_transforms(self):

assert isinstance(dataset[0][0], cvatpt.Target)
assert isinstance(dataset[0][1], PIL.Image.Image)

def test_custom_label_mapping(self):
label_name_to_id = {label.name: label.id for label in self.task.labels}

dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
label_name_to_index={"person": 123, "car": 456},
)

_, target = dataset[5]
assert target.label_id_to_index[label_name_to_id["person"]] == 123
assert target.label_id_to_index[label_name_to_id["car"]] == 456