Skip to content

Commit

Permalink
filtering comparison field as well
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Dec 15, 2024
1 parent 981c42d commit a9ea1c3
Showing 1 changed file with 87 additions and 48 deletions.
135 changes: 87 additions & 48 deletions fiftyone/operators/builtins/panels/model_evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def load_evaluation(self, ctx):
"confusion_matrices": self.get_confusion_matrices(results),
"per_class_metrics": per_class_metrics,
}
ctx.panel.set_state("missing", results.missing)

if ENABLE_CACHING:
# Cache the evaluation data
try:
Expand Down Expand Up @@ -406,88 +408,125 @@ def load_view(self, ctx):
return

view_state = ctx.panel.get_state("view") or {}
view_options = ctx.params.get("options", {})

eval_key = view_state.get("key")
eval_key = view_options.get("key", eval_key)
eval_view = ctx.dataset.load_evaluation_view(eval_key)
info = ctx.dataset.get_evaluation_info(eval_key)
pred_field = info.config.pred_field
gt_field = info.config.gt_field
view_options = ctx.params.get("options", {})

eval_key2 = view_state.get("compareKey", None)
pred_field2 = None
gt_field2 = None
if eval_key2 is not None:
info2 = ctx.dataset.get_evaluation_info(eval_key2)
pred_field2 = info2.config.pred_field
if info2.config.gt_field != gt_field:
gt_field2 = info2.config.gt_field

x = view_options.get("x", None)
y = view_options.get("y", None)
field = view_options.get("field", None)
computed_eval_key = view_options.get("key", eval_key)
eval_view = ctx.dataset.load_evaluation_view(eval_key)
missing = ctx.panel.get_state("missing", "(none)")

view = None
if info.config.type == "classification":
if view_type == "class":
view = eval_view.match(
(F(f"{gt_field}.label") == x)
| (F(f"{pred_field}.label") == x)
)
# All GT/predictions of class `x`
expr = F(f"{gt_field}.label") == x
expr |= F(f"{pred_field}.label") == x
if gt_field2 is not None:
expr |= F(f"{gt_field2}.label") == x
if pred_field2 is not None:
expr |= F(f"{pred_field2}.label") == x
view = eval_view.match(expr)
elif view_type == "matrix":
view = eval_view.match(
(F(f"{gt_field}.label") == y)
& (F(f"{pred_field}.label") == x)
)
# Specific confusion matrix cell (including FP/FN)
expr = F(f"{gt_field}.label") == y
expr &= F(f"{pred_field}.label") == x
view = eval_view.match(expr)
elif view_type == "field":
if field == "fn":
view = eval_view.match(
F(f"{gt_field}.{computed_eval_key}") == field
)
if info.config.method == "binary":
# All TP/FP/FN
expr = F(f"{eval_key}") == field.upper()
view = eval_view.match(expr)
else:
view = eval_view.match(
F(f"{pred_field}.{computed_eval_key}") == field
)
# Correct/incorrect
expr = F(f"{eval_key}") == field
view = eval_view.match(expr)
elif info.config.type == "detection":
_, pred_root = ctx.dataset._get_label_field_path(pred_field)
_, gt_root = ctx.dataset._get_label_field_path(gt_field)
_, pred_root = ctx.dataset._get_label_field_path(pred_field)
if gt_field2 is not None:
_, gt_root2 = ctx.dataset._get_label_field_path(gt_field2)
if pred_field2 is not None:
_, pred_root2 = ctx.dataset._get_label_field_path(pred_field2)

if view_type == "class":
view = (
eval_view.filter_labels(
pred_field, F("label") == x, only_matches=False
)
.filter_labels(
gt_field, F("label") == x, only_matches=False
# All GT/predictions of class `x`
view = eval_view.filter_labels(
gt_field, F("label") == x, only_matches=False
)
expr = F(gt_root).length() > 0
view = view.filter_labels(
pred_field, F("label") == x, only_matches=False
)
expr |= F(pred_root).length() > 0
if gt_field2 is not None:
view = view.filter_labels(
gt_field2, F("label") == x, only_matches=False
)
.match(
(F(pred_root).length() > 0) | (F(gt_root).length() > 0)
expr |= F(gt_root2).length() > 0
if pred_field2 is not None:
view = view.filter_labels(
pred_field2, F("label") == x, only_matches=False
)
)
expr |= F(pred_root2).length() > 0
view = view.match(expr)
elif view_type == "matrix":
view = (
eval_view.filter_labels(
if y == missing:
# False positives of class `x`
expr = (F("label") == x) & (F(eval_key) == "fp")
view = eval_view.filter_labels(
pred_field, expr, only_matches=True
)
elif x == missing:
# False negatives of class `y`
expr = (F("label") == y) & (F(eval_key) == "fn")
view = eval_view.filter_labels(
gt_field, expr, only_matches=True
)
else:
# All class `y` GT and class `x` predictions in same sample
view = eval_view.filter_labels(
gt_field, F("label") == y, only_matches=False
)
.filter_labels(
expr = F(gt_root).length() > 0
view = view.filter_labels(
pred_field, F("label") == x, only_matches=False
)
.match(
(F(pred_root).length() > 0) & (F(gt_root).length() > 0)
)
)
expr &= F(pred_root).length() > 0
view = view.match(expr)
elif view_type == "field":
if field == "tp":
# All true positives
view = eval_view.filter_labels(
gt_field,
F(computed_eval_key) == field,
only_matches=False,
).filter_labels(
pred_field,
F(computed_eval_key) == field,
only_matches=True,
gt_field, F(eval_key) == field, only_matches=False
)
view = view.filter_labels(
pred_field, F(eval_key) == field, only_matches=True
)
elif field == "fn":
# All false negatives
view = eval_view.filter_labels(
gt_field,
F(computed_eval_key) == field,
only_matches=True,
gt_field, F(eval_key) == field, only_matches=True
)
else:
# All false positives
view = eval_view.filter_labels(
pred_field,
F(computed_eval_key) == field,
only_matches=True,
pred_field, F(eval_key) == field, only_matches=True
)

if view is not None:
Expand Down

0 comments on commit a9ea1c3

Please sign in to comment.