Skip to content

Commit

Permalink
Update visual prompting on 1.x (#3038)
Browse files Browse the repository at this point in the history
* Refine v1 perf benchmark to align with v2 (#3006)

* Add --benchmark-type w/ accuracy|efficiency|all options

* Add perf-benchmark tox env

* Refine perf workflow to align with v2

* Add dummy perf tests for visual prompting

* Fix weekly workflow

---------
Signed-off-by: Songki Choi <songki.choi@intel.com>

* Update docstring

* Update overlapped region refinement

* Update templates

* Remove `PromptGetter` during ov inference

* Fix tests

* For unittest coverage

---------

Co-authored-by: Songki Choi <songki.choi@intel.com>
  • Loading branch information
sungchul2 and goodsong81 authored Mar 8, 2024
1 parent 8c669aa commit d1a27ab
Show file tree
Hide file tree
Showing 28 changed files with 680 additions and 679 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# See the License for the specific language governing permissions
# and limitations under the License.

from .openvino_models import Decoder, ImageEncoder, PromptGetter # noqa: F401
from .openvino_models import Decoder, ImageEncoder # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def parameters(cls) -> Dict[str, Any]: # noqa: D102
{
"resize_type": StringValue(default_value="fit_to_window"),
"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048),
"downsizing": NumericalValue(value_type=int, default_value=64, min=1, max=1024),
}
)
return parameters
Expand All @@ -57,38 +58,6 @@ def preprocess(
return dict_inputs, meta


class PromptGetter(ImageModel):
"""PromptGetter class for zero-shot visual prompting of openvino model wrapper."""

__model__ = "prompt_getter"

def __init__(self, inference_adapter, configuration=None, preload=False):
super().__init__(inference_adapter, configuration, preload)

@classmethod
def parameters(cls) -> Dict[str, Any]: # noqa: D102
parameters = super().parameters()
parameters.update({"image_size": NumericalValue(value_type=int, default_value=1024, min=0, max=2048)})
parameters.update({"sim_threshold": NumericalValue(value_type=float, default_value=0.5, min=0, max=1)})
parameters.update({"num_bg_points": NumericalValue(value_type=int, default_value=1, min=0, max=1024)})
parameters.update(
{"default_threshold_reference": NumericalValue(value_type=float, default_value=0.3, min=-1.0, max=1.0)}
)
return parameters

def _get_inputs(self):
"""Defines the model inputs for images and additional info."""
image_blob_names, image_info_blob_names = [], []
for name, metadata in self.inputs.items():
if len(metadata.shape) == 4:
image_blob_names.append(name)
else:
image_info_blob_names.append(name)
if not image_blob_names:
self.raise_error("Failed to identify the input for the image: no 4D input layer found")
return image_blob_names, image_info_blob_names


class Decoder(SegmentationModel):
"""Decoder class for visual prompting of openvino model wrapper."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,7 @@ def postprocess_masks(cls, masks: Tensor, input_size: int, orig_size: Tensor) ->
Args:
masks (Tensor): A batch of predicted masks with shape Bx1xHxW.
input_size (int): The size of the image input to the model, in (H, W) format.
Used to remove padding.
input_size (int): The size of the image input to the model. Used to remove padding.
orig_size (Tensor): The original image size with shape Bx2.
Returns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_prompt_candidates(

total_points_scores: Dict[int, Tensor] = {}
total_bg_coords: Dict[int, Tensor] = {}
for label in map(int, used_indices[0]):
for label in map(int, used_indices):
points_scores, bg_coords = self(
image_embeddings=image_embeddings,
reference_feat=reference_feats[label],
Expand Down Expand Up @@ -248,7 +248,7 @@ def set_default_config(self) -> DictConfig:
def set_empty_reference_info(self) -> None:
"""Set empty reference information."""
reference_feats: Parameter = Parameter(torch.tensor([], dtype=torch.float32), requires_grad=False)
used_indices: Parameter = Parameter(torch.tensor([[]], dtype=torch.int64), requires_grad=False)
used_indices: Parameter = Parameter(torch.tensor([], dtype=torch.int64), requires_grad=False)
self.reference_info = ParameterDict(
{
"reference_feats": reference_feats,
Expand All @@ -260,7 +260,7 @@ def set_empty_reference_info(self) -> None:
def initialize_reference_info(self) -> None:
"""Initialize reference information."""
self.reference_info["reference_feats"] = Parameter(torch.zeros(0, 1, 256), requires_grad=False)
self.reference_info["used_indices"] = Parameter(torch.tensor([[]], dtype=torch.int64), requires_grad=False)
self.reference_info["used_indices"] = Parameter(torch.tensor([], dtype=torch.int64), requires_grad=False)
self.is_reference_info_empty = False

def expand_reference_info(self, new_largest_label: int) -> None:
Expand Down Expand Up @@ -364,7 +364,7 @@ def learn(self, batch: List[Dict[str, Any]], reset_feat: bool = False) -> Union[

self.reference_info["reference_feats"][label] = ref_feat.detach().cpu()
self.reference_info["used_indices"] = Parameter(
torch.cat((self.reference_info["used_indices"], torch.tensor([[label]])), dim=1),
torch.cat((self.reference_info["used_indices"], torch.tensor([[label]]))),
requires_grad=False,
)
ref_masks[label] = ref_mask.detach().cpu()
Expand Down Expand Up @@ -479,11 +479,20 @@ def _calculate_mask_iou(mask1: Tensor, mask2: Tensor):
overlapped_label = []
overlapped_other_label = []
for (im, mask), (jm, other_mask) in product(enumerate(masks), enumerate(other_masks)):
if _calculate_mask_iou(mask, other_mask) > threshold_iou:
_mask_iou = _calculate_mask_iou(mask, other_mask)
if _mask_iou > threshold_iou:
# compare overlapped regions between different labels and filter out the lower score
if used_points[label][im][2] > used_points[other_label][jm][2]:
overlapped_other_label.append(jm)
else:
overlapped_label.append(im)
elif _mask_iou > 0:
# refine the slightly overlapping region
overlapped_coords = torch.where(torch.logical_and(mask, other_mask))
if used_points[label][im][2] > used_points[other_label][jm][2]:
other_mask[overlapped_coords] = 0.0
else:
mask[overlapped_coords] = 0.0

for im in sorted(list(set(overlapped_label)), reverse=True):
masks.pop(im)
Expand Down Expand Up @@ -746,7 +755,7 @@ def on_predict_start(self) -> None:
def training_epoch_end(self, outputs) -> None:
"""Called in the training loop at the very end of the epoch."""
self.reference_info["used_indices"] = Parameter(
self.reference_info["used_indices"].unique().unsqueeze(0), requires_grad=False
self.reference_info["used_indices"].unique(), requires_grad=False
)
if self.config.model.save_outputs:
path_reference_info = self.path_reference_info.format(time.strftime("%Y%m%d-%H%M%S"))
Expand Down
26 changes: 4 additions & 22 deletions src/otx/algorithms/visual_prompting/configs/base/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,33 +114,15 @@ class __Postprocessing(ParameterGroup):
affects_outcome_of=ModelLifecycle.INFERENCE,
)

sim_threshold = configurable_float(
default_value=0.65,
header="Similarity threshold",
description="The threshold to filter point candidates based on similarity scores.",
min_value=0.0,
max_value=1.0,
affects_outcome_of=ModelLifecycle.INFERENCE,
)

num_bg_points = configurable_integer(
default_value=1,
header="The number of background points",
description="The number of background points to be used as negative prompts.",
downsizing = configurable_integer(
default_value=64,
header="The downsizing ratio",
description="The downsizing ratio of image encoder.",
min_value=1,
max_value=1024,
affects_outcome_of=ModelLifecycle.INFERENCE,
)

default_threshold_reference = configurable_float(
default_value=0.3,
header="Default reference threshold",
description="The threshold to get target area in the mask for reference features.",
min_value=-1.0,
max_value=1.0,
affects_outcome_of=ModelLifecycle.INFERENCE,
)

@attrs
class __POTParameter(BaseConfig.BasePOTParameter):
header = string_attribute("POT Parameters")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Initialization of Configurable Parameters for SAM Visual Prompting Task."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

from .configuration import VisualPromptingConfig # noqa: F401
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
dataset:
task: visual_prompting
train_batch_size: 1
val_batch_size: 1
test_batch_size: 1
num_workers: 4
image_size: 1024 # dimensions to which images are resized (mandatory)
normalize:
mean:
- 123.675
- 116.28
- 103.53
std:
- 58.395
- 57.12
- 57.375
offset_bbox: 0
use_point: false
use_bbox: false

model:
name: SAM
image_size: 1024
mask_threshold: 0.
return_logits: true
backbone: vit_b
freeze_image_encoder: true
freeze_prompt_encoder: true
freeze_mask_decoder: true
checkpoint: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
# just for inference
return_single_mask: false
use_stability_score: false
stability_score_offset: 1.
return_extra_metrics: false
# zero-shot
default_threshold_reference: 0.3
default_threshold_target: 0.65
save_outputs: True

# PL Trainer Args. Don't add extra parameter here.
trainer:
enable_checkpointing: false
gradient_clip_val: 0
gradient_clip_algorithm: norm
num_nodes: 1
devices: 1
enable_progress_bar: true
overfit_batches: 0.0
track_grad_norm: -1
check_val_every_n_epoch: 1 # Don't validate before extracting features.
fast_dev_run: false
accumulate_grad_batches: 1
max_epochs: 1
min_epochs: null
max_steps: -1
min_steps: null
max_time: null
limit_train_batches: 1.0
limit_val_batches: 0 # No validation
limit_test_batches: 1.0
limit_predict_batches: 1.0
val_check_interval: 1.0
log_every_n_steps: 10
accelerator: auto # <"cpu", "gpu", "tpu", "ipu", "hpu", "auto">
strategy: null
sync_batchnorm: false
precision: 32
enable_model_summary: true
num_sanity_val_steps: 0
profiler: null
benchmark: false
deterministic: false
reload_dataloaders_every_n_epochs: 0
auto_lr_find: false
replace_sampler_ddp: true
detect_anomaly: false
auto_scale_batch_size: false
plugins: null
move_metrics_to_cpu: false
multiple_trainloader_mode: max_size_cycle
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Configuration file of OTX Visual Prompting."""

# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


from attr import attrs

from otx.algorithms.visual_prompting.configs.base import VisualPromptingBaseConfig


@attrs
class VisualPromptingConfig(VisualPromptingBaseConfig):
"""Configurable parameters for Visual Prompting task."""
Loading

0 comments on commit d1a27ab

Please sign in to comment.