Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a streaming_op() to the pr_curves plugin. #587

Merged
merged 3 commits into from
Oct 3, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions tensorboard/plugins/pr_curve/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,104 @@ def op(
description,
collections)


def streaming_op(tag,
labels,
predictions,
num_thresholds=200,
weights=None,
metrics_collections=None,
updates_collections=None,
display_name=None,
description=None):
"""Computes a precision-recall curve summary across batches of data.

This function is similar to op() above, but can be used to compute the PR
curve across multiple batches of labels and predictions, in the same style
as the metrics found in tf.metrics.

This function creates multiple local variables for storing true positives,
true negative, etc. accumulated over each batch of data, and uses these local
variables for computing the final PR curve summary. These variables can be
updated with the returned update_op.

Args:
tag: A tag attached to the summary. Used by TensorBoard for organization.
labels: The ground truth values, a `Tensor` whose dimensions must match
`predictions`. Will be cast to `bool`.
predictions: A floating point `Tensor` of arbitrary shape and whose values
are in the range `[0, 1]`.
num_thresholds: The number of evenly spaced thresholds to generate for
computing the PR curve.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `labels` dimension).
metrics_collections: An optional list of collections that `auc` should be
added to.
updates_collections: An optional list of collections that `update_op` should
be added to.
display_name: Optional name for this summary in TensorBoard, as a
constant `str`. Defaults to `name`.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.

Returns:
pr_curve: A string `Tensor` containing a single value: the
serialized PR curve Tensor summary. The summary contains a
float32 `Tensor` of dimension (6, num_thresholds). The first
dimension (of length 6) is of the order: true positives, false
positives, true negatives, false negatives, precision, recall.
update_op: An operation that updates the summary with the latest data.
"""
thresholds = [i / float(num_thresholds - 1)
for i in range(num_thresholds)]

with tf.name_scope(tag, values=[labels, predictions, weights]):
tp, update_tp = tf.metrics.true_positives_at_thresholds(
labels=labels,
predictions=predictions,
thresholds=thresholds,
weights=weights)
fp, update_fp = tf.metrics.false_positives_at_thresholds(
labels=labels,
predictions=predictions,
thresholds=thresholds,
weights=weights)
tn, update_tn = tf.metrics.true_negatives_at_thresholds(
labels=labels,
predictions=predictions,
thresholds=thresholds,
weights=weights)
fn, update_fn = tf.metrics.false_negatives_at_thresholds(
labels=labels,
predictions=predictions,
thresholds=thresholds,
weights=weights)

def compute_summary(tp, fp, tn, fn, collections):
precision = tp / tf.maximum(_MINIMUM_COUNT, tp + fp)
recall = tp / tf.maximum(_MINIMUM_COUNT, tp + fn)

return _create_tensor_summary(
tag,
tp,
fp,
tn,
fn,
precision,
recall,
num_thresholds,
display_name,
description,
collections)

pr_curve = compute_summary(tp, fp, tn, fn, metrics_collections)
update_op = compute_summary(update_tp, update_fp, update_tn, update_fn,
updates_collections)

return pr_curve, update_op


def raw_data_op(
tag,
true_positive_counts,
Expand Down
76 changes: 76 additions & 0 deletions tensorboard/plugins/pr_curve/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,5 +320,81 @@ def testRawDataOp(self):
], tensor_events[0])


class StreamingOpTest(tf.test.TestCase):

def setUp(self):
super(StreamingOpTest, self).setUp()
tf.reset_default_graph()
np.random.seed(1)

def pb_via_op(self, summary_op):
actual_pbtxt = summary_op.eval()
actual_proto = tf.Summary()
actual_proto.ParseFromString(actual_pbtxt)
return actual_proto

def tensor_via_op(self, summary_op):
actual_pbtxt = summary_op.eval()
actual_proto = tf.Summary()
actual_proto.ParseFromString(actual_pbtxt)
return actual_proto

def testMatchesOp(self):
predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32)
labels = tf.constant([False, True, True, False, True], dtype=tf.bool)

pr_curve, update_op = summary.streaming_op(tag='pr_curve',
predictions=predictions,
labels=labels,
num_thresholds=10)
expected_pr_curve = summary.op(tag='pr_curve',
predictions=predictions,
labels=labels,
num_thresholds=10)
with self.test_session() as sess:
sess.run(tf.local_variables_initializer())
sess.run([update_op])

proto = self.pb_via_op(pr_curve)
expected_proto = self.pb_via_op(expected_pr_curve)

# Need to detect and fix the automatic _1 appended to second namespace.
self.assertEqual(proto.value[0].tag, 'pr_curve/pr_curves')
self.assertEqual(expected_proto.value[0].tag, 'pr_curve_1/pr_curves')
expected_proto.value[0].tag = 'pr_curve/pr_curves'

self.assertProtoEquals(expected_proto, proto)

def testMatchesOpWithUpdates(self):
predictions = tf.constant([0.2, 0.4, 0.5, 0.6, 0.8], dtype=tf.float32)
labels = tf.constant([False, True, True, False, True], dtype=tf.bool)
pr_curve, update_op = summary.streaming_op(tag='pr_curve',
predictions=predictions,
labels=labels,
num_thresholds=10)

complete_predictions = tf.tile(predictions, [3])
complete_labels = tf.tile(labels, [3])
expected_pr_curve = summary.op(tag='pr_curve',
predictions=complete_predictions,
labels=complete_labels,
num_thresholds=10)
with self.test_session() as sess:
sess.run(tf.local_variables_initializer())
sess.run([update_op])
sess.run([update_op])
sess.run([update_op])

proto = self.pb_via_op(pr_curve)
expected_proto = self.pb_via_op(expected_pr_curve)

# Need to detect and fix the automatic _1 appended to second namespace.
self.assertEqual(proto.value[0].tag, 'pr_curve/pr_curves')
self.assertEqual(expected_proto.value[0].tag, 'pr_curve_1/pr_curves')
expected_proto.value[0].tag = 'pr_curve/pr_curves'

self.assertProtoEquals(expected_proto, proto)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, I should use assertProtoEquals more often.



if __name__ == "__main__":
tf.test.main()
1 change: 1 addition & 0 deletions tensorboard/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
image_pb = _image_summary.pb

pr_curve = _pr_curve_summary.op
pr_curve_streaming_op = _pr_curve_summary.streaming_op
pr_curve_raw_data = _pr_curve_summary.raw_data_op

scalar = _scalar_summary.op
Expand Down