Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Counting Optimize] support multi-batch in tile classifier #2153

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions otx/algorithms/detection/adapters/mmdet/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ class ImageTilingDataset(OTXDetDataset):
max_annotation (int, optional): Limit the number of ground truth by
randomly select 5000 due to RAM OOM. Defaults to 5000.
sampling_ratio (flaot): Ratio for sampling entire tile dataset.
include_full_img (bool): Whether to include full image in the dataset.
"""

def __init__(
Expand All @@ -347,6 +348,7 @@ def __init__(
filter_empty_gt=True,
test_mode=False,
sampling_ratio=1.0,
include_full_img=False,
):
self.dataset = build_dataset(dataset)
self.CLASSES = self.dataset.CLASSES
Expand All @@ -362,6 +364,7 @@ def __init__(
max_annotation=max_annotation,
filter_empty_gt=filter_empty_gt if self.dataset.otx_dataset[0].subset != Subset.TESTING else False,
sampling_ratio=sampling_ratio if self.dataset.otx_dataset[0].subset != Subset.TESTING else 1.0,
include_full_img=include_full_img if self.dataset.otx_dataset[0].subset != Subset.TESTING else True,
)
self.flag = np.zeros(len(self), dtype=np.uint8)
self.pipeline = Compose(pipeline)
Expand Down
24 changes: 15 additions & 9 deletions otx/algorithms/detection/adapters/mmdet/datasets/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class Tile:
during tests. Defaults to True.
nproc (int, optional): Processes used for processing masks. Default: 4.
sampling_ratio (float): Ratio for sampling entire tile dataset. Default: 1.0.(No sample)
include_full_img (bool): Whether to include full-size image for inference or training. Default: False.
"""

def __init__(
Expand All @@ -77,6 +78,7 @@ def __init__(
filter_empty_gt: bool = True,
nproc: int = 2,
sampling_ratio: float = 1.0,
include_full_img: bool = False,
):
self.min_area_ratio = min_area_ratio
self.filter_empty_gt = filter_empty_gt
Expand All @@ -97,20 +99,21 @@ def __init__(
break

self.dataset = dataset
self.tiles_all, self.cached_results = self.gen_tile_ann()
self.tiles_all, self.cached_results = self.gen_tile_ann(include_full_img)
self.sample_num = max(int(len(self.tiles_all) * sampling_ratio), 1)
if sampling_ratio < 1.0:
self.tiles = sample(self.tiles_all, self.sample_num)
else:
self.tiles = self.tiles_all

@timeit
def gen_tile_ann(self) -> Tuple[List[Dict], List[Dict]]:
def gen_tile_ann(self, include_full_img) -> Tuple[List[Dict], List[Dict]]:
"""Generate tile annotations and cache the original image-level annotations.

Returns:
tiles: a list of tile annotations with some other useful information for data pipeline.
cache_result: a list of original image-level annotations.
include_full_img: whether to include full-size image for inference or training.
"""
tiles = []
cache_result = []
Expand All @@ -119,7 +122,8 @@ def gen_tile_ann(self) -> Tuple[List[Dict], List[Dict]]:

pbar = tqdm(total=len(self.dataset) * 2, desc="Generating tile annotations...")
for idx, result in enumerate(cache_result):
tiles.append(self.gen_single_img(result, dataset_idx=idx))
if include_full_img:
tiles.append(self.gen_single_img(result, dataset_idx=idx))
pbar.update(1)

for idx, result in enumerate(cache_result):
Expand All @@ -137,6 +141,7 @@ def gen_single_img(self, result: Dict, dataset_idx: int) -> Dict:
Returns:
Dict: annotation with some other useful information for data pipeline.
"""
result["full_res_image"] = True
result["tile_box"] = (0, 0, result["img_shape"][1], result["img_shape"][0])
result["dataset_idx"] = dataset_idx
result["original_shape_"] = result["img_shape"]
Expand Down Expand Up @@ -170,20 +175,21 @@ def gen_tiles_single_img(self, result: Dict, dataset_idx: int) -> List[Dict]:
height, width = img_shape[:2]
_tile = self.prepare_result(result)

num_patches_h = int((height - self.tile_size) / self.stride) + 1
num_patches_w = int((width - self.tile_size) / self.stride) + 1
num_patches_h = int(height / self.stride) + 1
num_patches_w = int(width / self.stride) + 1
for (_, _), (loc_i, loc_j) in zip(
product(range(num_patches_h), range(num_patches_w)),
product(
range(0, height - self.tile_size + 1, self.stride),
range(0, width - self.tile_size + 1, self.stride),
range(0, height, self.stride),
range(0, width, self.stride),
),
):
x_1 = loc_j
x_2 = loc_j + self.tile_size
x_2 = min(loc_j + self.tile_size, width)
y_1 = loc_i
y_2 = loc_i + self.tile_size
y_2 = min(loc_i + self.tile_size, height)
tile = copy.deepcopy(_tile)
tile["full_res_image"] = False
tile["original_shape_"] = img_shape
tile["ori_shape"] = (y_2 - y_1, x_2 - x_1, 3)
tile["img_shape"] = tile["ori_shape"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def simple_test(self, img: torch.Tensor) -> torch.Tensor:

out = self.forward(img)
with no_nncf_trace():
return self.sigmoid(out)[0][0]
return self.sigmoid(out).flatten()


# pylint: disable=too-many-ancestors
Expand Down Expand Up @@ -146,7 +146,21 @@ def forward_train(
losses.update(rcnn_loss)
return losses

def simple_test(self, img, img_metas, proposals=None, rescale=False):
@staticmethod
def make_fake_results(num_classes):
"""Make fake results.

Returns:
tuple: MaskRCNN output
"""
bbox_results = []
mask_results = []
for _ in range(num_classes):
bbox_results.append(np.empty((0, 5), dtype=np.float32))
mask_results.append([])
return bbox_results, mask_results

def simple_test(self, img, img_metas, proposals=None, rescale=False, full_res_image=False):
"""Simple test.

Tile classifier is used to filter out images without any objects.
Expand All @@ -161,28 +175,30 @@ def simple_test(self, img, img_metas, proposals=None, rescale=False):
Returns:
tuple: MaskRCNN output
"""
keep = self.tile_classifier.simple_test(img) > 0.45

if not keep:
tmp_results = []
num_classes = 1
bbox_results = []
mask_results = []
for _ in range(num_classes):
bbox_results.append(np.empty((0, 5), dtype=np.float32))
mask_results.append([])
tmp_results.append((bbox_results, mask_results))
return tmp_results

assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img)

if proposals is None:
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
else:
proposal_list = proposals

return self.roi_head.simple_test(x, proposal_list, img_metas, rescale=rescale)
keep = self.tile_classifier.simple_test(img) > 0.45
keep = full_res_image[0] | keep

results = []
for _ in range(len(img)):
fake_result = CustomMaskRCNNTileOptimized.make_fake_results(self.roi_head.bbox_head.num_classes)
results.append(fake_result)

if any(keep):
img = img[keep]
img_metas = [item for keep, item in zip(keep, img_metas) if keep]
assert self.with_bbox, "Bbox head must be implemented."
x = self.extract_feat(img)

if proposals is None:
proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
else:
proposal_list = proposals
maskrcnn_results = self.roi_head.simple_test(x, proposal_list, img_metas, rescale=rescale)
for i, keep_flag in enumerate(keep):
if keep_flag:
results[i] = maskrcnn_results.pop(0)
return results


if is_mmdeploy_enabled():
Expand Down
5 changes: 5 additions & 0 deletions otx/algorithms/detection/adapters/mmdet/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,11 @@ def patch_tiling(config, hparams, dataset=None):

if hparams.tiling_parameters.enable_tile_classifier:
logger.info("Tile classifier enabled")

for subset in ('val', 'test'):
if config.data[subset].pipeline[0]['transforms'][-1]['type'] == 'Collect':
config.data[subset].pipeline[0]['transforms'][-1]['keys'].append('full_res_image')

logger.info(f"Patch model from: {config.model.type} to CustomMaskRCNNTileOptimized")
config.model.type = "CustomMaskRCNNTileOptimized"

Expand Down
8 changes: 5 additions & 3 deletions otx/api/utils/tiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ def tile(self, image: np.ndarray) -> List[List[int]]:

coords = [[0, 0, width, height]]
for (loc_j, loc_i) in product(
range(0, width - self.tile_size + 1, int(self.tile_size * (1 - self.overlap))),
range(0, height - self.tile_size + 1, int(self.tile_size * (1 - self.overlap))),
range(0, width, int(self.tile_size * (1 - self.overlap))),
range(0, height, int(self.tile_size * (1 - self.overlap))),
):
coords.append([loc_j, loc_i, loc_j + self.tile_size, loc_i + self.tile_size])
x2 = min(loc_j + self.tile_size, width)
y2 = min(loc_i + self.tile_size, height)
coords.append([loc_j, loc_i, x2, y2])
return coords

def filter_tiles_by_objectness(
Expand Down
12 changes: 8 additions & 4 deletions otx/core/data/adapter/base_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,15 +329,19 @@ def _get_original_bbox_entity(self, annotation: DatumAnnotation) -> Annotation:
labels=[ScoredLabel(label=self.label_entities[annotation.label])],
)

def _get_polygon_entity(self, annotation: DatumAnnotation, width: int, height: int) -> Annotation:
def _get_polygon_entity(self, annotation: DatumAnnotation, width: int, height: int, num_polygons: int = -1) -> Annotation:
"""Get polygon entity."""
return Annotation(
Polygon(
polygon = Polygon(
points=[
Point(x=annotation.points[i] / width, y=annotation.points[i + 1] / height)
for i in range(0, len(annotation.points), 2)
]
),
)
step = 1 if num_polygons == -1 else len(polygon.points)//num_polygons
points = [polygon.points[i] for i in range(0, len(polygon.points), step)]

return Annotation(
Polygon(points),
labels=[ScoredLabel(label=self.label_entities[annotation.label])],
)

Expand Down
2 changes: 1 addition & 1 deletion otx/core/data/adapter/detection_dataset_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_otx_dataset(self) -> DatasetEntity:
and ann.type == DatumAnnotationType.polygon
):
if self._is_normal_polygon(ann):
shapes.append(self._get_polygon_entity(ann, image.width, image.height))
shapes.append(self._get_polygon_entity(ann, image.width, image.height, 20))
if self.task_type is TaskType.DETECTION and ann.type == DatumAnnotationType.bbox:
if self._is_normal_bbox(ann.points[0], ann.points[1], ann.points[2], ann.points[3]):
shapes.append(self._get_normalized_bbox_entity(ann, image.width, image.height))
Expand Down