Skip to content

Commit

Permalink
Merge pull request #5332 from voxel51/segmentation-callbacks2
Browse files Browse the repository at this point in the history
Add Model Evaluation panel callbacks for segmentation tasks
  • Loading branch information
brimoor authored Jan 24, 2025
2 parents 448177d + 3d11a02 commit 2ac97ad
Show file tree
Hide file tree
Showing 5 changed files with 340 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1783,11 +1783,8 @@ type SummaryRow = {

function formatCustomMetricRows(evaluationMetrics, comparisonMetrics) {
const results = [] as SummaryRow[];
const customMetrics = _.get(
evaluationMetrics,
"custom_metrics",
{}
) as CustomMetrics;
const customMetrics = (_.get(evaluationMetrics, "custom_metrics", null) ||
{}) as CustomMetrics;
for (const [operatorUri, customMetric] of Object.entries(customMetrics)) {
const compareValue = _.get(
comparisonMetrics,
Expand Down
62 changes: 62 additions & 0 deletions fiftyone/core/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -1621,6 +1621,68 @@ def is_rgb_target(target):
)


def hex_to_int(hex_str):
"""Converts a hex string like `"#ff6d04"` to a hex integer.
Args:
hex_str: a hex string
Returns:
an integer
"""
r = int(hex_str[1:3], 16)
g = int(hex_str[3:5], 16)
b = int(hex_str[5:7], 16)
return (r << 16) + (g << 8) + b


def int_to_hex(value):
"""Converts an RRGGBB integer value to hex string like `"#ff6d04"`.
Args:
value: an integer value
Returns:
a hex string
"""
r = (value >> 16) & 255
g = (value >> 8) & 255
b = value & 255
return "#%02x%02x%02x" % (r, g, b)


def rgb_array_to_int(mask):
"""Converts an RGB mask array to a 2D hex integer mask array.
Args:
mask: an RGB mask array
Returns:
a 2D integer mask array
"""
return (
np.left_shift(mask[:, :, 0], 16, dtype=int)
+ np.left_shift(mask[:, :, 1], 8, dtype=int)
+ mask[:, :, 2]
)


def int_array_to_rgb(mask):
"""Converts a 2D hex integer mask array to an RGB mask array.
Args:
mask: a 2D integer mask array
Returns:
an RGB mask array
"""
return np.stack(
[(mask >> 16) & 255, (mask >> 8) & 255, mask & 255],
axis=2,
dtype=np.uint8,
)


class EmbeddedDocumentField(mongoengine.fields.EmbeddedDocumentField, Field):
"""A field that stores instances of a given type of
:class:`fiftyone.core.odm.BaseEmbeddedDocument` object.
Expand Down
15 changes: 15 additions & 0 deletions fiftyone/server/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ def load_and_cache_dataset(name):
return dataset


def cache_dataset(dataset):
"""Caches the given dataset.
This method ensures that subsequent calls to
:func:`fiftyone.core.dataset.load_dataset` in async calls will return this
dataset singleton.
See :meth:`load_and_cache_dataset` for additional details.
Args:
dataset: a :class:`fiftyone.core.dataset.Dataset`
"""
_cache[dataset.name] = dataset


def change_sample_tags(sample_collection, changes):
"""Applies the changes to tags to all samples of the collection, if
necessary.
Expand Down
94 changes: 55 additions & 39 deletions fiftyone/utils/eval/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from copy import deepcopy
import logging
import inspect
import itertools
import warnings

import numpy as np
Expand Down Expand Up @@ -369,7 +370,7 @@ def evaluate_samples(
if mask_targets is not None:
if fof.is_rgb_mask_targets(mask_targets):
mask_targets = {
_hex_to_int(k): v for k, v in mask_targets.items()
fof.hex_to_int(k): v for k, v in mask_targets.items()
}

values, classes = zip(*sorted(mask_targets.items()))
Expand All @@ -385,6 +386,7 @@ def evaluate_samples(

nc = len(values)
confusion_matrix = np.zeros((nc, nc), dtype=int)
matches = []

bandwidth = self.config.bandwidth
average = self.config.average
Expand Down Expand Up @@ -427,6 +429,17 @@ def evaluate_samples(
)
sample_conf_mat += image_conf_mat

for i, j in zip(*np.nonzero(image_conf_mat)):
matches.append(
(
classes[i],
classes[j],
int(image_conf_mat[i, j]),
gt_seg.id,
pred_seg.id,
)
)

if processing_frames and save:
facc, fpre, frec = _compute_accuracy_precision_recall(
image_conf_mat, values, average
Expand Down Expand Up @@ -460,6 +473,7 @@ def evaluate_samples(
eval_key,
confusion_matrix,
classes,
matches=matches,
missing=missing,
backend=self,
)
Expand All @@ -474,6 +488,9 @@ class SegmentationResults(BaseClassificationResults):
eval_key: the evaluation key
pixel_confusion_matrix: a pixel value confusion matrix
classes: a list of class labels corresponding to the confusion matrix
matches (None): a list of
``(gt_label, pred_label, pixel_count, gt_id, pred_id)``
matches
missing (None): a missing (background) class
custom_metrics (None): an optional dict of custom metrics
backend (None): a :class:`SegmentationEvaluation` backend
Expand All @@ -486,14 +503,23 @@ def __init__(
eval_key,
pixel_confusion_matrix,
classes,
matches=None,
missing=None,
custom_metrics=None,
backend=None,
):
pixel_confusion_matrix = np.asarray(pixel_confusion_matrix)
ytrue, ypred, weights = self._parse_confusion_matrix(
pixel_confusion_matrix, classes
)

if matches is None:
ytrue, ypred, weights = self._parse_confusion_matrix(
pixel_confusion_matrix, classes
)
ytrue_ids = None
ypred_ids = None
elif matches:
ytrue, ypred, weights, ytrue_ids, ypred_ids = zip(*matches)
else:
ytrue, ypred, weights, ytrue_ids, ypred_ids = [], [], [], [], []

super().__init__(
samples,
Expand All @@ -502,6 +528,8 @@ def __init__(
ytrue,
ypred,
weights=weights,
ytrue_ids=ytrue_ids,
ypred_ids=ypred_ids,
classes=classes,
missing=missing,
custom_metrics=custom_metrics,
Expand All @@ -510,15 +538,6 @@ def __init__(

self.pixel_confusion_matrix = pixel_confusion_matrix

def attributes(self):
return [
"cls",
"pixel_confusion_matrix",
"classes",
"missing",
"custom_metrics",
]

def dice_score(self):
"""Computes the Dice score across all samples in the evaluation.
Expand All @@ -529,12 +548,31 @@ def dice_score(self):

@classmethod
def _from_dict(cls, d, samples, config, eval_key, **kwargs):
ytrue = d.get("ytrue", None)
ypred = d.get("ypred", None)
weights = d.get("weights", None)
ytrue_ids = d.get("ytrue_ids", None)
ypred_ids = d.get("ypred_ids", None)

if ytrue is not None and ypred is not None and weights is not None:
if ytrue_ids is None:
ytrue_ids = itertools.repeat(None)

if ypred_ids is None:
ypred_ids = itertools.repeat(None)

matches = list(zip(ytrue, ypred, weights, ytrue_ids, ypred_ids))
else:
# Legacy format segmentations
matches = None

return cls(
samples,
config,
eval_key,
d["pixel_confusion_matrix"],
d["classes"],
matches=matches,
missing=d.get("missing", None),
custom_metrics=d.get("custom_metrics", None),
**kwargs,
Expand Down Expand Up @@ -599,10 +637,10 @@ def _compute_pixel_confusion_matrix(
pred_mask, gt_mask, values, bandwidth=None
):
if pred_mask.ndim == 3:
pred_mask = _rgb_array_to_int(pred_mask)
pred_mask = fof.rgb_array_to_int(pred_mask)

if gt_mask.ndim == 3:
gt_mask = _rgb_array_to_int(gt_mask)
gt_mask = fof.rgb_array_to_int(gt_mask)

if pred_mask.shape != gt_mask.shape:
msg = (
Expand Down Expand Up @@ -675,37 +713,15 @@ def _get_mask_values(samples, pred_field, gt_field, progress=None):
mask = seg.get_mask()
if mask.ndim == 3:
is_rgb = True
mask = _rgb_array_to_int(mask)
mask = fof.rgb_array_to_int(mask)

values.update(mask.ravel())

values = sorted(values)

if is_rgb:
classes = [_int_to_hex(v) for v in values]
classes = [fof.int_to_hex(v) for v in values]
else:
classes = [str(v) for v in values]

return values, classes


def _rgb_array_to_int(mask):
return (
np.left_shift(mask[:, :, 0], 16, dtype=int)
+ np.left_shift(mask[:, :, 1], 8, dtype=int)
+ mask[:, :, 2]
)


def _hex_to_int(hex_str):
r = int(hex_str[1:3], 16)
g = int(hex_str[3:5], 16)
b = int(hex_str[5:7], 16)
return (r << 16) + (g << 8) + b


def _int_to_hex(value):
r = (value >> 16) & 255
g = (value >> 8) & 255
b = value & 255
return "#%02x%02x%02x" % (r, g, b)
Loading

0 comments on commit 2ac97ad

Please sign in to comment.