Skip to content

Commit

Permalink
restore
Browse files Browse the repository at this point in the history
  • Loading branch information
wonjuleee committed Jul 15, 2024
1 parent 92ef6d7 commit cad9861
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class VisionTransformer(BaseModule):
norm_layer: Normalization layer.
act_layer: MLP activation layer.
block_fn: Transformer block layer.
lora: Enable LoRA training.
"""

arch_zoo = { # noqa: RUF012
Expand Down
32 changes: 32 additions & 0 deletions src/otx/algo/classification/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,22 @@ def _customize_outputs(
labels=preds,
)

@property
def _exporter(self) -> OTXModelExporter:
"""Creates OTXModelExporter object that can export the model."""
return OTXNativeModelExporter(
task_level_export_parameters=self._export_parameters,
input_size=(1, 3, 224, 224),
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
resize_mode="standard",
pad_value=0,
swap_rgb=False,
via_onnx=True, # NOTE: This should be done via onnx
onnx_export_configuration=None,
output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None,
)


class VisionTransformerForHLabelCls(ForwardExplainMixInForViT, OTXHlabelClsModel):
"""DeitTiny Model for hierarchical label classification task."""
Expand Down Expand Up @@ -735,3 +751,19 @@ def _convert_pred_entity_to_compute_metric(
"preds": pred_result,
"target": torch.stack(inputs.labels),
}

@property
def _exporter(self) -> OTXModelExporter:
"""Creates OTXModelExporter object that can export the model."""
return OTXNativeModelExporter(
task_level_export_parameters=self._export_parameters,
input_size=(1, 3, 224, 224),
mean=(123.675, 116.28, 103.53),
std=(58.395, 57.12, 57.375),
resize_mode="standard",
pad_value=0,
swap_rgb=False,
via_onnx=True, # NOTE: This should be done via onnx
onnx_export_configuration=None,
output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None,
)

0 comments on commit cad9861

Please sign in to comment.