diff --git a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html index 7bfd600c83..107f83d7b8 100644 --- a/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html +++ b/tensorboard/plugins/interactive_inference/tf_interactive_inference_dashboard/tf-interactive-inference-dashboard.html @@ -221,9 +221,9 @@ } .pr-line-chart { - margin: 6px; + margin: 0; height: 200px; - width: 300px; + width: 280px; display: inline-block; } @@ -536,19 +536,25 @@ font-weight: 500; } - .roc-holder { + .curves-holder { + display: flex; + flex-wrap: wrap; + margin-top: 20px; + position: relative; + } + + .curve-holder { width: 300px; height: 235px; - margin-top: 20px; margin-bottom: 20px; - margin-right: 50px; + margin-right: 20px; position: relative; } .roc-x-label { position: absolute; bottom: 0; - left: 140px; + left: 120px; font-size: 12px; color: #5f6368; padding: 0px; @@ -556,7 +562,26 @@ .roc-y-label { position: absolute; - left: -30px; + left: -36px; + bottom: 110px; + transform: rotate(270deg); + font-size: 12px; + color: #5f6368; + padding: 0px; + } + + .pr-x-label { + position: absolute; + bottom: 0; + left: 140px; + font-size: 12px; + color: #5f6368; + padding: 0px; + } + + .pr-y-label { + position: absolute; + left: -14px; bottom: 110px; transform: rotate(270deg); font-size: 12px; @@ -876,6 +901,7 @@ .roc-text { color: #3C4043; font-size: 16px; + margin-left: 44px; } .conf-text { @@ -1112,6 +1138,11 @@ text-align: right; margin-right: 20px; } + .perf-table-f1 { + width: 10%; + text-align: right; + margin-right: 20px; + } .perf-button { margin-top: 10px; } @@ -1771,6 +1802,7 @@

Show similarity to selected datapoint

False Positives (%)
False Negatives (%)
Accuracy (%)
+
F1
+
+ +
@@ -1824,23 +1861,43 @@

Show similarity to selected datapoint

-
-
ROC curve - - - -
ROC curve
-
A receiver operating characteristic (ROC) curve plots the true positive rate (TPR) against the - false positive rate (FPR) at various classification thresholds. -
-
+
+
+
ROC curve + + + +
ROC curve
+
A receiver operating characteristic (ROC) curve plots the true positive rate (TPR) against the + false positive rate (FPR) at various classification thresholds. +
+
+
+
False positive rate
+
True positive rate
+ + +
+
+
PR curve + + + +
PR curve
+
A precision-recall (PR) curve plots precision against + recall at various classification thresholds. +
+
+
+
Recall
+
Precision
+ +
-
False positive rate
-
True positive rate
- -
@@ -1881,6 +1938,11 @@

Show similarity to selected datapoint

[[getAccuracyModelIndex(inferenceStats_, overallThresholds, index)]]
+
+ +
@@ -1896,23 +1958,43 @@

Show similarity to selected datapoint

-
-
ROC curve - - - -
ROC curve
-
A receiver operating characteristic (ROC) curve plots the true positive rate (TPR) against the - false positive rate (FPR) at various classification thresholds. -
-
+
+
+
ROC curve + + + +
ROC curve
+
A receiver operating characteristic (ROC) curve plots the true positive rate (TPR) against the + false positive rate (FPR) at various classification thresholds. +
+
+
+
False positive rate
+
True positive rate
+ + +
+
+
PR curve + + + +
PR curve
+
A precision-recall (PR) curve plots precision against + recall at various classification thresholds. +
+
+
+
Recall
+
Precision
+ +
-
False positive rate
-
True positive rate
- -
@@ -3262,11 +3344,18 @@

Show similarity to selected datapoint

plotStats.push(inferenceStats.faceted[key]); plotThresholds.push(modelThresholds[modelInd].threshold) } - this.plotPr( + this.plotChart( + this.$$('#' + this.getRocChartId(i)), + plotStats, + plotThresholds, + regenInferenceStats, + true); + this.plotChart( this.$$('#' + this.getPrChartId(i)), plotStats, plotThresholds, - regenInferenceStats); + regenInferenceStats, + false); } const plotStats = []; const plotThresholds = []; @@ -3275,9 +3364,12 @@

Show similarity to selected datapoint

plotStats.push(this.inferenceStats_[modelInd].thresholds); plotThresholds.push(this.overallThresholds[modelInd].threshold); } - this.plotPr( + this.plotChart( + this.$$('#rocchart'), plotStats, + plotThresholds, regenInferenceStats, true); + this.plotChart( this.$$('#prchart'), plotStats, - plotThresholds, regenInferenceStats); + plotThresholds, regenInferenceStats, false); } this.updateCorrectness_(); }, @@ -3377,16 +3469,20 @@

Show similarity to selected datapoint

}, /** - * Plots a PR purve given data to plot. + * Plots a PR or ROC purve given data to plot. * thresholdstats and thresholds are arrays, indexed by model number */ - plotPr: function(chart, thresholdStats, thresholds, - regenInferenceStats) { + plotChart: function(chart, thresholdStats, thresholds, + regenInferenceStats, isRoc) { if (!thresholdStats || !thresholdStats[0] || !chart) { return; } const visibleCharts = []; const seriesColors = []; + const xAxis = isRoc ? 'FPR' : 'TPR'; + const yAxis = isRoc ? 'TPR' : 'PPV'; + const xAxisLabel = isRoc ? 'FPR' : 'Recall'; + const yAxisLabel = isRoc ? 'TPR' : 'Precision'; for (let modelInd = 0; modelInd < thresholdStats.length; modelInd++) { let currentThresholdData = null; const data = thresholdStats[modelInd].map((thresh, i) => { @@ -3394,14 +3490,14 @@

Show similarity to selected datapoint

// current threshold if (i - thresholds[modelInd] * 100 < 0.5) { currentThresholdData = { - 'step': thresh['FPR'], - 'scalar': thresh['TPR'], + 'step': thresh[xAxis], + 'scalar': thresh[yAxis], 'threshold': i / 100 }; } return { - 'step': thresh['FPR'], - 'scalar': thresh['TPR'], + 'step': thresh[xAxis], + 'scalar': thresh[yAxis], 'threshold': i / 100 }; }).reverse(); @@ -3429,15 +3525,17 @@

Show similarity to selected datapoint

}, }, { - title: 'TPR', + title: yAxisLabel, evaluate: function (d) { - return percentageFormatter(d.datum.scalar); + return isRoc ? percentageFormatter(d.datum.scalar) : + valueFormatter(d.datum.scalar); }, }, { - title: 'FPR', + title: xAxisLabel, evaluate: function (d) { - return percentageFormatter(d.datum.step); + return isRoc ? percentageFormatter(d.datum.step) : + valueFormatter(d.datum.step); }, }, ]; @@ -3462,7 +3560,7 @@

Show similarity to selected datapoint

}, /** - * Calculates TPR and FPR given binary confusion matrix counts. + * Calculates TPR, FPR, and PPV given binary confusion matrix counts. */ calcThresholdStats: function(stats) { for (let i = 0; i < stats.length; i++) { @@ -3478,6 +3576,12 @@

Show similarity to selected datapoint

} else { stats[i]['FPR'] = 0; } + if (stats[i]['TP'] + stats[i]['FP'] > 0) { + stats[i]['PPV'] = + stats[i]['TP'] / (stats[i]['TP'] + stats[i]['FP']); + } else { + stats[i]['PPV'] = 0; + } } }, @@ -3779,6 +3883,35 @@

Show similarity to selected datapoint

val / this.getTotalEntriesInConfCounts(confCounts) * 100); }, + /** + * Returns the F1 score for a given confusion matrix + * from the threshold selected, model index, and facet to view. + */ + getF1ModelIndex: function(inferenceStats, thresholds, modelInd, item) { + // TODO(jameswex): This unnecessarily recalculates confusion matrix. + // Can speed this up. + const formatter = d3.format(",.2f"); + const confCounts = this.getConfusionCountsModelIndex( + inferenceStats, thresholds, modelInd, item); + if (Object.keys(confCounts).length == 0) { + return 0; + } + const truePositives = confCounts['Yes']['Yes']; + const falsePositives = confCounts['No']['Yes']; + const falseNegatives = confCounts['Yes']['No']; + if (truePositives == 0) { + if (falsePositives != 0 || falseNegatives != 0) { + return formatter(0); + } else { + return formatter(1); + } + } + const precision = truePositives / (truePositives + falsePositives); + const recall = truePositives / (truePositives + falseNegatives); + return formatter( + 2 * (precision * recall) / (precision + recall)); + }, + /** * Returns the number of examples in a given facet. */ @@ -3971,6 +4104,10 @@

Show similarity to selected datapoint

}, 0), null); }, + getRocChartId: function(index) { + return 'rocchart' + index; + }, + getPrChartId: function(index) { return 'prchart' + index; },