Skip to content

Commit

Permalink
Add inference error messages in WIT notebook mode (#2414)
Browse files Browse the repository at this point in the history
If a model fails to run inference in WitWidget, no error is currently displayed so the user has no idea why WIT isn't working correctly. This change adds in proper error handling.
  • Loading branch information
jameswex authored Jul 15, 2019
1 parent ee24f0a commit 839ceac
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ tf_web_library(
"@org_polymer_paper_spinner",
"@org_polymer_paper_styles",
"@org_polymer_paper_tabs",
"@org_polymer_paper_toast",
"@org_polymer_paper_toggle_button",
"//tensorboard/components/tf_backend",
"//tensorboard/components/tf_dashboard_common",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
<link rel="import" href="../paper-spinner/paper-spinner-lite.html">
<link rel="import" href="../paper-tabs/paper-tab.html">
<link rel="import" href="../paper-tabs/paper-tabs.html">
<link rel="import" href="../paper-toast/paper-toast.html">
<link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../tf-backend/tf-backend.html">
<link rel="import" href="../tf-dashboard-common/dashboard-style.html">
Expand Down Expand Up @@ -1501,7 +1502,7 @@ <h2>Show similarity to selected datapoint</h2>
<div id="noexamples" class="noexamples info-text">
Datapoints and their inference results will be displayed here.
</div>
<paper-spinner-lite id="spinner" hidden active></paper-spinner-lite>
<paper-spinner-lite id="spinner" hidden="[[spinnerHidden_]]" active></paper-spinner-lite>
<div class="feature-container-holder" id="partialplotholder">
<div class="pd-plots-header">
<div class="flex">
Expand Down Expand Up @@ -2651,7 +2652,12 @@ <h2>Show similarity to selected datapoint</h2>
allConfMatrixLabels: {
type: Array,
value: () => ([]),
}
},
// Controls if the loading spinner is hidden from view.
spinnerHidden_: {
type: Boolean,
value: true,
},
},

observers: [
Expand Down Expand Up @@ -4307,7 +4313,7 @@ <h2>Show similarity to selected datapoint</h2>
},

newInferences_: function() {
this.$.spinner.hidden = true;
this.spinnerHidden_ = true;
this.updateInferences_(true);
requestAnimationFrame(() => this.updateInferenceStats_(true));
},
Expand Down Expand Up @@ -4438,6 +4444,9 @@ <h2>Show similarity to selected datapoint</h2>
: this.strWithModelName_(inferenceValueStr, 0);
if (this.isRegression_(this.modelType)) {
this.$.dive.horizontalPosition = this.strWithModelName_(inferenceValueStr, 0);
if (this.numModels > 1) {
this.$.dive.verticalPosition = this.strWithModelName_(inferenceValueStr, 1);
}
}
else if (this.isBinaryClassification_(this.modelType, this.multiClass)) {
if (this.numModels == 1) {
Expand All @@ -4461,7 +4470,8 @@ <h2>Show similarity to selected datapoint</h2>
// TODO(jwexler): Support attributions from multiple models.
// For now, we only display attribution from the first model, if WIT
// is loaded with two models.
const attribs = attributions.attributions[0][i];
const attribs = attributions.attributions[0] == null ? {} :
attributions.attributions[0][i];
const keys = Object.keys(attribs);
for (let j = 0; j < keys.length; j++) {
// If the attributions for a key is a 2D array then treat the first
Expand Down Expand Up @@ -4583,16 +4593,16 @@ <h2>Show similarity to selected datapoint</h2>
'use_predict': this.usePredictApi,
'predict_output_tensor': this.predictOutputTensor,
'predict_input_tensor': this.predictInputTensor};
this.$.spinner.hidden = false;
this.spinnerHidden_ = false;
if (!this.local) {
const url = this.makeUrl_('/data/plugin/whatif/infer',
inferParams);
const inferContents = result => {
this.$.spinner.hidden = true;
this.spinnerHidden_ = true;
this.labelVocab = /** @type {!Array} */ (JSON.parse(result.value.vocab));
this.inferences = /** @type {!Object} */ (JSON.parse(result.value.inferences));
};
this.makeAsyncRequest_(url, inferContents, null);
this.makeAsyncRequest_(url, inferContents, null, 'model inference');
}
this.fire('infer-examples', inferParams);
},
Expand Down Expand Up @@ -4676,7 +4686,8 @@ <h2>Show similarity to selected datapoint</h2>
var url = this.makeUrl_('/data/plugin/whatif/update_example', null);

this.makeAsyncRequest_(url, null, {'example': exampleJson,
'index': index});
'index': index},
'datapoint update');
}
},

Expand Down Expand Up @@ -4704,16 +4715,20 @@ <h2>Show similarity to selected datapoint</h2>
console.error(msg);
},

handleError: function(errorStr) {
this.showToast_(errorStr);
this.exampleStatusStr = errorStr;
this.spinnerHidden_ = true;
},

makeAsyncRequest_: function(
url, thenDoFn, postData, errorFn = () => {}) {
url, thenDoFn, postData, readableRequestName, errorFn = () => {}) {
const wrapperFn = this._canceller.cancellable(result => {
if (result.cancelled) {
return;
}
if (result.value && result.value.error){
// show toast with the error
this.showToast_(result.value.error);
this.$.spinner.hidden = true;
this.handleError(result.value.error);
if (errorFn != null) {
errorFn();
}
Expand All @@ -4723,9 +4738,8 @@ <h2>Show similarity to selected datapoint</h2>
});
this._requestManager.request(url, postData).then(wrapperFn)
.catch(reason => {
this.exampleStatusStr = 'Request failed';
this.showToast_('Request failed: ' + reason);
this.$.spinner.hidden = true;
this.handleError(
`Request for ${readableRequestName} failed: ${reason}`);
if (errorFn != null) {
errorFn();
}
Expand Down Expand Up @@ -4835,7 +4849,7 @@ <h2>Show similarity to selected datapoint</h2>
updateExampleContents: function(examples, hasSprite) {
this.exampleStatusStr = examples.length + ' datapoints loaded';
this.$.noexamples.style.display = 'none';
this.$.spinner.hidden = true;
this.spinnerHidden_ = true;
this.examplesAndInferences = examples.map(function(ex) {
const example = JSON.parse(ex);
return {example: example, changed: false, orig: JSON.parse(ex)};});
Expand Down Expand Up @@ -4908,8 +4922,8 @@ <h2>Show similarity to selected datapoint</h2>
result.value.examples, result.value.sprite);
};
this.exampleStatusStr = 'Loading datapoints...'
this.makeAsyncRequest_(url, updateExampleContents, null);
this.$.spinner.hidden = false;
this.makeAsyncRequest_(url, updateExampleContents, null, 'datapoint load');
this.spinnerHidden_ = false;
},

updateSprite: function() {
Expand Down Expand Up @@ -4963,7 +4977,7 @@ <h2>Show similarity to selected datapoint</h2>
};
const url = this.makeUrl_('/data/plugin/whatif/duplicate_example',
{'index': duplicatedIndex});
this.makeAsyncRequest_(url, refreshDiveAfterDuplicate, null);
this.makeAsyncRequest_(url, refreshDiveAfterDuplicate, null, 'datapoint duplication');
} else {
this.refreshDive_();
}
Expand Down Expand Up @@ -4998,7 +5012,7 @@ <h2>Show similarity to selected datapoint</h2>
};
const url = this.makeUrl_('/data/plugin/whatif/delete_example',
{'index': deletedIndex});
this.makeAsyncRequest_(url, refreshDiveAfterDelete, null);
this.makeAsyncRequest_(url, refreshDiveAfterDelete, null, 'datapoint delete');
} else {
this.refreshDive_();
}
Expand Down Expand Up @@ -5400,16 +5414,6 @@ <h2>Show similarity to selected datapoint</h2>
return {step: origValue, scalar: origInferenceScore};
},

showToast: function(msg) {
const toast = document.createElement('paper-toast');
document.body.appendChild(toast);
toast.text = msg;
toast.show();

// Also, log to console.
console.error(msg);
},

deletePdPlotSpinner: function(featureName) {
const container = this.featureContainerByName(featureName);
deleteElement(container.querySelector('paper-spinner-lite'));
Expand Down Expand Up @@ -5466,7 +5470,7 @@ <h2>Show similarity to selected datapoint</h2>
}
this.makeAsyncRequest_(
url, chartMakerCallback.bind(this), null,
chartErrorCallback.bind(this));
'plot creation', chartErrorCallback.bind(this));
} else {
this.fire('infer-mutants', urlParams);
}
Expand Down Expand Up @@ -5656,7 +5660,7 @@ <h2>Show similarity to selected datapoint</h2>
const setEligibleFields = result => {
this.set('partialDepPlotEligibleFeatures', result.value);
};
this.makeAsyncRequest_(url, setEligibleFields, null);
this.makeAsyncRequest_(url, setEligibleFields, null, 'plot setup');
} else {
this.fire('get-eligible-features');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,14 @@ def infer_mutants(wit_id, details):
// Javascript callbacks called by python code to communicate with WIT
// Polymer element.
window.backendError = error => {{
wit.handleError(error.msg);
}};
window.inferenceCallback = inferences => {{
const parsedInferences = JSON.parse(inferences);
wit.labelVocab = parsedInferences.label_vocab;
wit.inferences = parsedInferences.inferences;
wit.labelVocab = inferences.label_vocab;
wit.inferences = inferences.inferences;
wit.attributions = {{indices: wit.inferences.indices,
attributions: parsedInferences.attributions}}
attributions: inferences.attributions}}
}};
window.spriteCallback = spriteUrl => {{
if (!wit.updateSprite) {{
Expand All @@ -107,20 +109,17 @@ def infer_mutants(wit_id, details):
wit.updateSprite();
}};
window.eligibleFeaturesCallback = features => {{
const parsedFeatures = JSON.parse(features);
wit.partialDepPlotEligibleFeatures = parsedFeatures;
wit.partialDepPlotEligibleFeatures = features;
}};
window.inferMutantsCallback = jsonMapping => {{
const chartInfo = JSON.parse(jsonMapping);
window.inferMutantsCallback = chartInfo => {{
wit.makeChartForFeature(chartInfo.chartType, mutantFeature,
chartInfo.data);
}};
window.configCallback = jsonConfig => {{
window.configCallback = config => {{
if (!wit.updateNumberOfModels) {{
requestAnimationFrame(() => window.configCallback(jsonConfig));
requestAnimationFrame(() => window.configCallback(config));
return;
}}
const config = JSON.parse(jsonConfig);
if ('inference_address' in config) {{
let addresses = config['inference_address'];
if ('inference_address_2' in config) {{
Expand Down Expand Up @@ -201,16 +200,16 @@ def __init__(self, config_builder, height=1000):
# Display WIT Polymer element.
display.display(display.HTML(self._get_element_html()))
display.display(display.HTML(
WIT_HTML.format(height=height, id=self.id)))
WIT_HTML.format(height=height, id=self.id)))

# Increment the static instance WitWidget index counter
WitWidget.index += 1

# Send the provided config and examples to JS
output.eval_js("""configCallback('{config}')""".format(
config=json.dumps(self.config)))
output.eval_js("""configCallback({config})""".format(
config=json.dumps(self.config)))
output.eval_js("""updateExamplesCallback({examples})""".format(
examples=json.dumps(self.examples)))
examples=json.dumps(self.examples)))
self._generate_sprite()
self._ctor_complete = True

Expand All @@ -228,19 +227,23 @@ def set_examples(self, examples):
# cell from the cell that displays WIT.
channel_name = 'updateExamples{}'.format(self.id)
output.eval_js("""(new BroadcastChannel('{channel_name}')).postMessage(
{examples})""".format(
examples=json.dumps(self.examples), channel_name=channel_name))
{examples})""".format(
examples=json.dumps(self.examples), channel_name=channel_name))
self._generate_sprite()

def infer(self):
inferences = base.WitWidgetBase.infer_impl(self)
output.eval_js("""inferenceCallback('{inferences}')""".format(
inferences=json.dumps(inferences)))
try:
inferences = base.WitWidgetBase.infer_impl(self)
output.eval_js("""inferenceCallback({inferences})""".format(
inferences=json.dumps(inferences)))
except Exception as e:
output.eval_js("""backendError({error})""".format(
error=json.dumps({'msg': str(e)})))

def delete_example(self, index):
self.examples.pop(index)
self.updated_example_indices = set([
i if i < index else i - 1 for i in self.updated_example_indices])
i if i < index else i - 1 for i in self.updated_example_indices])
self._generate_sprite()

def update_example(self, index, example):
Expand All @@ -255,13 +258,17 @@ def duplicate_example(self, index):

def get_eligible_features(self):
features_list = base.WitWidgetBase.get_eligible_features_impl(self)
output.eval_js("""eligibleFeaturesCallback('{features_list}')""".format(
features_list=json.dumps(features_list)))
output.eval_js("""eligibleFeaturesCallback({features_list})""".format(
features_list=json.dumps(features_list)))

def infer_mutants(self, info):
json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info)
output.eval_js("""inferMutantsCallback('{json_mapping}')""".format(
json_mapping=json.dumps(json_mapping)))
try:
json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info)
output.eval_js("""inferMutantsCallback({json_mapping})""".format(
json_mapping=json.dumps(json_mapping)))
except Exception as e:
output.eval_js("""backendError({error})""".format(
error=json.dumps({'msg': str(e)})))

def _generate_sprite(self):
sprite = base.WitWidgetBase.create_sprite(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var WITView = widgets.DOMWidgetView.extend({
this.eligibleFeaturesChanged, this);
this.model.on('change:mutant_charts', this.mutantChartsChanged, this);
this.model.on('change:sprite', this.spriteChanged, this);
this.model.on('change:error', this.backendError, this);
},

/**
Expand Down Expand Up @@ -229,6 +230,10 @@ var WITView = widgets.DOMWidgetView.extend({
this.view_.localAtlasUrl = spriteUrl;
this.view_.updateSprite();
},
backendError: function() {
const error = this.model.get('error');
this.view_.handleError(error['msg']);
},
});

module.exports = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class WitWidget(widgets.DOMWidget, base.WitWidgetBase):
mutant_charts = Dict([]).tag(sync=True)
mutant_charts_counter = Int(0)
sprite = Unicode('').tag(sync=True)
error = Dict(dict()).tag(sync=True)

def __init__(self, config_builder, height=1000):
"""Constructor for Jupyter notebook WitWidget.
Expand All @@ -57,6 +58,7 @@ def __init__(self, config_builder, height=1000):
"""
widgets.DOMWidget.__init__(self, layout=Layout(height='%ipx' % height))
base.WitWidgetBase.__init__(self, config_builder)
self.error_counter = 0

# Ensure the visualization takes all available width.
display(HTML("<style>.container { width:100% !important; }</style>"))
Expand All @@ -65,9 +67,19 @@ def set_examples(self, examples):
base.WitWidgetBase.set_examples(self, examples)
self._generate_sprite()

def _report_error(self, err):
self.error = {
'msg': str(err),
'counter': self.error_counter
}
self.error_counter += 1

@observe('infer')
def _infer(self, change):
self.inferences = base.WitWidgetBase.infer_impl(self)
try:
self.inferences = base.WitWidgetBase.infer_impl(self)
except Exception as e:
self._report_error(e)

# Observer callbacks for changes from javascript.
@observe('get_eligible_features')
Expand All @@ -78,10 +90,13 @@ def _get_eligible_features(self, change):
@observe('infer_mutants')
def _infer_mutants(self, change):
info = self.infer_mutants
json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info)
json_mapping['counter'] = self.mutant_charts_counter
self.mutant_charts_counter += 1
self.mutant_charts = json_mapping
try:
json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info)
json_mapping['counter'] = self.mutant_charts_counter
self.mutant_charts_counter += 1
self.mutant_charts = json_mapping
except Exception as e:
self._report_error(e)

@observe('update_example')
def _update_example(self, change):
Expand Down

0 comments on commit 839ceac

Please sign in to comment.