Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to set custom distance function for counterfactuals #2607

Merged
merged 44 commits into from
Sep 19, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
ae2daff
midway 1
tolga-b Aug 23, 2019
d6a6920
midway 2
tolga-b Aug 27, 2019
0359655
midway 3
tolga-b Aug 28, 2019
fc6709b
image colab works
tolga-b Aug 29, 2019
31e05ae
colab distance 1
tolga-b Aug 29, 2019
79dd89b
jupyter works
tolga-b Aug 29, 2019
00cade8
ran prettier
tolga-b Aug 29, 2019
ed3a4c6
review comments
tolga-b Sep 3, 2019
299dd47
minor fix
tolga-b Sep 3, 2019
76ddff8
fix lint
tolga-b Sep 3, 2019
0ee315a
fix comment
tolga-b Sep 4, 2019
a5126e8
Merge branch 'distance' of https://github.com/tolga-b/tensorboard int…
tolga-b Sep 4, 2019
9749fd6
fix dom if
tolga-b Sep 4, 2019
1343f30
image demo 1
tolga-b Sep 5, 2019
cc2ab68
image demo 2
tolga-b Sep 5, 2019
3fb28c1
image demo 3
tolga-b Sep 5, 2019
47b5960
image demo 3
tolga-b Sep 5, 2019
243e78e
undo image demo
tolga-b Sep 5, 2019
4eb6e09
Merge branch 'master' of https://github.com/tensorflow/tensorboard in…
tolga-b Sep 5, 2019
9e98468
Merge branch 'distance' of https://github.com/tolga-b/tensorboard int…
tolga-b Sep 5, 2019
674afd4
add distance
tolga-b Sep 10, 2019
4faa9e6
smile demo works
tolga-b Sep 11, 2019
a024f6d
Merge branch 'master' of https://github.com/tensorflow/tensorboard in…
tolga-b Sep 11, 2019
4a9e9e8
add distance embeddings
tolga-b Sep 11, 2019
16039ef
add delete and duplicate example for distance
tolga-b Sep 11, 2019
06a8451
add distance works
tolga-b Sep 11, 2019
eadecfa
changed distance pane
tolga-b Sep 12, 2019
3bccfc8
update readme
tolga-b Sep 12, 2019
d2414d9
update demos
tolga-b Sep 12, 2019
0c4a9c9
review comments fix image demo slice
tolga-b Sep 12, 2019
2469392
memory leak fix
tolga-b Sep 12, 2019
859cfa6
memory leak fix 2
tolga-b Sep 12, 2019
87eb749
ran prettier
tolga-b Sep 12, 2019
d38f7b5
Merge branch 'master' of https://github.com/tensorflow/tensorboard in…
tolga-b Sep 12, 2019
50839a3
smile demo upload image fix
tolga-b Sep 12, 2019
7c82372
review comments
tolga-b Sep 17, 2019
9aa1a72
merge master
tolga-b Sep 17, 2019
8b02c04
review comments
tolga-b Sep 18, 2019
de0ca1b
prettier
tolga-b Sep 18, 2019
69f5d6d
merge master
tolga-b Sep 18, 2019
b87964c
fix bzl
tolga-b Sep 18, 2019
5becf6a
fix bzl 2
tolga-b Sep 18, 2019
a80af0f
fix bzl 3
tolga-b Sep 18, 2019
13a8444
review comments
tolga-b Sep 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1418,16 +1418,27 @@ <h2>Show similarity to selected datapoint</h2>
Show nearest counterfactual datapoint
</paper-toggle-button>
</div>
<paper-radio-group
selected="{{nearestCounterfactualDist}}"
<template
is="dom-if"
if="[[customDistanceFunctionSet]]"
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
>
<paper-radio-button name="L1"
>L1</paper-radio-button
<paper-radio-group
selected="{{nearestCounterfactualDist}}"
>
<paper-radio-button name="L2"
>L2</paper-radio-button
>
</paper-radio-group>
<paper-radio-button name="L1"
>L1</paper-radio-button
>
<paper-radio-button name="L2"
>L2</paper-radio-button
>
</paper-radio-group>
</template>
<template
is="dom-if"
if="[[!customDistanceFunctionSet]]"
>
Using custom distance function.
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
</template>
<paper-dropdown-menu
label="Model:"
no-label-float
Expand Down Expand Up @@ -3360,6 +3371,11 @@ <h2>Show similarity to selected datapoint</h2>
value: '',
observer: 'breakdownFeatureSelected_',
},
// True if an example has been updated.
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
customDistanceFunctionSet: {
type: Boolean,
value: false,
},
// Feature for true label.
selectedLabelFeature: {
type: String,
Expand Down Expand Up @@ -3784,8 +3800,52 @@ <h2>Show similarity to selected datapoint</h2>
}
},

computeClosestCounterfactual: function(exInd, distances) {
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
// Distances are indexed by example ids
const modelInferenceValueStr = this.strWithModelName_(
inferenceValueStr,
this.nearestCounterfactualModelIndex
);
let closestDist = Number.POSITIVE_INFINITY;
let closest = -1;
for (let i = 0; i < this.visdata.length; i++) {
// Skip examples with the same inference class as the selected
// examples.
if (
this.visdata[exInd][modelInferenceValueStr] ==
this.visdata[i][modelInferenceValueStr]
) {
continue;
}
let dist = distances[i];
if (dist < closestDist) {
closestDist = dist;
closest = i;
}
}
if (closest != -1) {
// Display the counterfactual in dive and example viewer.
this.comparedIndices = [closest];
this.counterfactualExampleAndInference = this.examplesAndInferences[
closest
];
this.compareTitle = 'Counterfactual value(s)';
}
},

findClosestCounterfactual_: function() {
const selected = this.selected[0];
// Custom distance function can only be used when local.
// If using custom distance function, request distances and return.
if (this.local && this.customDistanceFunctionSet) {
this.requestDistanceWithCallback(
selected,
'computeClosestCounterfactual',
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
{callbackParams: {}, distanceParams: {}}
);
return;
}

const modelInferenceValueStr = this.strWithModelName_(
inferenceValueStr,
this.nearestCounterfactualModelIndex
Expand Down Expand Up @@ -3821,6 +3881,17 @@ <h2>Show similarity to selected datapoint</h2>
}
},

// Call backend for distance computation, backend calls callback function
// with computed distances and parameters
requestDistanceWithCallback: function(exInd, callbackFun, params) {
const urlParams = {
index: exInd,
callback: callbackFun,
params: params,
};
this.fire('compute-custom-distance', urlParams);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it weird to see notebook specific code in the main dashboard code. What is the expected behavior outside of the notebook?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requestDistanceWithCallback would not be invoked in non-local instances (local=demos and notebook). In case we are in non-local mode, WIT defaults to it's previous behavior of computing counterfactuals completely on the js side with L1 and L2 distance between examples. This is slightly similar in terms of behavior to custom_predict_fn where it is only supported in notebook mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added TODO note here so that if there is a support in TensorBoard mode to provide custom distance functions for counterfactuals then we should update this function to reflect that.

},

/**
* Gets distance between two examples using L1 or L2 distance.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def __init__(self, config_builder):
self.compare_custom_predict_fn = (
config.get('compare_custom_predict_fn')
if 'compare_custom_predict_fn' in config else None)
self.custom_distance_fn = (
config.get('custom_distance_fn')
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
if 'custom_distance_fn' in config else None)
self.adjust_prediction_fn = (
config.get('adjust_prediction')
if 'adjust_prediction' in config else None)
Expand All @@ -76,6 +79,9 @@ def __init__(self, config_builder):
del copied_config['custom_predict_fn']
if 'compare_custom_predict_fn' in copied_config:
del copied_config['compare_custom_predict_fn']
if 'custom_distance_fn' in copied_config:
del copied_config['custom_distance_fn']
copied_config['uses_custom_distance_fn'] = True
if 'adjust_prediction' in copied_config:
del copied_config['adjust_prediction']
if 'compare_adjust_prediction' in copied_config:
Expand Down Expand Up @@ -111,6 +117,12 @@ def set_examples(self, examples):
self.examples = [json_format.MessageToJson(ex) for ex in examples]
self.updated_example_indices = set(range(len(examples)))

def compute_custom_distance_impl(self, index, params=None):
exs_for_distance = [
self.json_to_proto(example) for example in self.examples]
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
selected_ex = exs_for_distance[index]
return self.custom_distance_fn(selected_ex, exs_for_distance, params)

def json_to_proto(self, json):
ex = (tf.train.SequenceExample()
if self.config.get('are_sequence_examples')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def infer_mutants(wit_id, details):
output.register_callback('notebook.InferMutants', infer_mutants)


def compute_custom_distance(wit_id, index, callback_name, params):
WitWidget.widgets[wit_id].compute_custom_distance(index, callback_name,
params)
output.register_callback('notebook.ComputeCustomDistance',
compute_custom_distance)


# HTML/javascript for the WIT frontend.
WIT_HTML = """
<tf-interactive-inference-dashboard id="wit" local>
Expand All @@ -66,6 +73,12 @@ def infer_mutants(wit_id, details):
google.colab.kernel.invokeFunction(
'notebook.InferExamples', [id], {{}});
}});
wit.addEventListener("compute-custom-distance", e => {{
google.colab.kernel.invokeFunction(
'notebook.ComputeCustomDistance',
[id, e.detail.index, e.detail.callback, e.detail.params],
{{}});
}});
wit.addEventListener("delete-example", e => {{
google.colab.kernel.invokeFunction(
'notebook.DeleteExample', [id, e.detail.index], {{}});
Expand All @@ -76,7 +89,9 @@ def infer_mutants(wit_id, details):
}});
wit.addEventListener("update-example", e => {{
google.colab.kernel.invokeFunction(
'notebook.UpdateExample', [id, e.detail.index, e.detail.example], {{}});
'notebook.UpdateExample',
[id, e.detail.index, e.detail.example],
{{}});
}});
wit.addEventListener('get-eligible-features', e => {{
google.colab.kernel.invokeFunction(
Expand All @@ -97,8 +112,15 @@ def infer_mutants(wit_id, details):
wit.labelVocab = inferences.label_vocab;
wit.inferences = inferences.inferences;
wit.attributions = {{indices: wit.inferences.indices,
attributions: inferences.attributions}}
attributions: inferences.attributions}};
}};

window.distanceCallback = callbackDict => {{
wit[callbackDict.callback_fn](callbackDict.exInd,
callbackDict.distances,
callbackDict.params);
}};

window.spriteCallback = spriteUrl => {{
if (!wit.updateSprite) {{
requestAnimationFrame(() => window.spriteCallback(spriteUrl));
Expand Down Expand Up @@ -150,6 +172,11 @@ def infer_mutants(wit_id, details):
if ('target_feature' in config) {{
wit.selectedLabelFeature = config['target_feature'];
}}
if ('uses_custom_distance_fn' in config) {{
wit.customDistanceFunctionSet = 1;
}} else {{
wit.customDistanceFunctionSet = 0;
}}
}};
window.updateExamplesCallback = examples => {{
if (!wit.updateExampleContents) {{
Expand Down Expand Up @@ -256,6 +283,22 @@ def duplicate_example(self, index):
self.updated_example_indices.add(len(self.examples) - 1)
self._generate_sprite()

def compute_custom_distance(self, index, callback_fn, params):
try:
distances = base.WitWidgetBase.compute_custom_distance_impl(
self, index, params['distanceParams'])
callback_dict = {
'distances': distances,
'exInd': index,
'callback_fn': callback_fn,
'params': params['callbackParams']
}
output.eval_js("""distanceCallback({callback_dict})""".format(
callback_dict=json.dumps(callback_dict)))
except Exception as e:
output.eval_js(
"""backendError({error})""".format(error=json.dumps({'msg': str(e)})))

def get_eligible_features(self):
features_list = base.WitWidgetBase.get_eligible_features_impl(self)
output.eval_js("""eligibleFeaturesCallback({features_list})""".format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ var WITView = widgets.DOMWidgetView.extend({
this.model.on('change:mutant_charts', this.mutantChartsChanged, this);
this.model.on('change:sprite', this.spriteChanged, this);
this.model.on('change:error', this.backendError, this);
this.model.on(
'change:custom_distance_dict',
this.customDistanceComputed,
this
);
},

/**
Expand Down Expand Up @@ -118,14 +123,19 @@ var WITView = widgets.DOMWidgetView.extend({
this.model.set('get_eligible_features', i);
this.touch();
});

this.inferMutantsCounter = 0;
this.view_.addEventListener('infer-mutants', (e) => {
e.detail['infer_mutants_counter'] = this.inferMutantsCounter++;
this.model.set('infer_mutants', e.detail);
this.mutantFeature = e.detail.feature_name;
this.touch();
});
this.computeDistanceCounter = 0;
this.view_.addEventListener('compute-custom-distance', (e) => {
e.detail['compute_distance_counter'] = this.computeDistanceCounter++;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels quite odd and not clean. Would it be possible to use data binding instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The counter is there as it is in the case of infer-mutants above. We use data binding for what is passed ('compute_custom_distance' is set to e.detail). However, we only pass the example id back to the python side instead of the full example since examples are already in sync between the backend and frontend. In this case, if someone requests the same example again after updating some features, the change on bound dict may not trigger since it is the same dict. We force trigger it by incrementing the counter.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry but I must confess I am very lost. Even if this logic is correct, I do think this has readability problem.

AFAICT, the event is fired from JavaScript side. When firing the event, the object does not contain the property compute_distance_counter. 1. how is it being used by owner of the event object? 2. where does the Python come into play?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dashboard launches this event from js side by passing some parameters (example index, callback ind etc.). "compute_custom_distance" is sync variable between js and python sides. This synchronization is managed by jupyter widget framework. When we call this.model.set('compute_custom_distance', e.detail), ideally jupyter should recognize the change in "compute_custom_distance" and trigger distance computation. James and I found out previously (for infer-mutants) that if someone edits an example then fires this event again from the dashboard, since e.detail does not contain the contents of the example but only it's index, jupyter thinks nothing has changed and model.set does not trigger distance computation. James solved it by adding a counter so that we make sure the serialized dictionary always has one field that changes between calls and jupyter triggers the sync.
Since this is a jupyter specific issue, we did not include this in the event but append it in wit.js related to jupyter.

this.model.set('compute_custom_distance', e.detail);
this.touch();
});
this.setupComplete = true;
},

Expand Down Expand Up @@ -228,6 +238,11 @@ var WITView = widgets.DOMWidgetView.extend({
if ('target_feature' in config) {
this.view_.selectedLabelFeature = config['target_feature'];
}
if ('uses_custom_distance_fn' in config) {
this.view_.customDistanceFunctionSet = 1;
} else {
this.view_.customDistanceFunctionSet = 0;
}
},
spriteChanged: function() {
if (!this.setupComplete) {
Expand All @@ -246,6 +261,21 @@ var WITView = widgets.DOMWidgetView.extend({
const error = this.model.get('error');
this.view_.handleError(error['msg']);
},
customDistanceComputed: function() {
if (!this.setupComplete) {
if (this.isViewReady()) {
this.setupView();
}
requestAnimationFrame(() => this.customDistanceComputed());
return;
}
const custom_distance_dict = this.model.get('custom_distance_dict');
this.view_[custom_distance_dict.callback_fn](
custom_distance_dict.exInd,
custom_distance_dict.distances,
custom_distance_dict.params
);
},
});

module.exports = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class WitWidget(widgets.DOMWidget, base.WitWidgetBase):
mutant_charts_counter = Int(0)
sprite = Unicode('').tag(sync=True)
error = Dict(dict()).tag(sync=True)
compute_custom_distance = Dict(dict()).tag(sync=True)
custom_distance_dict = Dict(dict()).tag(sync=True)

def __init__(self, config_builder, height=1000):
"""Constructor for Jupyter notebook WitWidget.
Expand Down Expand Up @@ -119,6 +121,22 @@ def _delete_example(self, change):
i if i < index else i - 1 for i in self.updated_example_indices])
self._generate_sprite()

@observe('compute_custom_distance')
def _compute_custom_distance(self, change):
info = self.compute_custom_distance
index = info['index']
params = info['params']
callback_fn = info['callback']
try:
distances = base.WitWidgetBase.compute_custom_distance_impl(self, index,
params['distanceParams'])
self.custom_distance_dict = {'distances':distances,
'exInd':index,
'callback_fn':callback_fn,
'params':params['callbackParams']}
except Exception as e:
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
self._report_error(e)

def _generate_sprite(self):
sprite = base.WitWidgetBase.create_sprite(self)
if sprite is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,27 @@ def set_compare_custom_predict_fn(self, predict_fn):
self.set_compare_model_name('2')
return self

def set_custom_distance_fn(self, distance_fn):
"""Sets a custom function for distance computation.

WIT can directly use a custom function for all distance computations within
the tool. In this case, the provided function should accept a query example
proto and a list of example protos to compute the distance against and
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
return a 1D list of numbers containing the distances.

Args:
distance_fn: The python function which will be used for distance
computation.

Returns:
self, in order to enabled method chaining.
"""
if distance_fn is None:
self.delete('custom_distance_fn')
else:
self.store('custom_distance_fn', distance_fn)
return self

tolga-b marked this conversation as resolved.
Show resolved Hide resolved
def _convert_json_to_tf_examples(self, examples):
self._set_uses_json_input(True)
tf_examples = []
Expand Down