Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Model Evaluation panel callbacks for segmentation tasks #5332

Merged
merged 6 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1783,11 +1783,8 @@ type SummaryRow = {

function formatCustomMetricRows(evaluationMetrics, comparisonMetrics) {
const results = [] as SummaryRow[];
const customMetrics = _.get(
evaluationMetrics,
"custom_metrics",
{}
) as CustomMetrics;
const customMetrics = (_.get(evaluationMetrics, "custom_metrics", null) ||
{}) as CustomMetrics;
for (const [operatorUri, customMetric] of Object.entries(customMetrics)) {
const compareValue = _.get(
comparisonMetrics,
Expand Down
62 changes: 62 additions & 0 deletions fiftyone/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,68 @@ def is_rgb_target(target):
)


def hex_to_int(hex_str):
"""Converts a hex string like `"#ff6d04"` to a hex integer.

Args:
hex_str: a hex string

Returns:
an integer
"""
r = int(hex_str[1:3], 16)
g = int(hex_str[3:5], 16)
b = int(hex_str[5:7], 16)
return (r << 16) + (g << 8) + b


def int_to_hex(value):
"""Converts an RRGGBB integer value to hex string like `"#ff6d04"`.

Args:
value: an integer value

Returns:
a hex string
"""
r = (value >> 16) & 255
g = (value >> 8) & 255
b = value & 255
return "#%02x%02x%02x" % (r, g, b)


def rgb_array_to_int(mask):
"""Converts an RGB mask array to a 2D hex integer mask array.

Args:
mask: an RGB mask array

Returns:
a 2D integer mask array
"""
return (
np.left_shift(mask[:, :, 0], 16, dtype=int)
+ np.left_shift(mask[:, :, 1], 8, dtype=int)
+ mask[:, :, 2]
)


def int_array_to_rgb(mask):
"""Converts a 2D hex integer mask array to an RGB mask array.

Args:
mask: a 2D integer mask array

Returns:
an RGB mask array
"""
return np.stack(
[(mask >> 16) & 255, (mask >> 8) & 255, mask & 255],
axis=2,
dtype=np.uint8,
)


class EmbeddedDocumentField(mongoengine.fields.EmbeddedDocumentField, Field):
"""A field that stores instances of a given type of
:class:`fiftyone.core.odm.BaseEmbeddedDocument` object.
Expand Down
15 changes: 15 additions & 0 deletions fiftyone/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ def load_and_cache_dataset(name):
return dataset


def cache_dataset(dataset):
"""Caches the given dataset.

This method ensures that subsequent calls to
:func:`fiftyone.core.dataset.load_dataset` in async calls will return this
dataset singleton.

See :meth:`load_and_cache_dataset` for additional details.

Args:
dataset: a :class:`fiftyone.core.dataset.Dataset`
"""
_cache[dataset.name] = dataset


def change_sample_tags(sample_collection, changes):
"""Applies the changes to tags to all samples of the collection, if
necessary.
Expand Down
94 changes: 55 additions & 39 deletions fiftyone/utils/eval/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from copy import deepcopy
import logging
import inspect
import itertools
import warnings

import numpy as np
Expand Down Expand Up @@ -369,7 +370,7 @@ def evaluate_samples(
if mask_targets is not None:
if fof.is_rgb_mask_targets(mask_targets):
mask_targets = {
_hex_to_int(k): v for k, v in mask_targets.items()
fof.hex_to_int(k): v for k, v in mask_targets.items()
}

values, classes = zip(*sorted(mask_targets.items()))
Expand All @@ -385,6 +386,7 @@ def evaluate_samples(

nc = len(values)
confusion_matrix = np.zeros((nc, nc), dtype=int)
matches = []

bandwidth = self.config.bandwidth
average = self.config.average
Expand Down Expand Up @@ -427,6 +429,17 @@ def evaluate_samples(
)
sample_conf_mat += image_conf_mat

for i, j in zip(*np.nonzero(image_conf_mat)):
matches.append(
(
classes[i],
classes[j],
int(image_conf_mat[i, j]),
gt_seg.id,
pred_seg.id,
)
)

if processing_frames and save:
facc, fpre, frec = _compute_accuracy_precision_recall(
image_conf_mat, values, average
Expand Down Expand Up @@ -460,6 +473,7 @@ def evaluate_samples(
eval_key,
confusion_matrix,
classes,
matches=matches,
missing=missing,
backend=self,
)
Expand All @@ -474,6 +488,9 @@ class SegmentationResults(BaseClassificationResults):
eval_key: the evaluation key
pixel_confusion_matrix: a pixel value confusion matrix
classes: a list of class labels corresponding to the confusion matrix
matches (None): a list of
``(gt_label, pred_label, pixel_count, gt_id, pred_id)``
matches
missing (None): a missing (background) class
custom_metrics (None): an optional dict of custom metrics
backend (None): a :class:`SegmentationEvaluation` backend
Expand All @@ -486,14 +503,23 @@ def __init__(
eval_key,
pixel_confusion_matrix,
classes,
matches=None,
missing=None,
custom_metrics=None,
backend=None,
):
pixel_confusion_matrix = np.asarray(pixel_confusion_matrix)
ytrue, ypred, weights = self._parse_confusion_matrix(
pixel_confusion_matrix, classes
)

if matches is None:
ytrue, ypred, weights = self._parse_confusion_matrix(
pixel_confusion_matrix, classes
)
ytrue_ids = None
ypred_ids = None
elif matches:
ytrue, ypred, weights, ytrue_ids, ypred_ids = zip(*matches)
else:
ytrue, ypred, weights, ytrue_ids, ypred_ids = [], [], [], [], []

super().__init__(
samples,
Expand All @@ -502,6 +528,8 @@ def __init__(
ytrue,
ypred,
weights=weights,
ytrue_ids=ytrue_ids,
ypred_ids=ypred_ids,
classes=classes,
missing=missing,
custom_metrics=custom_metrics,
Expand All @@ -510,15 +538,6 @@ def __init__(

self.pixel_confusion_matrix = pixel_confusion_matrix

def attributes(self):
return [
"cls",
"pixel_confusion_matrix",
"classes",
"missing",
"custom_metrics",
]

def dice_score(self):
"""Computes the Dice score across all samples in the evaluation.

Expand All @@ -529,12 +548,31 @@ def dice_score(self):

@classmethod
def _from_dict(cls, d, samples, config, eval_key, **kwargs):
ytrue = d.get("ytrue", None)
ypred = d.get("ypred", None)
weights = d.get("weights", None)
ytrue_ids = d.get("ytrue_ids", None)
ypred_ids = d.get("ypred_ids", None)

if ytrue is not None and ypred is not None and weights is not None:
if ytrue_ids is None:
ytrue_ids = itertools.repeat(None)

if ypred_ids is None:
ypred_ids = itertools.repeat(None)

matches = list(zip(ytrue, ypred, weights, ytrue_ids, ypred_ids))
else:
# Legacy format segmentations
matches = None

return cls(
samples,
config,
eval_key,
d["pixel_confusion_matrix"],
d["classes"],
matches=matches,
missing=d.get("missing", None),
custom_metrics=d.get("custom_metrics", None),
**kwargs,
Expand Down Expand Up @@ -599,10 +637,10 @@ def _compute_pixel_confusion_matrix(
pred_mask, gt_mask, values, bandwidth=None
):
if pred_mask.ndim == 3:
pred_mask = _rgb_array_to_int(pred_mask)
pred_mask = fof.rgb_array_to_int(pred_mask)

if gt_mask.ndim == 3:
gt_mask = _rgb_array_to_int(gt_mask)
gt_mask = fof.rgb_array_to_int(gt_mask)

if pred_mask.shape != gt_mask.shape:
msg = (
Expand Down Expand Up @@ -675,37 +713,15 @@ def _get_mask_values(samples, pred_field, gt_field, progress=None):
mask = seg.get_mask()
if mask.ndim == 3:
is_rgb = True
mask = _rgb_array_to_int(mask)
mask = fof.rgb_array_to_int(mask)

values.update(mask.ravel())

values = sorted(values)

if is_rgb:
classes = [_int_to_hex(v) for v in values]
classes = [fof.int_to_hex(v) for v in values]
else:
classes = [str(v) for v in values]

return values, classes


def _rgb_array_to_int(mask):
return (
np.left_shift(mask[:, :, 0], 16, dtype=int)
+ np.left_shift(mask[:, :, 1], 8, dtype=int)
+ mask[:, :, 2]
)


def _hex_to_int(hex_str):
r = int(hex_str[1:3], 16)
g = int(hex_str[3:5], 16)
b = int(hex_str[5:7], 16)
return (r << 16) + (g << 8) + b


def _int_to_hex(value):
r = (value >> 16) & 255
g = (value >> 8) & 255
b = value & 255
return "#%02x%02x%02x" % (r, g, b)
Loading
Loading