Skip to content

Commit

Permalink
drop DetectionReference prefix
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Apr 3, 2023
1 parent c80dc73 commit a859c09
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,25 +139,25 @@ def detection_ssdlite_pipeline_builder(*, input_type, api_version):

pipeline = []
if api_version == "v1":
pipeline.append(DetectionReferenceConvertCocoPolysToMaskV1())
pipeline.append(ConvertCocoPolysToMaskV1())

if input_type == "Tensor":
pipeline.append(DetectionReferencePILToTensorV1())
pipeline.append(PILToTensorV1())

pipeline.extend(
[
DetectionReferenceRandomIoUCropV1(),
DetectionReferenceRandomHorizontalFlipV1(p=0.5),
RandomIoUCropV1(),
RandomHorizontalFlipV1(p=0.5),
]
)

if input_type == "PIL":
pipeline.append(DetectionReferencePILToTensorV1())
pipeline.append(PILToTensorV1())

pipeline.append(DetectionReferenceConvertImageDtypeV1(torch.float))
pipeline.append(ConvertImageDtypeV1(torch.float))

elif api_version == "v2":
pipeline.append(WrapCocoDetectionReferenceSampleForTransformsV2())
pipeline.append(WrapCocoSampleForTransformsV2())

if input_type == "Tensor":
pipeline.append(transforms_v2.PILToTensor())
Expand Down Expand Up @@ -197,7 +197,7 @@ def _transform(self, inpt, params):
return F_v2.crop(inpt, **params)


class WrapCocoDetectionReferenceSampleForTransformsV2:
class WrapCocoSampleForTransformsV2:
def __init__(self):
num_samples = 117_266
wrapper_factory = WRAPPER_FACTORIES[datasets.CocoDetection]
Expand Down Expand Up @@ -237,7 +237,7 @@ def convert_coco_poly_to_mask(segmentations, height, width):
return masks


class DetectionReferenceConvertCocoPolysToMaskV1:
class ConvertCocoPolysToMaskV1:
def __call__(self, image, target):
w, h = image.size

Expand Down Expand Up @@ -296,7 +296,7 @@ def __call__(self, image, target):
return image, target


class DetectionReferenceRandomHorizontalFlipV1(transforms_v1.RandomHorizontalFlip):
class RandomHorizontalFlipV1(transforms_v1.RandomHorizontalFlip):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
Expand All @@ -314,15 +314,15 @@ def forward(
return image, target


class DetectionReferencePILToTensorV1(nn.Module):
class PILToTensorV1(nn.Module):
def forward(
self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F_v1.pil_to_tensor(image)
return image, target


class DetectionReferenceConvertImageDtypeV1(nn.Module):
class ConvertImageDtypeV1(nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype
Expand All @@ -334,7 +334,7 @@ def forward(
return image, target


class DetectionReferenceRandomIoUCropV1(nn.Module):
class RandomIoUCropV1(nn.Module):
def __init__(
self,
min_scale: float = 0.3,
Expand Down

0 comments on commit a859c09

Please sign in to comment.