Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce the accuracy and change meta_info -> label_info #2994

Merged
merged 14 commits into from
Feb 28, 2024
Prev Previous commit
Next Next commit
Edit the NamedConfusionMatrix
  • Loading branch information
sungmanc committed Feb 28, 2024
commit 1b3ff689e2b28f59cc011324a952d5b405d88424
50 changes: 37 additions & 13 deletions src/otx/core/metrics/accuracy.py
Original file line number Diff line number Diff line change
@@ -19,24 +19,48 @@
from otx.core.data.dataset.base import LabelInfo


class NamedConfusionMatrix(nn.Module):
class NamedConfusionMatrix(ConfusionMatrix):
"""Named Confusion Matrix to add row, col label names."""

def __init__(
self,
task: str,
num_classes: int,
def __new__(
cls,
col_names: list[str],
row_names: list[str],
):
super().__init__()
self.conf_matrix = ConfusionMatrix(task=task, num_classes=num_classes)
self.col_names = col_names
self.row_names = row_names
task: Literal["binary", "multiclass", "multilabel"],
threshold: float = 0.5,
num_classes: int | None = None,
num_labels: int | None = None,
normalize: Literal["true", "pred", "all", "none"] | None = None,
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: Any, # noqa: ANN401
) -> NamedConfusionMatrix:
"""Construct the NamedConfusionMatrix."""
confusion_metric = super().__new__(
cls,
task=task,
threshold=threshold,
num_classes=num_classes,
num_labels=num_labels,
normalize=normalize,
ignore_index=ignore_index,
validate_args=validate_args,
**kwargs,
)

def __call__(self, *args: object, **kwargs: object):
"""Call function of the Named Confusion Matrix."""
return self.conf_matrix(*args, **kwargs)
confusion_metric.col_names = col_names
confusion_metric.row_names = row_names
return confusion_metric

@property
def col_names(self) -> list[str]:
"""The names of colum."""
return self.col_names

@property
def row_names(self) -> list[str]:
"""The names of row."""
return self.row_names


class CustomAccuracy(Metric):
Loading