Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WIT ability to consume arbitrary prediction-time information #2660

Merged
merged 10 commits into from
Sep 20, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3266,6 +3266,10 @@ <h2>Show similarity to selected datapoint</h2>
observer: 'newInferences_',
value: () => ({}),
},
extraOutputs: {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments here like you did with attributions below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to define the type using closure type syntax? It does not have to be correct but it is for readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added @type {indices: Array, extra: Array<{Object}>} but am not too familiar with closure.
If I want to specify that the Object above is a dict with arbitrary keys where each value in the dict is an array of numbers or strings, can I do that as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, TypeScript types are all expressable in Closure and vice a versa. I believe you can do !Object<string, number>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thx done

type: Object,
observer: 'newExtraOutputs_',
},
// Attributions from inference. A dict with two fields: 'indices' and
// 'attributions'. Indices contains a list of example indices that
// these new attributions apply to. Attributions contains a list of
Expand Down Expand Up @@ -3778,12 +3782,16 @@ <h2>Show similarity to selected datapoint</h2>
} else {
this.comparedIndices = [];
this.counterfactualExampleAndInference = null;
const temp = this.selectedExampleAndInference;
this.selectedExampleAndInference = null;
this.selectedExampleAndInference = temp;
this.refreshSelectedDatapoint_();
}
},

refreshSelectedDatapoint_: function() {
const temp = this.selectedExampleAndInference;
this.selectedExampleAndInference = null;
this.selectedExampleAndInference = temp
},

findClosestCounterfactual_: function() {
const selected = this.selected[0];
const modelInferenceValueStr = this.strWithModelName_(
Expand Down Expand Up @@ -5856,6 +5864,71 @@ <h2>Show similarity to selected datapoint</h2>
this.updatedExample = false;
},

newExtraOutputs_: function(extraOutputs) {
// Set attributions from the extra outputs, if available.
const attributions = [];
for (let i = 0; i < extraOutputs.extra.length; i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i -> modelNum would be better here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if ('attributions' in extraOutputs.extra[i]) {
attributions.push(extraOutputs.extra[i].attributions);
}
}
if (attributions.length > 0) {
this.attributions = {
'indices': extraOutputs.indices,
'attributions': attributions,
};
}

// Add extra output information to datapoints
for (let i = 0; i < extraOutputs.indices.length; i++) {
const idx = extraOutputs.indices[i];
const datapoint = Object.assign({}, this.visdata[idx]);
for (
let modelNum = 0;
modelNum < extraOutputs.extra.length;
modelNum++
) {
const keys = Object.keys(extraOutputs.extra[modelNum]);
for (let j = 0; j < keys.length; j++) {
const key = keys[j];
// Skip attributions as they are handled separately above.
if (key == 'attributions') {
continue;
}
let val = extraOutputs.extra[modelNum][key][i];

// Update the datapoint with the extra info for use in
// Facets Dive.
datapoint[datapointKey] = val;

// Convert the extra output into an array if necessary, for
// insertion into tf.Example as a value list, for update of
// examplesAndInferences for the example viewer.
if (!Array.isArray(val)) {
val = [val];
}
const isString = val.length > 0 &&
(typeof val[0] == 'string' || val[0] instanceof String);
const datapointKey = this.strWithModelName_(key, modelNum);
const exampleJsonString = JSON.stringify(
this.examplesAndInferences[idx].example
);
const copiedExample = JSON.parse(exampleJsonString);
copiedExample.features.feature[datapointKey] = isString ?
{bytesList: {value: val}} : {floatList: {value: val}};
this.examplesAndInferences[idx].example = copiedExample;
}
}
this.set(`visdata.${idx}`, datapoint);
}
this.refreshDive_();

// Update selected datapoint so that if a datapoint is being viewed,
// the display is updated with the appropriate extra output.
this.computeSelectedExampleAndInference();
this.refreshSelectedDatapoint_();
},

newAttributions_: function(attributions) {
if (Object.keys(attributions).length == 0) {
return;
Expand Down
24 changes: 13 additions & 11 deletions tensorboard/plugins/interactive_inference/utils/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,12 +613,12 @@ def get_example_features(example):

def run_inference_for_inference_results(examples, serving_bundle):
"""Calls servo and wraps the inference results."""
(inference_result_proto, attributions) = run_inference(
(inference_result_proto, extra_results) = run_inference(
examples, serving_bundle)
inferences = wrap_inference_results(inference_result_proto)
infer_json = json_format.MessageToJson(
inferences, including_default_value_fields=True)
return json.loads(infer_json), attributions
return json.loads(infer_json), extra_results

def get_eligible_features(examples, num_mutants):
"""Returns a list of JSON objects for each feature in the examples.
Expand Down Expand Up @@ -740,8 +740,8 @@ def run_inference(examples, serving_bundle):

Returns:
A tuple with the first entry being the ClassificationResponse or
RegressionResponse proto and the second entry being a list of the
attributions for each example, or None if no attributions exist.
RegressionResponse proto and the second entry being a dictionary of extra
data for each example, such as attributions, or None if no data exists.
"""
batch_size = 64
if serving_bundle.estimator and serving_bundle.feature_spec:
Expand All @@ -767,14 +767,16 @@ def run_inference(examples, serving_bundle):
# If custom_predict_fn is provided, pass examples directly for local
# inference.
values = serving_bundle.custom_predict_fn(examples)
attributions = None
extra_results = None
# If the custom prediction function returned a dict, then parse out the
# prediction scores and the attributions. If it is just a list, then the
# results are the prediction results without attributions.
# prediction scores. If it is just a list, then the results are the
# prediction results without attributions or other data.
if isinstance(values, dict):
attributions = values['attributions']
values = values['predictions']
return (common_utils.convert_prediction_values(values, serving_bundle),
attributions)
preds = values.pop('predictions')
extra_results = values
else:
preds = values
return (common_utils.convert_prediction_values(preds, serving_bundle),
extra_results)
else:
return (platform_utils.call_servo(examples, serving_bundle), None)
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def infer_impl(self):
examples_to_infer = [
self.json_to_proto(self.examples[index]) for index in indices_to_infer]
infer_objs = []
attribution_objs = []
extra_output_objs = []
serving_bundle = inference_utils.ServingBundle(
self.config.get('inference_address'),
self.config.get('model_name'),
Expand All @@ -137,11 +137,11 @@ def infer_impl(self):
self.estimator_and_spec.get('estimator'),
self.estimator_and_spec.get('feature_spec'),
self.custom_predict_fn)
(predictions, attributions) = (
(predictions, extra_output) = (
inference_utils.run_inference_for_inference_results(
examples_to_infer, serving_bundle))
infer_objs.append(predictions)
attribution_objs.append(attributions)
extra_output_objs.append(extra_output)
if ('inference_address_2' in self.config or
self.compare_estimator_and_spec.get('estimator') or
self.compare_custom_predict_fn):
Expand All @@ -157,16 +157,16 @@ def infer_impl(self):
self.compare_estimator_and_spec.get('estimator'),
self.compare_estimator_and_spec.get('feature_spec'),
self.compare_custom_predict_fn)
(predictions, attributions) = (
(predictions, extra_output) = (
inference_utils.run_inference_for_inference_results(
examples_to_infer, serving_bundle))
infer_objs.append(predictions)
attribution_objs.append(attributions)
extra_output_objs.append(extra_output)
self.updated_example_indices = set()
return {
'inferences': {'indices': indices_to_infer, 'results': infer_objs},
'label_vocab': self.config.get('label_vocab'),
'attributions': attribution_objs}
'extra_outputs': extra_output_objs}

def infer_mutants_impl(self, info):
"""Performs mutant inference on specified examples."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def infer_mutants(wit_id, details):
window.inferenceCallback = inferences => {{
wit.labelVocab = inferences.label_vocab;
wit.inferences = inferences.inferences;
wit.attributions = {{indices: wit.inferences.indices,
attributions: inferences.attributions}}
wit.extraOutputs = {{indices: wit.inferences.indices,
extra: inferences.extra_outputs}}
}};
window.spriteCallback = spriteUrl => {{
if (!wit.updateSprite) {{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,9 @@ var WITView = widgets.DOMWidgetView.extend({
const inferences = this.model.get('inferences');
this.view_.labelVocab = inferences['label_vocab'];
this.view_.inferences = inferences['inferences'];
this.view_.attributions = {
this.view_.extraOutputs = {
indices: this.view_.inferences.indices,
attributions: inferences['attributions'],
extra: inferences['extra_outputs'],
};
},
eligibleFeaturesChanged: function() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,11 @@ def set_custom_predict_fn(self, predict_fn):
- For regression: A 1D list of numbers, with a regression score for each
example being predicted.

Optionally, if attributions can be returned by the model with each
prediction, then this method can return a dict with the key 'predictions'
containing the predictions result list described above, and with the key
'attributions' containing a list of attributions for each example that was
predicted.
Optionally, if attributions or other prediction-time information
can be returned by the model with each prediction, then this method
can return a dict with the key 'predictions' containing the predictions
result list described above, and with the key 'attributions' containing
a list of attributions for each example that was predicted.

For each example, the attributions list should contain a dict mapping
input feature names to attribution values for that feature on that example.
Expand All @@ -432,6 +432,12 @@ def set_custom_predict_fn(self, predict_fn):
a list of attribution values for the corresponding feature values in
the first list.

This dict can contain any other keys, with their values being a list of
prediction-time strings or numbers for each example being predicted. These
values will be displayed in WIT as extra information for each example,
usable in the same ways by WIT as normal input features (such as for
creating plots and slicing performance data).

Args:
predict_fn: The custom python function which will be used for model
inference.
Expand Down Expand Up @@ -464,11 +470,11 @@ def set_compare_custom_predict_fn(self, predict_fn):
- For regression: A 1D list of numbers, with a regression score for each
example being predicted.

Optionally, if attributions can be returned by the model with each
prediction, then this method can return a dict with the key 'predictions'
containing the predictions result list described above, and with the key
'attributions' containing a list of attributions for each example that was
predicted.
Optionally, if attributions or other prediction-time information
can be returned by the model with each prediction, then this method
can return a dict with the key 'predictions' containing the predictions
result list described above, and with the key 'attributions' containing
a list of attributions for each example that was predicted.

For each example, the attributions list should contain a dict mapping
input feature names to attribution values for that feature on that example.
Expand All @@ -482,6 +488,12 @@ def set_compare_custom_predict_fn(self, predict_fn):
a list of attribution values for the corresponding feature values in
the first list.

This dict can contain any other keys, with their values being a list of
prediction-time strings or numbers for each example being predicted. These
values will be displayed in WIT as extra information for each example,
usable in the same ways by WIT as normal input features (such as for
creating plots and slicing performance data).

Args:
predict_fn: The custom python function which will be used for model
inference.
Expand Down