Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reference info path for zero-shot learning #3354

Merged
merged 10 commits into from
Apr 22, 2024
Prev Previous commit
Next Next commit
Updates for current develop
  • Loading branch information
sungchul2 committed Apr 18, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
commit ff3b02eb2c7cdf96773d863f087dd1b1805ae89f
21 changes: 7 additions & 14 deletions src/otx/algo/visual_prompting/zero_shot_segment_anything.py
Original file line number Diff line number Diff line change
@@ -6,15 +6,14 @@
from __future__ import annotations

import logging as log
import os
import pickle
from collections import defaultdict
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import torch
import pickle
import torchvision.transforms.v2 as tvt_v2
from datumaro import Polygon as dmPolygon
from torch import LongTensor, Tensor, nn
@@ -625,7 +624,7 @@ def _decide_cascade_results(
class OTXZeroShotSegmentAnything(OTXZeroShotVisualPromptingModel):
"""Zero-Shot Visual Prompting model."""

def __init__(
def __init__( # noqa: PLR0913
self,
backbone: Literal["tiny_vit", "vit_b"],
num_classes: int = 0,
@@ -888,17 +887,11 @@ def save_reference_info(self, default_root_dir: Path | str) -> None:
# save reference info
path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
Path.mkdir(Path(path_reference_info).parent, parents=True, exist_ok=True)
if isinstance(self, OTXZeroShotVisualPromptingModel):
# with torch model
torch.save(reference_info, path_reference_info)
pickle.dump(
{k: v.numpy() for k, v in reference_info.items()},
Path.open(Path(str(path_reference_info).replace(".pt", ".pickle")), "wb"),
)
else:
# with ov model
torch.save({k: torch.as_tensor(v) for k, v in reference_info.items()}, path_reference_info)
pickle.dump(reference_info, Path.open(Path(str(path_reference_info).replace(".pt", ".pickle")), "wb"))
torch.save(reference_info, path_reference_info)
pickle.dump(
{k: v.numpy() for k, v in reference_info.items()},
Path.open(Path(str(path_reference_info).replace(".pt", ".pickle")), "wb"),
)
log.info(f"Saved reference info at {path_reference_info}.")

def load_reference_info(self, default_root_dir: Path | str, device: str | torch.device = "cpu") -> bool:
241 changes: 216 additions & 25 deletions src/otx/core/model/visual_prompting.py
Original file line number Diff line number Diff line change
@@ -3,15 +3,10 @@
#
"""Class definition for visual prompting model entity used in OTX."""

# TODO(vinnamki): There are so many mypy errors. Resolve them after refactoring visual prompting code.
# mypy: ignore-errors

from __future__ import annotations

import logging as log
import os
import pickle
import time
from collections import defaultdict
from copy import deepcopy
from functools import partial
@@ -342,7 +337,7 @@ def on_train_start(self) -> None:
def on_test_start(self) -> None:
"""Load previously saved reference info."""
super().on_test_start()
if not self.model.load_reference_info(self.trainer.default_root_dir, self.device):
if not self.load_reference_info(self.trainer.default_root_dir, self.device):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.trainer.fit_loop.run()
@@ -351,11 +346,11 @@ def on_test_start(self) -> None:
# to set _combined_loader
self.trainer._evaluation_loop.setup_data() # noqa: SLF001
self.trainer._evaluation_loop.reset() # noqa: SLF001
self.model.load_reference_info(self.trainer.default_root_dir, self.device)
self.load_reference_info(self.trainer.default_root_dir, self.device)

def on_predict_start(self) -> None:
"""Load previously saved reference info."""
if not self.model.load_reference_info(self.trainer.default_root_dir, self.device):
if not self.load_reference_info(self.trainer.default_root_dir, self.device):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.trainer.fit_loop.run()
@@ -364,15 +359,15 @@ def on_predict_start(self) -> None:
# to set _combined_loader
self.trainer._evaluation_loop.setup_data() # noqa: SLF001
self.trainer._evaluation_loop.reset() # noqa: SLF001
self.model.load_reference_info(self.trainer.default_root_dir, self.device)
self.load_reference_info(self.trainer.default_root_dir, self.device)

def on_train_epoch_start(self) -> None:
"""Skip on_train_epoch_start unused in zero-shot visual prompting."""

def on_train_epoch_end(self) -> None:
"""Skip on_train_epoch_end unused in zero-shot visual prompting."""
if self.save_outputs:
self.model.save_reference_info(self.trainer.default_root_dir)
self.save_reference_info(self.trainer.default_root_dir)

def on_validation_epoch_start(self) -> None:
"""Skip on_validation_epoch_start unused in zero-shot visual prompting."""
@@ -393,7 +388,7 @@ def training_step(

def validation_step(
self,
inputs: VisualPromptingBatchDataEntity | ZeroShotVisualPromptingBatchDataEntity,
inputs: ZeroShotVisualPromptingBatchDataEntity,
batch_idx: int,
) -> None:
"""Skip validation_step unused in zero-shot visual prompting."""
@@ -406,7 +401,7 @@ def test_step(
"""Perform a single test step on a batch of data from the test set.

Args:
inputs (VisualPromptingBatchDataEntity): The input data for the test step.
inputs (ZeroShotVisualPromptingBatchDataEntity): The input data for the test step.
batch_idx (int): The index of the current batch.

Raises:
@@ -566,7 +561,7 @@ def _customize_outputs(
self,
outputs: Any, # noqa: ANN401
inputs: VisualPromptingBatchDataEntity, # type: ignore[override]
) -> VisualPromptingBatchPredEntity | OTXBatchLossEntity:
) -> VisualPromptingBatchPredEntity:
"""Customize OTX output batch data entity if needed for model."""
masks: list[tv_tensors.Mask] = []
scores: list[torch.Tensor] = []
@@ -717,7 +712,12 @@ def _set_label_info(self, label_info: LabelInfo | list[str]) -> None:
return


class OVZeroShotVisualPromptingModel(OVVisualPromptingModel):
class OVZeroShotVisualPromptingModel(
OVModel[
ZeroShotVisualPromptingBatchDataEntity,
ZeroShotVisualPromptingBatchPredEntity,
],
):
"""Zero-shot visual prompting model compatible for OpenVINO IR inference.

It can only consume OpenVINO IR model path and create the OTX zero-shot visual prompting model compatible
@@ -738,6 +738,18 @@ def __init__(
save_outputs: bool = True,
**kwargs,
) -> None:
if async_inference:
log.warning(
"Async inference is not supported for visual prompting models. Setting async_inference to False.",
)
async_inference = False

basename: str = Path(model_name).name
model_type_name: str = "_".join(basename.split("_")[:2])
self.model_names: dict[str, str] = {
module: model_name.replace(basename, f"{model_type_name}_{module}.xml")
for module in ["image_encoder", "decoder"]
}
super().__init__(
model_name=model_name,
model_type=model_type,
@@ -756,6 +768,29 @@ def __init__(

self.initialize_reference_info()

def _create_model(self) -> dict[str, Model]:
"""Create a OV model with help of Model API."""
from openvino.model_api.adapters import OpenvinoAdapter, create_core, get_user_config
from openvino.model_api.models import Model

ov_models: dict[str, Model] = {}

plugin_config = get_user_config("AUTO", str(self.num_requests), "AUTO")
if self.use_throughput_mode:
plugin_config["PERFORMANCE_HINT"] = "THROUGHPUT"

model_parameters = {"decoder": {"input_layouts": "image_embeddings:NCHW"}}
for module in ["image_encoder", "decoder"]:
model_adapter = OpenvinoAdapter(
core=create_core(),
model=self.model_names.get(module),
model_parameters=model_parameters.get(module, {}),
max_num_requests=self.num_requests,
plugin_config=plugin_config,
)
ov_models[module] = Model.create_model(model_adapter, module, configuration=self.model_api_configuration)
return ov_models

def learn(
self,
inputs: ZeroShotVisualPromptingBatchDataEntity,
@@ -1004,6 +1039,90 @@ def _customize_outputs( # type: ignore[override]
labels=labels,
)

def optimize( # type: ignore[override]
self,
output_dir: Path,
data_module: OTXDataModule,
ptq_config: dict[str, Any] | None = None,
) -> dict[str, Path]:
"""Runs NNCF quantization."""
import nncf
import openvino

def check_if_quantized(model: openvino.Model) -> bool:
"""Checks if OpenVINO model is already quantized."""
nodes = model.get_ops()
return any(op.get_type_name() == "FakeQuantize" for op in nodes)

def transform_fn(
data_batch: VisualPromptingBatchDataEntity | ZeroShotVisualPromptingBatchDataEntity,
module: Literal["image_encoder", "decoder"],
) -> np.ndarray | dict[str, Any]:
images, _, prompts = self._customize_inputs(data_batch) # type: ignore[arg-type]

image = images[0]["images"] # use only the first image
if module == "image_encoder":
# resize
resized_image = self.model["image_encoder"].resize(
image[0],
(self.model["image_encoder"].w, self.model["image_encoder"].h),
)

# pad image if necessary because `fit_to_window` resize for python in modelapi doesn't support pad
pad_w = max(0, self.model["image_encoder"].w - resized_image.shape[1])
pad_h = max(0, self.model["image_encoder"].h - resized_image.shape[0])
resized_image = np.pad(
resized_image,
((0, pad_h), (0, pad_w), (0, 0)),
mode="constant",
constant_values=0,
)

# normalization
resized_image = self.model["image_encoder"].input_transform(resized_image)

# change layout from HWC to NCHW
return self.model["image_encoder"]._change_layout(resized_image) # noqa: SLF001

# obtain image embeddings from image encoder
image_embeddings = self.model["image_encoder"].infer_sync(image)
# use only the first prompt
prompt_for_optim = next(iter(prompts[0].values()))[0] if isinstance(prompts[0], dict) else prompts[0][0] # type: ignore[attr-defined]
prompt_for_optim.pop("label")
prompt_for_optim.update(**image_embeddings)
return prompt_for_optim

output_model_paths: dict[str, Path] = {}
for module in ["image_encoder", "decoder"]:
output_model_path = output_dir / (self._OPTIMIZED_MODEL_BASE_NAME + f"_{module}.xml")

ov_model = openvino.Core().read_model(self.model_names[module])
if check_if_quantized(ov_model):
msg = "Model is already optimized by PTQ"
raise RuntimeError(msg)

train_dataset = data_module.train_dataloader()

ptq_config_from_ir = self._read_ptq_config_from_ir(ov_model)
if ptq_config is not None:
ptq_config_from_ir.update(ptq_config)
ptq_config = ptq_config_from_ir
else:
ptq_config = ptq_config_from_ir

quantization_dataset = nncf.Dataset(train_dataset, partial(transform_fn, module=module)) # type: ignore[attr-defined]

compressed_model = nncf.quantize( # type: ignore[attr-defined]
ov_model,
quantization_dataset,
**ptq_config,
)

openvino.save_model(compressed_model, output_model_path)
output_model_paths[module] = output_model_path

return output_model_paths

######################################
# Preprocess #
######################################
@@ -1124,6 +1243,19 @@ def expand_reference_info(self, new_largest_label: int) -> None:
diff = new_largest_label - cur_largest_label
self.reference_feats = np.pad(self.reference_feats, ((0, diff), (0, 0), (0, 0)), constant_values=0.0)

def save_reference_info(self, default_root_dir: Path | str) -> None:
"""Save reference info."""
reference_info = {
"reference_feats": self.reference_feats,
"used_indices": self.used_indices,
}
# save reference info
path_reference_info: Path = Path(default_root_dir) / self.reference_info_dir / "reference_info.pt"
Path.mkdir(Path(path_reference_info).parent, parents=True, exist_ok=True)
torch.save({k: torch.as_tensor(v) for k, v in reference_info.items()}, path_reference_info)
pickle.dump(reference_info, Path.open(Path(str(path_reference_info).replace(".pt", ".pickle")), "wb"))
log.info(f"Saved reference info at {path_reference_info}.")

def _generate_masked_features(
self,
feats: np.ndarray,
@@ -1371,6 +1503,76 @@ def _topk_numpy(self, x: np.ndarray, k: int, axis: int = -1, largest: bool = Tru
def _reset_prediction_layer(self, num_classes: int) -> None:
return

def _create_label_info_from_ov_ir(self) -> LabelInfo:
"""Create NullLabelInfo since Visual Prompting tasks has no use of label information."""
return NullLabelInfo()

def _set_label_info(self, label_info: LabelInfo | list[str]) -> None:
"""Visual prompting task does not check label_info equivalance.

This is because it always has NullLabelInfo.
"""
return

######################################
# Lit Module #
######################################
def on_train_start(self) -> None:
"""Initialize reference infos before learn."""
self.initialize_reference_info()

def on_test_start(self) -> None:
"""Load previously saved reference info."""
super().on_test_start()
if not self.load_reference_info(self.trainer.default_root_dir, self.device):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.trainer.fit_loop.run()
# to use infer logic
self.training = False
# to set _combined_loader
self.trainer._evaluation_loop.setup_data() # noqa: SLF001
self.trainer._evaluation_loop.reset() # noqa: SLF001
self.load_reference_info(self.trainer.default_root_dir, self.device)

def on_predict_start(self) -> None:
"""Load previously saved reference info."""
if not self.load_reference_info(self.trainer.default_root_dir, self.device):
log.warning("No reference info found. `Learn` will be automatically executed first.")
self.trainer.lightning_module.automatic_optimization = False
self.trainer.fit_loop.run()
# to use infer logic
self.training = False
# to set _combined_loader
self.trainer._evaluation_loop.setup_data() # noqa: SLF001
self.trainer._evaluation_loop.reset() # noqa: SLF001
self.load_reference_info(self.trainer.default_root_dir, self.device)

def on_train_epoch_start(self) -> None:
"""Skip on_train_epoch_start unused in zero-shot visual prompting."""

def on_train_epoch_end(self) -> None:
"""Skip on_train_epoch_end unused in zero-shot visual prompting."""
if self.save_outputs:
self.save_reference_info(self.trainer.default_root_dir)

def on_validation_epoch_start(self) -> None:
"""Skip on_validation_epoch_start unused in zero-shot visual prompting."""

def on_validation_epoch_end(self) -> None:
"""Skip on_validation_epoch_end unused in zero-shot visual prompting."""

def configure_optimizers(self) -> None: # type: ignore[override]
"""Skip configure_optimizers unused in zero-shot visual prompting."""

def training_step(
self,
inputs: ZeroShotVisualPromptingBatchDataEntity, # type: ignore[override]
batch_idx: int,
) -> Tensor:
"""Skip training_step unused in zero-shot visual prompting."""
self.forward(inputs)

def validation_step(
self,
inputs: ZeroShotVisualPromptingBatchDataEntity,
@@ -1401,14 +1603,3 @@ def _convert_pred_entity_to_compute_metric(
) -> MetricInput:
"""Convert the prediction entity to the format required by the compute metric function."""
return _convert_pred_entity_to_compute_metric(preds=preds, inputs=inputs)

def _create_label_info_from_ov_ir(self) -> LabelInfo:
"""Create NullLabelInfo since Visual Prompting tasks has no use of label information."""
return NullLabelInfo()

def _set_label_info(self, label_info: LabelInfo | list[str]) -> None:
"""Visual prompting task does not check label_info equivalance.

This is because it always has NullLabelInfo.
"""
return
Loading
Loading