Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OTX optimize for visual prompting task #2318

Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ All notable changes to this project will be documented in this file.
- Add new visual prompting task: train/eval (https://github.com/openvinotoolkit/training_extensions/pull/2203)
- Add new visual prompting task: export (https://github.com/openvinotoolkit/training_extensions/pull/2274)
- Add new visual prompting task: deploy (https://github.com/openvinotoolkit/training_extensions/pull/2311)
- Add new visual prompting task: optimize (PTQ) (https://github.com/openvinotoolkit/training_extensions/pull/2318)
- Add new object detector ResNeXt101-ATSS (<https://github.com/openvinotoolkit/training_extensions/pull/2309>)

### Enhancements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,16 @@ def __init__(
preload: bool = False,
):
super().__init__(model_adapter, configuration, preload)
self.output_blob_name = "low_res_masks"

@classmethod
def parameters(cls): # noqa: D102
parameters = super().parameters()
parameters.update({"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048)})
return parameters

def _get_outputs(self):
return "low_res_masks"

def preprocess(self, inputs: Dict[str, Any], meta: Dict[str, Any]):
"""Preprocess prompts."""
processed_prompts = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def update_visual_prompting_config(
if groups:
for group in groups:
if group in ["learning_parameters", "nncf_optimization", "pot_parameters", "postprocessing"]:
if group in ["nncf_optimization", "pot_parameters"]:
# TODO (sungchul): Consider pot_parameters, nncf_optimization, and postprocessing
if group in ["nncf_optimization"]:
# TODO (sungchul): Consider nncf_optimization
logger.warning(f"{group} will be implemented.")
continue
update_visual_prompting_config(visual_prompting_config, getattr(otx_config, group))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import numpy as np
import torch
from torch import Tensor
from torch.nn import functional as F
from torchvision.transforms.functional import resize, to_pil_image # type: ignore


Expand All @@ -36,24 +35,26 @@ def __call__(self, item: Dict[str, Union[List, Tensor]]) -> Dict[str, Union[List
Dict[str, Union[List, Tensor]]: Dictionary of batch data.
"""
item["images"] = torch.as_tensor(
self.apply_image(item["images"]).transpose((2, 0, 1)), dtype=torch.get_default_dtype()
self.apply_image(item["images"], self.target_length).transpose((2, 0, 1)), dtype=torch.get_default_dtype()
)
item["gt_masks"] = [torch.as_tensor(gt_mask) for gt_mask in item["gt_masks"]]
item["bboxes"] = self.apply_boxes(item["bboxes"], item["original_size"])
if item["points"]:
item["points"] = self.apply_coords(item["points"], item["original_size"])
return item

def apply_image(self, image: np.ndarray) -> np.ndarray:
@classmethod
def apply_image(cls, image: np.ndarray, target_length: int) -> np.ndarray:
"""Expects a numpy array with shape HxWxC in uint8 format.

Args:
image (np.ndarray): Image array.
target_length (int): The length of the longest side of the image.

Returns:
np.ndarray: Resized image.
"""
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
target_size = cls.get_preprocess_shape(image.shape[0], image.shape[1], target_length)
return np.array(resize(to_pil_image(image), target_size))

def apply_coords(self, coords: np.ndarray, original_size: Union[List[Any], Tensor]) -> np.ndarray:
Expand Down Expand Up @@ -88,56 +89,6 @@ def apply_boxes(self, boxes: np.ndarray, original_size: Union[List[Any], Tensor]
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)

def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
"""Expects batched images with shape BxCxHxW and float format.

This transformation may not exactly match apply_image.
apply_image is the transformation expected by the model.

Args:
image (torch.Tensor): Image tensor.

Returns:
torch.Tensor: Resized image.
"""
# Expects an image in BCHW format. May not exactly match apply_image.
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
return F.interpolate(image, target_size, mode="bilinear", align_corners=False, antialias=True)

def apply_coords_torch(self, coords: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
"""Expects a torch tensor with length 2 in the last dimension.

Requires the original image size in (H, W) format.

Args:
coords (torch.Tensor): Coordinates tensor.
original_size (Tuple[int, ...]): Original size of image.

Returns:
torch.Tensor: Resized coordinates.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(original_size[0], original_size[1], self.target_length)
coords = deepcopy(coords).to(torch.float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords

def apply_boxes_torch(self, boxes: torch.Tensor, original_size: Tuple[int, ...]) -> torch.Tensor:
"""Expects a torch tensor with shape Bx4.

Requires the original image size in (H, W) format.

Args:
boxes (torch.Tensor): Boxes tensor.
original_size (Tuple[int, ...]): Original size of image.

Returns:
torch.Tensor: Resized boxes.
"""
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
return boxes.reshape(-1, 4)

@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
"""Compute the output size given input size and target long side length.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ def replace_state_dict_keys(state_dict, revise_keys):
state_dict = replace_state_dict_keys(state_dict, revise_keys)
self.load_state_dict(state_dict)

#################################################
# forward for inference (export/deploy) #
#################################################
##########################################################
# forward for inference (export/deploy/optimize) #
##########################################################
@torch.no_grad()
def forward(
self,
Expand All @@ -185,7 +185,7 @@ def forward(
point_labels: Tensor,
mask_input: Tensor,
has_mask_input: Tensor,
orig_size: Tensor,
# orig_size: Tensor,
):
"""Forward method for SAM inference (export/deploy).

Expand Down Expand Up @@ -227,16 +227,18 @@ def forward(
if self.config.model.return_single_mask:
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])

upscaled_masks = self.mask_postprocessing(masks, orig_size[0])
return scores, masks
# TODO (sungchul): apply inner postprocessing
# upscaled_masks = self.mask_postprocessing(masks, orig_size[0])

if self.config.model.return_extra_metrics:
stability_scores = self.calculate_stability_score(
upscaled_masks, self.config.model.mask_threshold, self.config.model.stability_score_offset
)
areas = (upscaled_masks > self.config.model.mask_threshold).sum(-1).sum(-1)
return upscaled_masks, scores, stability_scores, areas, masks
# if self.config.model.return_extra_metrics:
# stability_scores = self.calculate_stability_score(
# upscaled_masks, self.config.model.mask_threshold, self.config.model.stability_score_offset
# )
# areas = (upscaled_masks > self.config.model.mask_threshold).sum(-1).sum(-1)
# return upscaled_masks, scores, stability_scores, areas, masks

return upscaled_masks, scores, masks
# return upscaled_masks, scores, masks
jaegukhyun marked this conversation as resolved.
Show resolved Hide resolved

def _embed_points(self, point_coords: Tensor, point_labels: Tensor) -> Tensor:
"""Embed sparse input prompts.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

from attr import attrs

from otx.algorithms.common.configs import BaseConfig
from otx.algorithms.common.configs import BaseConfig, POTQuantizationPreset
from otx.api.configuration.elements import (
ParameterGroup,
add_parameter_group,
boolean_attribute,
configurable_boolean,
configurable_float,
configurable_integer,
selectable,
string_attribute,
)
from otx.api.configuration.model_lifecycle import ModelLifecycle
Expand Down Expand Up @@ -95,5 +97,20 @@ class __Postprocessing(ParameterGroup):
affects_outcome_of=ModelLifecycle.INFERENCE,
)

@attrs
class __POTParameter(BaseConfig.BasePOTParameter):
header = string_attribute("POT Parameters")
description = header
visible_in_ui = boolean_attribute(False)

preset = selectable(
default_value=POTQuantizationPreset.MIXED,
header="Preset",
description="Quantization preset that defines quantization scheme",
editable=True,
visible_in_ui=True,
)

learning_parameters = add_parameter_group(__LearningParameters)
postprocessing = add_parameter_group(__Postprocessing)
pot_parameters = add_parameter_group(__POTParameter)
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ pot_parameters:
affects_outcome_of: NONE
auto_hpo_state: not_possible
auto_hpo_value: null
default_value: Performance
default_value: Mixed
description: Quantization preset that defines quantization scheme
editable: true
enum_name: POTQuantizationPreset
Expand All @@ -162,7 +162,7 @@ pot_parameters:
operator: AND
rules: []
type: UI_RULES
value: Performance
value: Mixed
visible_in_ui: true
warning: null
stat_subset_size:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
dataset:
task: visual_prompting
train_batch_size: 2
train_batch_size: 4
val_batch_size: 1
test_batch_size: 1
num_workers: 4
Expand Down Expand Up @@ -35,7 +35,7 @@ model:

optimizer:
name: Adam
lr: 0.0001
lr: 0.000001

callback:
checkpoint: # arguments for ModelCheckpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,70 +85,14 @@ learning_parameters:
visible_in_ui: true
warning: null
auto_hpo_state: NOT_POSSIBLE
nncf_optimization:
description: Optimization by NNCF
enable_pruning:
affects_outcome_of: NONE
auto_hpo_state: not_possible
auto_hpo_value: null
default_value: false
description: Enable filter pruning algorithm
editable: true
header: Enable filter pruning algorithm
type: BOOLEAN
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
value: false
visible_in_ui: true
warning: null
enable_quantization:
affects_outcome_of: NONE
auto_hpo_state: not_possible
auto_hpo_value: null
default_value: true
description: Enable quantization algorithm
editable: true
header: Enable quantization algorithm
type: BOOLEAN
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
value: true
visible_in_ui: true
warning: null
header: Optimization by NNCF
pruning_supported:
affects_outcome_of: TRAINING
auto_hpo_state: not_possible
auto_hpo_value: null
default_value: false
description: Whether filter pruning is supported
editable: false
header: Whether filter pruning is supported
type: BOOLEAN
ui_rules:
action: DISABLE_EDITING
operator: AND
rules: []
type: UI_RULES
value: false
visible_in_ui: false
warning: null
type: PARAMETER_GROUP
visible_in_ui: true
pot_parameters:
description: POT Parameters
header: POT Parameters
preset:
affects_outcome_of: NONE
auto_hpo_state: not_possible
auto_hpo_value: null
default_value: Performance
default_value: Mixed
description: Quantization preset that defines quantization scheme
editable: true
enum_name: POTQuantizationPreset
Expand All @@ -162,7 +106,7 @@ pot_parameters:
operator: AND
rules: []
type: UI_RULES
value: Performance
value: Mixed
visible_in_ui: true
warning: null
stat_subset_size:
Expand Down
3 changes: 1 addition & 2 deletions src/otx/algorithms/visual_prompting/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,8 @@ def _export_to_onnx(self, onnx_path: Dict[str, str]):
"point_labels": torch.randint(low=0, high=4, size=(1, 2), dtype=torch.float),
"mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
"has_mask_input": torch.tensor([[1]], dtype=torch.float),
"orig_size": torch.tensor([[height, width]], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
output_names = ["iou_predictions", "low_res_masks"]
model_to_export = self.model

with warnings.catch_warnings():
Expand Down
Loading