Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference error messages in WIT notebook mode #2414

Merged
merged 5 commits into from
Jul 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ tf_web_library(
"@org_polymer_paper_spinner",
"@org_polymer_paper_styles",
"@org_polymer_paper_tabs",
"@org_polymer_paper_toast",
"@org_polymer_paper_toggle_button",
"//tensorboard/components/tf_backend",
"//tensorboard/components/tf_dashboard_common",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
<link rel="import" href="../paper-spinner/paper-spinner-lite.html">
<link rel="import" href="../paper-tabs/paper-tab.html">
<link rel="import" href="../paper-tabs/paper-tabs.html">
<link rel="import" href="../paper-toast/paper-toast.html">
<link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../tf-backend/tf-backend.html">
<link rel="import" href="../tf-dashboard-common/dashboard-style.html">
Expand Down Expand Up @@ -1501,7 +1502,7 @@ <h2>Show similarity to selected datapoint</h2>
<div id="noexamples" class="noexamples info-text">
Datapoints and their inference results will be displayed here.
</div>
<paper-spinner-lite id="spinner" hidden active></paper-spinner-lite>
<paper-spinner-lite id="spinner" hidden="[[spinnerHidden_]]" active></paper-spinner-lite>
<div class="feature-container-holder" id="partialplotholder">
<div class="pd-plots-header">
<div class="flex">
Expand Down Expand Up @@ -2651,7 +2652,12 @@ <h2>Show similarity to selected datapoint</h2>
allConfMatrixLabels: {
type: Array,
value: () => ([]),
}
},
// Controls if the loading spinner is hidden from view.
spinnerHidden_: {
type: Boolean,
value: true,
},
},

observers: [
Expand Down Expand Up @@ -4307,7 +4313,7 @@ <h2>Show similarity to selected datapoint</h2>
},

newInferences_: function() {
this.$.spinner.hidden = true;
this.spinnerHidden_ = true;
this.updateInferences_(true);
requestAnimationFrame(() => this.updateInferenceStats_(true));
},
Expand Down Expand Up @@ -4438,6 +4444,9 @@ <h2>Show similarity to selected datapoint</h2>
: this.strWithModelName_(inferenceValueStr, 0);
if (this.isRegression_(this.modelType)) {
this.$.dive.horizontalPosition = this.strWithModelName_(inferenceValueStr, 0);
if (this.numModels > 1) {
this.$.dive.verticalPosition = this.strWithModelName_(inferenceValueStr, 1);
}
}
else if (this.isBinaryClassification_(this.modelType, this.multiClass)) {
if (this.numModels == 1) {
Expand All @@ -4461,7 +4470,8 @@ <h2>Show similarity to selected datapoint</h2>
// TODO(jwexler): Support attributions from multiple models.
// For now, we only display attribution from the first model, if WIT
// is loaded with two models.
const attribs = attributions.attributions[0][i];
const attribs = attributions.attributions[0] == null ? {} :
attributions.attributions[0][i];
const keys = Object.keys(attribs);
for (let j = 0; j < keys.length; j++) {
// If the attributions for a key is a 2D array then treat the first
Expand Down Expand Up @@ -4583,16 +4593,16 @@ <h2>Show similarity to selected datapoint</h2>
'use_predict': this.usePredictApi,
'predict_output_tensor': this.predictOutputTensor,
'predict_input_tensor': this.predictInputTensor};
this.$.spinner.hidden = false;
this.spinnerHidden_ = false;
if (!this.local) {
const url = this.makeUrl_('/data/plugin/whatif/infer',
inferParams);
const inferContents = result => {
this.$.spinner.hidden = true;
this.spinnerHidden_ = true;
this.labelVocab = /** @type {!Array} */ (JSON.parse(result.value.vocab));
this.inferences = /** @type {!Object} */ (JSON.parse(result.value.inferences));
};
this.makeAsyncRequest_(url, inferContents, null);
this.makeAsyncRequest_(url, inferContents, null, 'model inference');
}
this.fire('infer-examples', inferParams);
},
Expand Down Expand Up @@ -4676,7 +4686,8 @@ <h2>Show similarity to selected datapoint</h2>
var url = this.makeUrl_('/data/plugin/whatif/update_example', null);

this.makeAsyncRequest_(url, null, {'example': exampleJson,
'index': index});
'index': index},
'datapoint update');
}
},

Expand Down Expand Up @@ -4704,16 +4715,20 @@ <h2>Show similarity to selected datapoint</h2>
console.error(msg);
},

handleError: function(errorStr) {
this.showToast_(errorStr);
this.exampleStatusStr = errorStr;
this.spinnerHidden_ = true;
},

makeAsyncRequest_: function(
url, thenDoFn, postData, errorFn = () => {}) {
url, thenDoFn, postData, readableRequestName, errorFn = () => {}) {
const wrapperFn = this._canceller.cancellable(result => {
if (result.cancelled) {
return;
}
if (result.value && result.value.error){
// show toast with the error
this.showToast_(result.value.error);
this.$.spinner.hidden = true;
this.handleError(result.value.error);
if (errorFn != null) {
errorFn();
}
Expand All @@ -4723,9 +4738,8 @@ <h2>Show similarity to selected datapoint</h2>
});
this._requestManager.request(url, postData).then(wrapperFn)
.catch(reason => {
this.exampleStatusStr = 'Request failed';
this.showToast_('Request failed: ' + reason);
this.$.spinner.hidden = true;
this.handleError(
`Request for ${readableRequestName} failed: ${reason}`);
if (errorFn != null) {
errorFn();
}
Expand Down Expand Up @@ -4835,7 +4849,7 @@ <h2>Show similarity to selected datapoint</h2>
updateExampleContents: function(examples, hasSprite) {
this.exampleStatusStr = examples.length + ' datapoints loaded';
this.$.noexamples.style.display = 'none';
this.$.spinner.hidden = true;
this.spinnerHidden_ = true;
this.examplesAndInferences = examples.map(function(ex) {
const example = JSON.parse(ex);
return {example: example, changed: false, orig: JSON.parse(ex)};});
Expand Down Expand Up @@ -4908,8 +4922,8 @@ <h2>Show similarity to selected datapoint</h2>
result.value.examples, result.value.sprite);
};
this.exampleStatusStr = 'Loading datapoints...'
this.makeAsyncRequest_(url, updateExampleContents, null);
this.$.spinner.hidden = false;
this.makeAsyncRequest_(url, updateExampleContents, null, 'datapoint load');
this.spinnerHidden_ = false;
},

updateSprite: function() {
Expand Down Expand Up @@ -4963,7 +4977,7 @@ <h2>Show similarity to selected datapoint</h2>
};
const url = this.makeUrl_('/data/plugin/whatif/duplicate_example',
{'index': duplicatedIndex});
this.makeAsyncRequest_(url, refreshDiveAfterDuplicate, null);
this.makeAsyncRequest_(url, refreshDiveAfterDuplicate, null, 'datapoint duplication');
} else {
this.refreshDive_();
}
Expand Down Expand Up @@ -4998,7 +5012,7 @@ <h2>Show similarity to selected datapoint</h2>
};
const url = this.makeUrl_('/data/plugin/whatif/delete_example',
{'index': deletedIndex});
this.makeAsyncRequest_(url, refreshDiveAfterDelete, null);
this.makeAsyncRequest_(url, refreshDiveAfterDelete, null, 'datapoint delete');
} else {
this.refreshDive_();
}
Expand Down Expand Up @@ -5400,16 +5414,6 @@ <h2>Show similarity to selected datapoint</h2>
return {step: origValue, scalar: origInferenceScore};
},

showToast: function(msg) {
const toast = document.createElement('paper-toast');
document.body.appendChild(toast);
toast.text = msg;
toast.show();

// Also, log to console.
console.error(msg);
},

deletePdPlotSpinner: function(featureName) {
const container = this.featureContainerByName(featureName);
deleteElement(container.querySelector('paper-spinner-lite'));
Expand Down Expand Up @@ -5466,7 +5470,7 @@ <h2>Show similarity to selected datapoint</h2>
}
this.makeAsyncRequest_(
url, chartMakerCallback.bind(this), null,
chartErrorCallback.bind(this));
'plot creation', chartErrorCallback.bind(this));
} else {
this.fire('infer-mutants', urlParams);
}
Expand Down Expand Up @@ -5656,7 +5660,7 @@ <h2>Show similarity to selected datapoint</h2>
const setEligibleFields = result => {
this.set('partialDepPlotEligibleFeatures', result.value);
};
this.makeAsyncRequest_(url, setEligibleFields, null);
this.makeAsyncRequest_(url, setEligibleFields, null, 'plot setup');
} else {
this.fire('get-eligible-features');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,14 @@ def infer_mutants(wit_id, details):

// Javascript callbacks called by python code to communicate with WIT
// Polymer element.
window.backendError = error => {{
wit.handleError(error.msg);
}};
window.inferenceCallback = inferences => {{
const parsedInferences = JSON.parse(inferences);
wit.labelVocab = parsedInferences.label_vocab;
wit.inferences = parsedInferences.inferences;
wit.labelVocab = inferences.label_vocab;
wit.inferences = inferences.inferences;
wit.attributions = {{indices: wit.inferences.indices,
attributions: parsedInferences.attributions}}
attributions: inferences.attributions}}
}};
window.spriteCallback = spriteUrl => {{
if (!wit.updateSprite) {{
Expand All @@ -107,20 +109,17 @@ def infer_mutants(wit_id, details):
wit.updateSprite();
}};
window.eligibleFeaturesCallback = features => {{
const parsedFeatures = JSON.parse(features);
wit.partialDepPlotEligibleFeatures = parsedFeatures;
wit.partialDepPlotEligibleFeatures = features;
}};
window.inferMutantsCallback = jsonMapping => {{
const chartInfo = JSON.parse(jsonMapping);
window.inferMutantsCallback = chartInfo => {{
wit.makeChartForFeature(chartInfo.chartType, mutantFeature,
chartInfo.data);
}};
window.configCallback = jsonConfig => {{
window.configCallback = config => {{
if (!wit.updateNumberOfModels) {{
requestAnimationFrame(() => window.configCallback(jsonConfig));
requestAnimationFrame(() => window.configCallback(config));
return;
}}
const config = JSON.parse(jsonConfig);
if ('inference_address' in config) {{
let addresses = config['inference_address'];
if ('inference_address_2' in config) {{
Expand Down Expand Up @@ -201,16 +200,16 @@ def __init__(self, config_builder, height=1000):
# Display WIT Polymer element.
display.display(display.HTML(self._get_element_html()))
display.display(display.HTML(
WIT_HTML.format(height=height, id=self.id)))
WIT_HTML.format(height=height, id=self.id)))

# Increment the static instance WitWidget index counter
WitWidget.index += 1

# Send the provided config and examples to JS
output.eval_js("""configCallback('{config}')""".format(
config=json.dumps(self.config)))
output.eval_js("""configCallback({config})""".format(
config=json.dumps(self.config)))
output.eval_js("""updateExamplesCallback({examples})""".format(
examples=json.dumps(self.examples)))
examples=json.dumps(self.examples)))
self._generate_sprite()
self._ctor_complete = True

Expand All @@ -228,19 +227,23 @@ def set_examples(self, examples):
# cell from the cell that displays WIT.
channel_name = 'updateExamples{}'.format(self.id)
output.eval_js("""(new BroadcastChannel('{channel_name}')).postMessage(
{examples})""".format(
examples=json.dumps(self.examples), channel_name=channel_name))
{examples})""".format(
examples=json.dumps(self.examples), channel_name=channel_name))
self._generate_sprite()

def infer(self):
inferences = base.WitWidgetBase.infer_impl(self)
output.eval_js("""inferenceCallback('{inferences}')""".format(
inferences=json.dumps(inferences)))
try:
inferences = base.WitWidgetBase.infer_impl(self)
output.eval_js("""inferenceCallback({inferences})""".format(
inferences=json.dumps(inferences)))
except Exception as e:
output.eval_js("""backendError({error})""".format(
error=json.dumps({'msg': str(e)})))

def delete_example(self, index):
self.examples.pop(index)
self.updated_example_indices = set([
i if i < index else i - 1 for i in self.updated_example_indices])
i if i < index else i - 1 for i in self.updated_example_indices])
self._generate_sprite()

def update_example(self, index, example):
Expand All @@ -255,13 +258,17 @@ def duplicate_example(self, index):

def get_eligible_features(self):
features_list = base.WitWidgetBase.get_eligible_features_impl(self)
output.eval_js("""eligibleFeaturesCallback('{features_list}')""".format(
features_list=json.dumps(features_list)))
output.eval_js("""eligibleFeaturesCallback({features_list})""".format(
features_list=json.dumps(features_list)))

def infer_mutants(self, info):
json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info)
output.eval_js("""inferMutantsCallback('{json_mapping}')""".format(
json_mapping=json.dumps(json_mapping)))
try:
json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info)
output.eval_js("""inferMutantsCallback({json_mapping})""".format(
json_mapping=json.dumps(json_mapping)))
except Exception as e:
output.eval_js("""backendError({error})""".format(
error=json.dumps({'msg': str(e)})))

def _generate_sprite(self):
sprite = base.WitWidgetBase.create_sprite(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var WITView = widgets.DOMWidgetView.extend({
this.eligibleFeaturesChanged, this);
this.model.on('change:mutant_charts', this.mutantChartsChanged, this);
this.model.on('change:sprite', this.spriteChanged, this);
this.model.on('change:error', this.backendError, this);
},

/**
Expand Down Expand Up @@ -229,6 +230,10 @@ var WITView = widgets.DOMWidgetView.extend({
this.view_.localAtlasUrl = spriteUrl;
this.view_.updateSprite();
},
backendError: function() {
const error = this.model.get('error');
this.view_.handleError(error['msg']);
},
});

module.exports = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class WitWidget(widgets.DOMWidget, base.WitWidgetBase):
mutant_charts = Dict([]).tag(sync=True)
mutant_charts_counter = Int(0)
sprite = Unicode('').tag(sync=True)
error = Dict(dict()).tag(sync=True)

def __init__(self, config_builder, height=1000):
"""Constructor for Jupyter notebook WitWidget.
Expand All @@ -57,6 +58,7 @@ def __init__(self, config_builder, height=1000):
"""
widgets.DOMWidget.__init__(self, layout=Layout(height='%ipx' % height))
base.WitWidgetBase.__init__(self, config_builder)
self.error_counter = 0

# Ensure the visualization takes all available width.
display(HTML("<style>.container { width:100% !important; }</style>"))
Expand All @@ -65,9 +67,19 @@ def set_examples(self, examples):
base.WitWidgetBase.set_examples(self, examples)
self._generate_sprite()

def _report_error(self, err):
self.error = {
'msg': str(err),
'counter': self.error_counter
}
self.error_counter += 1

@observe('infer')
def _infer(self, change):
self.inferences = base.WitWidgetBase.infer_impl(self)
try:
self.inferences = base.WitWidgetBase.infer_impl(self)
except Exception as e:
self._report_error(e)

# Observer callbacks for changes from javascript.
@observe('get_eligible_features')
Expand All @@ -78,10 +90,13 @@ def _get_eligible_features(self, change):
@observe('infer_mutants')
def _infer_mutants(self, change):
info = self.infer_mutants
json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info)
json_mapping['counter'] = self.mutant_charts_counter
self.mutant_charts_counter += 1
self.mutant_charts = json_mapping
try:
json_mapping = base.WitWidgetBase.infer_mutants_impl(self, info)
json_mapping['counter'] = self.mutant_charts_counter
self.mutant_charts_counter += 1
self.mutant_charts = json_mapping
except Exception as e:
self._report_error(e)

@observe('update_example')
def _update_example(self, change):
Expand Down