Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to set custom distance function for counterfactuals #2607

Merged
merged 44 commits into from
Sep 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
ae2daff
midway 1
tolga-b Aug 23, 2019
d6a6920
midway 2
tolga-b Aug 27, 2019
0359655
midway 3
tolga-b Aug 28, 2019
fc6709b
image colab works
tolga-b Aug 29, 2019
31e05ae
colab distance 1
tolga-b Aug 29, 2019
79dd89b
jupyter works
tolga-b Aug 29, 2019
00cade8
ran prettier
tolga-b Aug 29, 2019
ed3a4c6
review comments
tolga-b Sep 3, 2019
299dd47
minor fix
tolga-b Sep 3, 2019
76ddff8
fix lint
tolga-b Sep 3, 2019
0ee315a
fix comment
tolga-b Sep 4, 2019
a5126e8
Merge branch 'distance' of https://github.com/tolga-b/tensorboard int…
tolga-b Sep 4, 2019
9749fd6
fix dom if
tolga-b Sep 4, 2019
1343f30
image demo 1
tolga-b Sep 5, 2019
cc2ab68
image demo 2
tolga-b Sep 5, 2019
3fb28c1
image demo 3
tolga-b Sep 5, 2019
47b5960
image demo 3
tolga-b Sep 5, 2019
243e78e
undo image demo
tolga-b Sep 5, 2019
4eb6e09
Merge branch 'master' of https://github.com/tensorflow/tensorboard in…
tolga-b Sep 5, 2019
9e98468
Merge branch 'distance' of https://github.com/tolga-b/tensorboard int…
tolga-b Sep 5, 2019
674afd4
add distance
tolga-b Sep 10, 2019
4faa9e6
smile demo works
tolga-b Sep 11, 2019
a024f6d
Merge branch 'master' of https://github.com/tensorflow/tensorboard in…
tolga-b Sep 11, 2019
4a9e9e8
add distance embeddings
tolga-b Sep 11, 2019
16039ef
add delete and duplicate example for distance
tolga-b Sep 11, 2019
06a8451
add distance works
tolga-b Sep 11, 2019
eadecfa
changed distance pane
tolga-b Sep 12, 2019
3bccfc8
update readme
tolga-b Sep 12, 2019
d2414d9
update demos
tolga-b Sep 12, 2019
0c4a9c9
review comments fix image demo slice
tolga-b Sep 12, 2019
2469392
memory leak fix
tolga-b Sep 12, 2019
859cfa6
memory leak fix 2
tolga-b Sep 12, 2019
87eb749
ran prettier
tolga-b Sep 12, 2019
d38f7b5
Merge branch 'master' of https://github.com/tensorflow/tensorboard in…
tolga-b Sep 12, 2019
50839a3
smile demo upload image fix
tolga-b Sep 12, 2019
7c82372
review comments
tolga-b Sep 17, 2019
9aa1a72
merge master
tolga-b Sep 17, 2019
8b02c04
review comments
tolga-b Sep 18, 2019
de0ca1b
prettier
tolga-b Sep 18, 2019
69f5d6d
merge master
tolga-b Sep 18, 2019
b87964c
fix bzl
tolga-b Sep 18, 2019
5becf6a
fix bzl 2
tolga-b Sep 18, 2019
a80af0f
fix bzl 3
tolga-b Sep 18, 2019
13a8444
review comments
tolga-b Sep 18, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
/** @type {!Object} */ var define;
/** @type {!Object} */ var global;
/** @type {!Object} */ var tf;
/** @type {!Object} */ var mobilenet;
/** @type {!Function|undefined} */ var ga;
/** @type {!Function|undefined} */ var KeyframeEffect;
/** @type {!Object} */ var tensor_widget;
Expand Down
4 changes: 4 additions & 0 deletions tensorboard/plugins/interactive_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ Here is a basic rundown of what it can do:
* For categorical features, the distance is 0 if the values are the same,
otherwise the distance is the probability that any two examples have
the same value for that feature across all examples.
* In notebook mode, the tool also allows you to set a custom distance function
using set_custom_distance_fn in WitConfigBuilder, where that function is
used to compute closest counterfactuals instead. As in the case with
custom_predict_fn, the custom distance function can be any python function.

* Edit a selected example in the browser and re-run inference and visualize the
difference in the inference results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ tf_web_library(
"image_index.html",
"tf-interactive-inference-image-demo.html",
"@org_tensorflow_tfjs//:tf.min.js",
"@org_tensorflow_tfjs_mobilenet//:mobilenet.js",
] + glob(["data/**"]),
path = "/tf-interactive-inference-dashboard",
deps = [
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
];
this.categories['over_50k'] = ['1', '0'];

this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/age_uci/model.json', {
credentials: 'include',
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
'Holand-Netherlands',
];

this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/uci/model.json', {
credentials: 'include',
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
<link rel="import" href="../tf-imports/polymer.html" />
<link rel="import" href="tf-interactive-inference-dashboard.html" />
<script src="tf.min.js"></script>
<script src="mobilenet.js"></script>

<dom-module id="tf-interactive-inference-image-demo">
<template>
<style>
Expand Down Expand Up @@ -53,16 +55,22 @@
labelVocab: {type: Array, value: ['Not smiling', 'Smiling']},
numLoadedImages: Number,
images: Array,
distanceEmbs: Array,
distanceEmbsTensor: Object,
distanceModel: Object,
},
ready: async function() {
this.$.dash.modelName = 'demo';
this.$.dash.inferenceAddress = 'demo';
this.$.dash.updateNumberOfModels();
this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/images/model.json', {
credentials: 'include',
})
);
// Load mobilenet for custom distance computation.
this.distanceModel = await mobilenet.load();
const DISTANCE_EMBS_PATH = 'data/images/distance_embeddings.json';
const DATA_PATH = 'data/images/smile_examples.json';
const RESULTS_PATH = 'data/images/smile_examples_inference.json';
const testData = d3.json(DATA_PATH).then((data) => {
Expand All @@ -75,20 +83,36 @@
this.data.map((item) => JSON.stringify(item)),
true
);
this.$.dash.addEventListener('update-example', (e) => {
this.$.dash.addEventListener('update-example', async (e) => {
this.data[e.detail.index] = JSON.parse(e.detail.example);
// Update distance embedding of example
const ex = this.data[e.detail.index];
const exTensor = await this.convertExToTensor(ex);
const predTensor = await this.distanceModel.infer(exTensor, true);
const predNorm = predTensor.div(tf.norm(predTensor));
// Normalize embedding
const pred = await predNorm.array();
this.distanceEmbs[e.detail.index] = pred[0];
this.updateDistanceEmbsTensor();
exTensor.dispose();
predNorm.dispose();
predTensor.dispose();
this.indicesToInfer[e.detail.index] = true;
this.updateSprite();
});
this.$.dash.addEventListener('duplicate-example', (e) => {
this.data.push(
tolga-b marked this conversation as resolved.
Show resolved Hide resolved
JSON.parse(JSON.stringify(this.data[e.detail.index]))
);
this.distanceEmbs.push(this.distanceEmbs[e.detail.index]);
this.updateDistanceEmbsTensor();
this.indicesToInfer[this.data.length - 1] = true;
this.updateSprite();
});
this.$.dash.addEventListener('delete-example', (e) => {
this.data.splice(e.detail.index, 1);
this.distanceEmbs.splice(e.detail.index, 1);
this.updateDistanceEmbsTensor();
const newIndicesToInfer = {};
const oldIndicesToInfer = Object.keys(this.indicesToInfer);
for (let i = 0; i < oldIndicesToInfer.length; i++) {
Expand All @@ -110,7 +134,7 @@
{classificationResult: {classifications: []}},
];
if (indices.length == 250) {
const testData = d3.json(RESULTS_PATH).then((inferences) => {
d3.json(RESULTS_PATH).then((inferences) => {
setTimeout(() => {
// for compatibility with inferences.json
inferences.results = [inferences.results];
Expand All @@ -119,6 +143,11 @@
this.$.dash.selectedLabelFeature = 'Smiling';
}, 1000);
});
// Load precomputed embeddings from json for fast startup
d3.json(DISTANCE_EMBS_PATH).then((distanceEmbs) => {
this.distanceEmbs = distanceEmbs;
this.updateDistanceEmbsTensor();
});
} else {
const tensorArr = [];
for (let i = 0; i < indices.length; i++) {
Expand Down Expand Up @@ -154,6 +183,27 @@
this.$.dash.addEventListener('get-eligible-features', (e) => {
this.$.dash.partialDepPlotEligibleFeatures = [];
});
this.$.dash.addEventListener('compute-custom-distance', async (e) => {
// Compute cosine similarity between selected example and all other
// examples. Return negative cosine similarity as the distance.
const ind = tf.tensor1d([+e.detail.index], 'int32');
const exEmb = this.distanceEmbsTensor.gather(ind);
const distancesTensor = tf.neg(
tf.dot(this.distanceEmbsTensor, tf.transpose(exEmb)).flatten()
);
const distances = await distancesTensor.data();
ind.dispose();
exEmb.dispose();
distancesTensor.dispose();
const callbackObj = {
exInd: e.detail.index,
distances: distances,
params: e.detail.params.callbackParams,
funId: e.detail.callback,
};
this.$.dash.invokeCustomDistanceCallback(callbackObj);
});
this.$.dash.customDistanceFunctionSet = true;
requestAnimationFrame(() => {
this.$.dash.inferClicked_();
});
Expand All @@ -174,7 +224,7 @@
canvas
.getContext('2d')
.drawImage(img, 0, 0, img.width, img.height);
let tensor = tf.fromPixels(canvas).toFloat();
let tensor = tf.browser.fromPixels(canvas).toFloat();
if (tensor.shape[0] != 78 || tensor.shape[1] != 64) {
tensor = tf.image.resizeBilinear(tensor, [78, 64]);
}
Expand Down Expand Up @@ -203,6 +253,21 @@
}
},

updateDistanceEmbsTensor: function() {
if (
this.distanceEmbsTensor &&
Object.keys(this.distanceEmbsTensor).length != 0
) {
this.distanceEmbsTensor.dispose();
}
const distanceEmbsTensor = tf.tensor2d(this.distanceEmbs);
this.distanceEmbsTensor = distanceEmbsTensor.reshape([
this.distanceEmbs.length,
1024,
]);
distanceEmbsTensor.dispose();
},

imageLoaded: function() {
const canvas = this.$.spritecanvas;
const THUMBNAIL_SIZE = 32;
Expand All @@ -211,7 +276,6 @@
return;
}
const thumbnailsPerSide = Math.ceil(Math.sqrt(this.images.length));
console.log(thumbnailsPerSide);
canvas.width = thumbnailsPerSide * THUMBNAIL_SIZE;
canvas.height = thumbnailsPerSide * THUMBNAIL_SIZE;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
this.$.dash.inferenceAddress = 'demo';
this.$.dash.multiClass = true;
this.$.dash.updateNumberOfModels();
this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/iris/model.json', {
credentials: 'include',
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@
'Holand-Netherlands',
];

this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/uci/model.json', {
credentials: 'include',
})
);
this.model2 = await tf.loadModel(
this.model2 = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/uci/model_comparison/model.json', {
credentials: 'include',
})
Expand Down
Loading