-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a streaming_op() to the pr_curves plugin. #587
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,6 +164,98 @@ def op( | |
description, | ||
collections) | ||
|
||
|
||
def streaming_op(tag, labels, predictions, num_thresholds=200, weights=None, | ||
metrics_collections=None, updates_collections=None, | ||
display_name=None, description=None): | ||
"""Computes a precision-recall curve summary across batches of data. | ||
|
||
This function is similar to op() above, but can be used to compute the PR | ||
curve across multiple batches of labels and predictions, in the same style | ||
as the metrics found in tf.metrics. | ||
|
||
This function creates multiple local variables for storing true positives, | ||
true negative, etc. accumulated over each batch of data, and uses these local | ||
variables for computing the final PR curve summary. These variables can be | ||
updated with the returned update_op. | ||
|
||
Args: | ||
tag: A tag attached to the summary. Used by TensorBoard for organization. | ||
labels: The ground truth values, a `Tensor` whose dimensions must match | ||
`predictions`. Will be cast to `bool`. | ||
predictions: A floating point `Tensor` of arbitrary shape and whose values | ||
are in the range `[0, 1]`. | ||
num_thresholds: The number of evenly spaced thresholds to generate for | ||
computing the PR curve. | ||
weights: Optional `Tensor` whose rank is either 0, or the same rank as | ||
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must | ||
be either `1`, or the same as the corresponding `labels` dimension). | ||
metrics_collections: An optional list of collections that `auc` should be | ||
added to. | ||
updates_collections: An optional list of collections that `update_op` should | ||
be added to. | ||
display_name: Optional name for this summary in TensorBoard, as a | ||
constant `str`. Defaults to `name`. | ||
description: Optional long-form description for this summary, as a | ||
constant `str`. Markdown is supported. Defaults to empty. | ||
|
||
Returns: | ||
pr_curve: A string `Tensor` containing a single value: the | ||
serialized PR curve Tensor summary. The summary contains a | ||
float32 `Tensor` of dimension (6, num_thresholds). The first | ||
dimension (of length 6) is of the order: true positives, false | ||
positives, true negatives, false negatives, precision, recall. | ||
update_op: An operation that updates the summary with the latest data. | ||
""" | ||
thresholds = [i * 1.0 / float(num_thresholds - 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep. Done. |
||
for i in range(num_thresholds)] | ||
|
||
with tf.name_scope(tag, values=[labels, predictions, weights]): | ||
tp, update_tp = tf.metrics.true_positives_at_thresholds( | ||
labels=labels, | ||
predictions=predictions, | ||
thresholds=thresholds, | ||
weights=weights) | ||
fp, update_fp = tf.metrics.false_positives_at_thresholds( | ||
labels=labels, | ||
predictions=predictions, | ||
thresholds=thresholds, | ||
weights=weights) | ||
tn, update_tn = tf.metrics.true_negatives_at_thresholds( | ||
labels=labels, | ||
predictions=predictions, | ||
thresholds=thresholds, | ||
weights=weights) | ||
fn, update_fn = tf.metrics.false_negatives_at_thresholds( | ||
labels=labels, | ||
predictions=predictions, | ||
thresholds=thresholds, | ||
weights=weights) | ||
|
||
def compute_summary(tp, fp, tn, fn, collections): | ||
precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp) | ||
recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn) | ||
|
||
return _create_tensor_summary( | ||
tag, | ||
tp, | ||
fp, | ||
tn, | ||
fn, | ||
precision, | ||
recall, | ||
num_thresholds, | ||
display_name, | ||
description, | ||
collections) | ||
|
||
pr_curve = compute_summary(tp, fp, tn, fn, metrics_collections) | ||
update_op = compute_summary(update_tp, update_fp, update_tn, update_fn, | ||
updates_collections) | ||
|
||
return pr_curve, update_op | ||
|
||
|
||
def raw_data_op( | ||
tag, | ||
true_positive_counts, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -320,5 +320,81 @@ def testRawDataOp(self): | |
], tensor_events[0]) | ||
|
||
|
||
class StreamingOpTest(tf.test.TestCase): | ||
|
||
def setUp(self): | ||
super(StreamingOpTest, self).setUp() | ||
tf.reset_default_graph() | ||
np.random.seed(1) | ||
|
||
def pb_via_op(self, summary_op): | ||
actual_pbtxt = summary_op.eval() | ||
actual_proto = tf.Summary() | ||
actual_proto.ParseFromString(actual_pbtxt) | ||
return actual_proto | ||
|
||
def tensor_via_op(self, summary_op): | ||
actual_pbtxt = summary_op.eval() | ||
actual_proto = tf.Summary() | ||
actual_proto.ParseFromString(actual_pbtxt) | ||
return actual_proto | ||
|
||
def testMatchesOp(self): | ||
predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) | ||
labels = tf.constant([False, True, True, False, True], dtype=tf.bool) | ||
|
||
pr_curve, update_op = summary.streaming_op(tag='pr_curve', | ||
predictions=predictions, | ||
labels=labels, | ||
num_thresholds=10) | ||
expected_pr_curve = summary.op(tag='pr_curve', | ||
predictions=predictions, | ||
labels=labels, | ||
num_thresholds=10) | ||
with self.test_session() as sess: | ||
sess.run(tf.local_variables_initializer()) | ||
sess.run([update_op]) | ||
|
||
proto = self.pb_via_op(pr_curve) | ||
expected_proto = self.pb_via_op(expected_pr_curve) | ||
|
||
# Need to detect and fix the automatic _1 appended to second namespace. | ||
self.assertEqual(proto.value[0].tag, 'pr_curve/pr_curves') | ||
self.assertEqual(expected_proto.value[0].tag, 'pr_curve_1/pr_curves') | ||
expected_proto.value[0].tag = 'pr_curve/pr_curves' | ||
|
||
self.assertProtoEquals(expected_proto, proto) | ||
|
||
def testMatchesOpWithUpdates(self): | ||
predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32) | ||
labels = tf.constant([False, True, True, False, True], dtype=tf.bool) | ||
pr_curve, update_op = summary.streaming_op(tag='pr_curve', | ||
predictions=predictions, | ||
labels=labels, | ||
num_thresholds=10) | ||
|
||
complete_predictions = tf.tile(predictions, [3]) | ||
complete_labels = tf.tile(labels, [3]) | ||
expected_pr_curve = summary.op(tag='pr_curve', | ||
predictions=complete_predictions, | ||
labels=complete_labels, | ||
num_thresholds=10) | ||
with self.test_session() as sess: | ||
sess.run(tf.local_variables_initializer()) | ||
sess.run([update_op]) | ||
sess.run([update_op]) | ||
sess.run([update_op]) | ||
|
||
proto = self.pb_via_op(pr_curve) | ||
expected_proto = self.pb_via_op(expected_pr_curve) | ||
|
||
# Need to detect and fix the automatic _1 appended to second namespace. | ||
self.assertEqual(proto.value[0].tag, 'pr_curve/pr_curves') | ||
self.assertEqual(expected_proto.value[0].tag, 'pr_curve_1/pr_curves') | ||
expected_proto.value[0].tag = 'pr_curve/pr_curves' | ||
|
||
self.assertProtoEquals(expected_proto, proto) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh wow, I should use |
||
|
||
|
||
if __name__ == "__main__": | ||
tf.test.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Could we reorganize the arguments to be one per row?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.