Skip to content

Commit

Permalink
Merge pull request #5279 from voxel51/custom-metrics
Browse files Browse the repository at this point in the history
Add support for custom evaluation metrics
  • Loading branch information
brimoor authored Jan 23, 2025
2 parents dad1c09 + e05cc17 commit e7f6d33
Show file tree
Hide file tree
Showing 14 changed files with 710 additions and 99 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import _ from "lodash";
import { Dialog } from "@fiftyone/components";
import { editingFieldAtom, view } from "@fiftyone/state";
import {
Expand Down Expand Up @@ -465,6 +466,7 @@ export default function Evaluation(props: EvaluationProps) {
: false,
hide: !showTpFpFn,
},
...formatCustomMetricRows(evaluation, compareEvaluation),
];

const perClassPerformance = {};
Expand Down Expand Up @@ -1756,3 +1758,54 @@ type PLOT_CONFIG_TYPE = {
type PLOT_CONFIG_DIALOG_TYPE = PLOT_CONFIG_TYPE & {
open?: boolean;
};

type CustomMetric = {
label: string;
key: any;
value: any;
lower_is_better: boolean;
};

type CustomMetrics = {
[operatorUri: string]: CustomMetric;
};

type SummaryRow = {
id: string;
property: string;
value: any;
compareValue: any;
lesserIsBetter: boolean;
filterable: boolean;
active: boolean;
hide: boolean;
};

function formatCustomMetricRows(evaluationMetrics, comparisonMetrics) {
const results = [] as SummaryRow[];
const customMetrics = _.get(
evaluationMetrics,
"custom_metrics",
{}
) as CustomMetrics;
for (const [operatorUri, customMetric] of Object.entries(customMetrics)) {
const compareValue = _.get(
comparisonMetrics,
`custom_metrics.${operatorUri}.value`,
null
);
const hasOneValue = customMetric.value !== null || compareValue !== null;

results.push({
id: operatorUri,
property: customMetric.label,
value: customMetric.value,
compareValue,
lesserIsBetter: customMetric.lower_is_better,
filterable: false,
active: false,
hide: !hasOneValue,
});
}
return results;
}
1 change: 1 addition & 0 deletions fiftyone/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
list_operators,
operator_exists,
)
from .evaluation_metric import EvaluationMetricConfig, EvaluationMetric
from .executor import (
execute_operator,
ExecutionContext,
Expand Down
133 changes: 133 additions & 0 deletions fiftyone/operators/evaluation_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""
Evaluation metric operators.
| Copyright 2017-2025, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""

from .operator import Operator, OperatorConfig


class EvaluationMetricConfig(OperatorConfig):
"""Configuration class for evaluation metrics.
Args:
name: the name of the evaluation metric
label (name): a label for the evaluation metric
description (None): a description of the evaluation metric
eval_types (None): an optional list of evaluation method types that
this metric supports
aggregate_key (None): an optional key under which to store the metric's
aggregate value. This is used, for example, by
:meth:`metrics() <fiftyone.utils.eval.base.BaseEvaluationResults.metrics>`.
By default, the metric's ``name`` is used as its key
lower_is_better (True): whether lower values of the metric are better
**kwargs: other kwargs for :class:`fiftyone.operators.OperatorConfig`
"""

def __init__(
self,
name,
label=None,
description=None,
eval_types=None,
aggregate_key=None,
lower_is_better=True,
**kwargs,
):
super().__init__(name, label=label, description=description, **kwargs)
self.eval_types = eval_types
self.aggregate_key = aggregate_key
self.lower_is_better = lower_is_better


class EvaluationMetric(Operator):
"""Base class for evaluation metric operators."""

def get_parameters(self, ctx, inputs):
"""Defines any necessary properties to collect this evaluation metric's
parameters from a user during prompting.
Args:
ctx: an :class:`fiftyone.operators.ExecutionContext`
inputs: a :class:`fiftyone.operators.types.Property`
"""
pass

def parse_parameters(self, ctx, params):
"""Performs any necessary execution-time formatting to this evaluation
metric's parameters.
Args:
ctx: an :class:`fiftyone.operators.ExecutionContext`
params: a params dict
"""
pass

def compute(self, samples, results, **kwargs):
"""Computes the evaluation metric for the given collection.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
results: an :class:`fiftyone.core.evaluation.EvaluationResults`
**kwargs: arbitrary metric-specific parameters
Returns:
an optional aggregate metric value to store on the results
"""
raise NotImplementedError("Subclass must implement compute()")

def get_fields(self, samples, config, eval_key):
"""Lists the fields that were populated by the evaluation metric with
the given key, if any.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
config: an :class:`fiftyone.core.evaluation.EvaluationMethodConfig`
eval_key: an evaluation key
Returns:
a list of fields
"""
return []

def rename(self, samples, config, eval_key, new_eval_key):
"""Performs any necessary operations required to rename this evaluation
metric's key.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
config: an :class:`fiftyone.core.evaluation.EvaluationMethodConfig`
eval_key: an evaluation key
new_eval_key: a new evaluation key
"""
dataset = samples._dataset
for metric_field in self.get_fields(samples, config, eval_key):
metric_field, is_frame_field = samples._handle_frame_field(
metric_field
)
new_metric_field = metric_field.replace(eval_key, new_eval_key, 1)
if is_frame_field:
dataset.rename_frame_field(metric_field, new_metric_field)
else:
dataset.rename_sample_field(metric_field, new_metric_field)

def cleanup(self, samples, config, eval_key):
"""Cleans up the results of the evaluation metric with the given key
from the collection.
Args:
samples: a :class:`fiftyone.core.collections.SampleCollection`
config: an :class:`fiftyone.core.evaluation.EvaluationMethodConfig`
eval_key: an evaluation key
"""
dataset = samples._dataset
for metric_field in self.get_fields(samples, config, eval_key):
metric_field, is_frame_field = samples._handle_frame_field(
metric_field
)
if is_frame_field:
dataset.delete_frame_field(metric_field, error_level=1)
else:
dataset.delete_sample_field(metric_field, error_level=1)
15 changes: 10 additions & 5 deletions fiftyone/operators/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
|
"""

import eta.core.utils as etau

from fiftyone.operators.panel import Panel
import fiftyone.plugins.context as fopc

Expand Down Expand Up @@ -41,7 +43,8 @@ def list_operators(enabled=True, builtin="all", type=None):
builtin ("all"): whether to include only builtin operators (True) or
only non-builtin operators (False) or all operators ("all")
type (None): whether to include only ``"panel"`` or ``"operator"`` type
operators
operators, or a specific :class:`fiftyone.operators.Operator`
subclass to restrict to
Returns:
a list of :class:`fiftyone.operators.Operator` instances
Expand Down Expand Up @@ -85,7 +88,8 @@ def list_operators(self, builtin=None, type=None):
builtin (None): whether to include only builtin operators (True) or
only non-builtin operators (False)
type (None): whether to include only ``"panel"`` or ``"operator"``
type operators
type operators, or a specific
:class:`fiftyone.operators.Operator` subclass to restrict to
Returns:
a list of :class:`fiftyone.operators.Operator` instances
Expand All @@ -104,9 +108,10 @@ def list_operators(self, builtin=None, type=None):
elif type == "operator":
operators = [op for op in operators if not isinstance(op, Panel)]
elif type is not None:
raise ValueError(
f"Unsupported type='{type}'. The supported values are ('panel', 'operator')"
)
if etau.is_str(type):
type = etau.get_class(type)

operators = [op for op in operators if isinstance(op, type)]

return operators

Expand Down
13 changes: 12 additions & 1 deletion fiftyone/utils/eval/activitynet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class ActivityNetEvaluationConfig(DetectionEvaluationConfig):
that mAP and PR curves can be generated
iou_threshs (None): a list of IoU thresholds to use when computing mAP
and PR curves. Only applicable when ``compute_mAP`` is True
custom_metrics (None): an optional list of custom metrics to compute
or dict mapping metric names to kwargs dicts
"""

def __init__(
Expand All @@ -50,10 +52,16 @@ def __init__(
classwise=None,
compute_mAP=False,
iou_threshs=None,
custom_metrics=None,
**kwargs,
):
super().__init__(
pred_field, gt_field, iou=iou, classwise=classwise, **kwargs
pred_field,
gt_field,
iou=iou,
classwise=classwise,
custom_metrics=custom_metrics,
**kwargs,
)

if compute_mAP and iou_threshs is None:
Expand Down Expand Up @@ -323,6 +331,7 @@ class ActivityNetDetectionResults(DetectionResults):
``num_iou_threshs x num_classes x num_recall``
missing (None): a missing label string. Any unmatched segments are
given this label for evaluation purposes
custom_metrics (None): an optional dict of custom metrics
backend (None): a :class:`ActivityNetEvaluation` backend
"""

Expand All @@ -339,6 +348,7 @@ def __init__(
classes,
thresholds=None,
missing=None,
custom_metrics=None,
backend=None,
):
super().__init__(
Expand All @@ -348,6 +358,7 @@ def __init__(
matches,
classes=classes,
missing=missing,
custom_metrics=custom_metrics,
backend=backend,
)

Expand Down
Loading

0 comments on commit e7f6d33

Please sign in to comment.