-
ROC curve
-
-
-
- ROC curve
- A receiver operating characteristic (ROC) curve plots the true positive rate (TPR) against the
- false positive rate (FPR) at various classification thresholds.
-
-
+
+
+
ROC curve
+
+
+
+ ROC curve
+ A receiver operating characteristic (ROC) curve plots the true positive rate (TPR) against the
+ false positive rate (FPR) at various classification thresholds.
+
+
+
+
False positive rate
+
True positive rate
+
+
+
+
+
PR curve
+
+
+
+ PR curve
+ A precision-recall (PR) curve plots precision against
+ recall at various classification thresholds.
+
+
+
+
Recall
+
Precision
+
+
-
False positive rate
-
True positive rate
-
-
@@ -3262,11 +3344,18 @@
Show similarity to selected datapoint
plotStats.push(inferenceStats.faceted[key]);
plotThresholds.push(modelThresholds[modelInd].threshold)
}
- this.plotPr(
+ this.plotChart(
+ this.$$('#' + this.getRocChartId(i)),
+ plotStats,
+ plotThresholds,
+ regenInferenceStats,
+ true);
+ this.plotChart(
this.$$('#' + this.getPrChartId(i)),
plotStats,
plotThresholds,
- regenInferenceStats);
+ regenInferenceStats,
+ false);
}
const plotStats = [];
const plotThresholds = [];
@@ -3275,9 +3364,12 @@
Show similarity to selected datapoint
plotStats.push(this.inferenceStats_[modelInd].thresholds);
plotThresholds.push(this.overallThresholds[modelInd].threshold);
}
- this.plotPr(
+ this.plotChart(
+ this.$$('#rocchart'), plotStats,
+ plotThresholds, regenInferenceStats, true);
+ this.plotChart(
this.$$('#prchart'), plotStats,
- plotThresholds, regenInferenceStats);
+ plotThresholds, regenInferenceStats, false);
}
this.updateCorrectness_();
},
@@ -3377,16 +3469,20 @@
Show similarity to selected datapoint
},
/**
- * Plots a PR purve given data to plot.
+ * Plots a PR or ROC purve given data to plot.
* thresholdstats and thresholds are arrays, indexed by model number
*/
- plotPr: function(chart, thresholdStats, thresholds,
- regenInferenceStats) {
+ plotChart: function(chart, thresholdStats, thresholds,
+ regenInferenceStats, isRoc) {
if (!thresholdStats || !thresholdStats[0] || !chart) {
return;
}
const visibleCharts = [];
const seriesColors = [];
+ const xAxis = isRoc ? 'FPR' : 'TPR';
+ const yAxis = isRoc ? 'TPR' : 'PPV';
+ const xAxisLabel = isRoc ? 'FPR' : 'Recall';
+ const yAxisLabel = isRoc ? 'TPR' : 'Precision';
for (let modelInd = 0; modelInd < thresholdStats.length; modelInd++) {
let currentThresholdData = null;
const data = thresholdStats[modelInd].map((thresh, i) => {
@@ -3394,14 +3490,14 @@
Show similarity to selected datapoint
// current threshold
if (i - thresholds[modelInd] * 100 < 0.5) {
currentThresholdData = {
- 'step': thresh['FPR'],
- 'scalar': thresh['TPR'],
+ 'step': thresh[xAxis],
+ 'scalar': thresh[yAxis],
'threshold': i / 100
};
}
return {
- 'step': thresh['FPR'],
- 'scalar': thresh['TPR'],
+ 'step': thresh[xAxis],
+ 'scalar': thresh[yAxis],
'threshold': i / 100
};
}).reverse();
@@ -3429,15 +3525,17 @@
Show similarity to selected datapoint
},
},
{
- title: 'TPR',
+ title: yAxisLabel,
evaluate: function (d) {
- return percentageFormatter(d.datum.scalar);
+ return isRoc ? percentageFormatter(d.datum.scalar) :
+ valueFormatter(d.datum.scalar);
},
},
{
- title: 'FPR',
+ title: xAxisLabel,
evaluate: function (d) {
- return percentageFormatter(d.datum.step);
+ return isRoc ? percentageFormatter(d.datum.step) :
+ valueFormatter(d.datum.step);
},
},
];
@@ -3462,7 +3560,7 @@
Show similarity to selected datapoint
},
/**
- * Calculates TPR and FPR given binary confusion matrix counts.
+ * Calculates TPR, FPR, and PPV given binary confusion matrix counts.
*/
calcThresholdStats: function(stats) {
for (let i = 0; i < stats.length; i++) {
@@ -3478,6 +3576,12 @@
Show similarity to selected datapoint
} else {
stats[i]['FPR'] = 0;
}
+ if (stats[i]['TP'] + stats[i]['FP'] > 0) {
+ stats[i]['PPV'] =
+ stats[i]['TP'] / (stats[i]['TP'] + stats[i]['FP']);
+ } else {
+ stats[i]['PPV'] = 0;
+ }
}
},
@@ -3779,6 +3883,35 @@
Show similarity to selected datapoint
val / this.getTotalEntriesInConfCounts(confCounts) * 100);
},
+ /**
+ * Returns the F1 score for a given confusion matrix
+ * from the threshold selected, model index, and facet to view.
+ */
+ getF1ModelIndex: function(inferenceStats, thresholds, modelInd, item) {
+ // TODO(jameswex): This unnecessarily recalculates confusion matrix.
+ // Can speed this up.
+ const formatter = d3.format(",.2f");
+ const confCounts = this.getConfusionCountsModelIndex(
+ inferenceStats, thresholds, modelInd, item);
+ if (Object.keys(confCounts).length == 0) {
+ return 0;
+ }
+ const truePositives = confCounts['Yes']['Yes'];
+ const falsePositives = confCounts['No']['Yes'];
+ const falseNegatives = confCounts['Yes']['No'];
+ if (truePositives == 0) {
+ if (falsePositives != 0 || falseNegatives != 0) {
+ return formatter(0);
+ } else {
+ return formatter(1);
+ }
+ }
+ const precision = truePositives / (truePositives + falsePositives);
+ const recall = truePositives / (truePositives + falseNegatives);
+ return formatter(
+ 2 * (precision * recall) / (precision + recall));
+ },
+
/**
* Returns the number of examples in a given facet.
*/
@@ -3971,6 +4104,10 @@
Show similarity to selected datapoint
}, 0), null);
},
+ getRocChartId: function(index) {
+ return 'rocchart' + index;
+ },
+
getPrChartId: function(index) {
return 'prchart' + index;
},