diff --git a/src/otx/algo/classification/backbones/vision_transformer.py b/src/otx/algo/classification/backbones/vision_transformer.py index 94a72a0a50b..27acb2549e7 100644 --- a/src/otx/algo/classification/backbones/vision_transformer.py +++ b/src/otx/algo/classification/backbones/vision_transformer.py @@ -87,6 +87,7 @@ class VisionTransformer(BaseModule): norm_layer: Normalization layer. act_layer: MLP activation layer. block_fn: Transformer block layer. + lora: Enable LoRA training. """ arch_zoo = { # noqa: RUF012 diff --git a/src/otx/algo/classification/vit.py b/src/otx/algo/classification/vit.py index 877adb5e743..86d05b71218 100644 --- a/src/otx/algo/classification/vit.py +++ b/src/otx/algo/classification/vit.py @@ -577,6 +577,22 @@ def _customize_outputs( labels=preds, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, 224, 224), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=True, # NOTE: This should be done via onnx + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + ) + class VisionTransformerForHLabelCls(ForwardExplainMixInForViT, OTXHlabelClsModel): """DeitTiny Model for hierarchical label classification task.""" @@ -735,3 +751,19 @@ def _convert_pred_entity_to_compute_metric( "preds": pred_result, "target": torch.stack(inputs.labels), } + + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter( + task_level_export_parameters=self._export_parameters, + input_size=(1, 3, 224, 224), + mean=(123.675, 116.28, 103.53), + std=(58.395, 57.12, 57.375), + resize_mode="standard", + pad_value=0, + swap_rgb=False, + via_onnx=True, # NOTE: This should be done via onnx + onnx_export_configuration=None, + output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, + )