Skip to content

Commit

Permalink
Add ability to set custom distance function for counterfactuals (#2607)
Browse files Browse the repository at this point in the history
WIT did not support counterfactual computation for image and text features as it is not straightforward to define the similarity of two images or two sentences. This PR adds the ability to pass a custom python function for similarity computation such that WIT can query that function to compute counterfactuals.
  • Loading branch information
tolga-b authored and jameswex committed Sep 19, 2019
1 parent e93f472 commit 15456be
Show file tree
Hide file tree
Showing 16 changed files with 407 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
/** @type {!Object} */ var define;
/** @type {!Object} */ var global;
/** @type {!Object} */ var tf;
/** @type {!Object} */ var mobilenet;
/** @type {!Function|undefined} */ var ga;
/** @type {!Function|undefined} */ var KeyframeEffect;
/** @type {!Object} */ var tensor_widget;
Expand Down
4 changes: 4 additions & 0 deletions tensorboard/plugins/interactive_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ Here is a basic rundown of what it can do:
* For categorical features, the distance is 0 if the values are the same,
otherwise the distance is the probability that any two examples have
the same value for that feature across all examples.
* In notebook mode, the tool also allows you to set a custom distance function
using set_custom_distance_fn in WitConfigBuilder, where that function is
used to compute closest counterfactuals instead. As in the case with
custom_predict_fn, the custom distance function can be any python function.

* Edit a selected example in the browser and re-run inference and visualize the
difference in the inference results.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ tf_web_library(
"image_index.html",
"tf-interactive-inference-image-demo.html",
"@org_tensorflow_tfjs//:tf.min.js",
"@org_tensorflow_tfjs_mobilenet//:mobilenet.js",
] + glob(["data/**"]),
path = "/tf-interactive-inference-dashboard",
deps = [
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
];
this.categories['over_50k'] = ['1', '0'];

this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/age_uci/model.json', {
credentials: 'include',
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@
'Holand-Netherlands',
];

this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/uci/model.json', {
credentials: 'include',
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
<link rel="import" href="../tf-imports/polymer.html" />
<link rel="import" href="tf-interactive-inference-dashboard.html" />
<script src="tf.min.js"></script>
<script src="mobilenet.js"></script>

<dom-module id="tf-interactive-inference-image-demo">
<template>
<style>
Expand Down Expand Up @@ -53,16 +55,22 @@
labelVocab: {type: Array, value: ['Not smiling', 'Smiling']},
numLoadedImages: Number,
images: Array,
distanceEmbs: Array,
distanceEmbsTensor: Object,
distanceModel: Object,
},
ready: async function() {
this.$.dash.modelName = 'demo';
this.$.dash.inferenceAddress = 'demo';
this.$.dash.updateNumberOfModels();
this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/images/model.json', {
credentials: 'include',
})
);
// Load mobilenet for custom distance computation.
this.distanceModel = await mobilenet.load();
const DISTANCE_EMBS_PATH = 'data/images/distance_embeddings.json';
const DATA_PATH = 'data/images/smile_examples.json';
const RESULTS_PATH = 'data/images/smile_examples_inference.json';
const testData = d3.json(DATA_PATH).then((data) => {
Expand All @@ -75,20 +83,36 @@
this.data.map((item) => JSON.stringify(item)),
true
);
this.$.dash.addEventListener('update-example', (e) => {
this.$.dash.addEventListener('update-example', async (e) => {
this.data[e.detail.index] = JSON.parse(e.detail.example);
// Update distance embedding of example
const ex = this.data[e.detail.index];
const exTensor = await this.convertExToTensor(ex);
const predTensor = await this.distanceModel.infer(exTensor, true);
const predNorm = predTensor.div(tf.norm(predTensor));
// Normalize embedding
const pred = await predNorm.array();
this.distanceEmbs[e.detail.index] = pred[0];
this.updateDistanceEmbsTensor();
exTensor.dispose();
predNorm.dispose();
predTensor.dispose();
this.indicesToInfer[e.detail.index] = true;
this.updateSprite();
});
this.$.dash.addEventListener('duplicate-example', (e) => {
this.data.push(
JSON.parse(JSON.stringify(this.data[e.detail.index]))
);
this.distanceEmbs.push(this.distanceEmbs[e.detail.index]);
this.updateDistanceEmbsTensor();
this.indicesToInfer[this.data.length - 1] = true;
this.updateSprite();
});
this.$.dash.addEventListener('delete-example', (e) => {
this.data.splice(e.detail.index, 1);
this.distanceEmbs.splice(e.detail.index, 1);
this.updateDistanceEmbsTensor();
const newIndicesToInfer = {};
const oldIndicesToInfer = Object.keys(this.indicesToInfer);
for (let i = 0; i < oldIndicesToInfer.length; i++) {
Expand All @@ -110,7 +134,7 @@
{classificationResult: {classifications: []}},
];
if (indices.length == 250) {
const testData = d3.json(RESULTS_PATH).then((inferences) => {
d3.json(RESULTS_PATH).then((inferences) => {
setTimeout(() => {
// for compatibility with inferences.json
inferences.results = [inferences.results];
Expand All @@ -119,6 +143,11 @@
this.$.dash.selectedLabelFeature = 'Smiling';
}, 1000);
});
// Load precomputed embeddings from json for fast startup
d3.json(DISTANCE_EMBS_PATH).then((distanceEmbs) => {
this.distanceEmbs = distanceEmbs;
this.updateDistanceEmbsTensor();
});
} else {
const tensorArr = [];
for (let i = 0; i < indices.length; i++) {
Expand Down Expand Up @@ -154,6 +183,27 @@
this.$.dash.addEventListener('get-eligible-features', (e) => {
this.$.dash.partialDepPlotEligibleFeatures = [];
});
this.$.dash.addEventListener('compute-custom-distance', async (e) => {
// Compute cosine similarity between selected example and all other
// examples. Return negative cosine similarity as the distance.
const ind = tf.tensor1d([+e.detail.index], 'int32');
const exEmb = this.distanceEmbsTensor.gather(ind);
const distancesTensor = tf.neg(
tf.dot(this.distanceEmbsTensor, tf.transpose(exEmb)).flatten()
);
const distances = await distancesTensor.data();
ind.dispose();
exEmb.dispose();
distancesTensor.dispose();
const callbackObj = {
exInd: e.detail.index,
distances: distances,
params: e.detail.params.callbackParams,
funId: e.detail.callback,
};
this.$.dash.invokeCustomDistanceCallback(callbackObj);
});
this.$.dash.customDistanceFunctionSet = true;
requestAnimationFrame(() => {
this.$.dash.inferClicked_();
});
Expand All @@ -174,7 +224,7 @@
canvas
.getContext('2d')
.drawImage(img, 0, 0, img.width, img.height);
let tensor = tf.fromPixels(canvas).toFloat();
let tensor = tf.browser.fromPixels(canvas).toFloat();
if (tensor.shape[0] != 78 || tensor.shape[1] != 64) {
tensor = tf.image.resizeBilinear(tensor, [78, 64]);
}
Expand Down Expand Up @@ -203,6 +253,21 @@
}
},

updateDistanceEmbsTensor: function() {
if (
this.distanceEmbsTensor &&
Object.keys(this.distanceEmbsTensor).length != 0
) {
this.distanceEmbsTensor.dispose();
}
const distanceEmbsTensor = tf.tensor2d(this.distanceEmbs);
this.distanceEmbsTensor = distanceEmbsTensor.reshape([
this.distanceEmbs.length,
1024,
]);
distanceEmbsTensor.dispose();
},

imageLoaded: function() {
const canvas = this.$.spritecanvas;
const THUMBNAIL_SIZE = 32;
Expand All @@ -211,7 +276,6 @@
return;
}
const thumbnailsPerSide = Math.ceil(Math.sqrt(this.images.length));
console.log(thumbnailsPerSide);
canvas.width = thumbnailsPerSide * THUMBNAIL_SIZE;
canvas.height = thumbnailsPerSide * THUMBNAIL_SIZE;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
this.$.dash.inferenceAddress = 'demo';
this.$.dash.multiClass = true;
this.$.dash.updateNumberOfModels();
this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/iris/model.json', {
credentials: 'include',
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@
'Holand-Netherlands',
];

this.model = await tf.loadModel(
this.model = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/uci/model.json', {
credentials: 'include',
})
);
this.model2 = await tf.loadModel(
this.model2 = await tf.loadLayersModel(
tf.io.browserHTTPRequest('data/uci/model_comparison/model.json', {
credentials: 'include',
})
Expand Down
Loading

0 comments on commit 15456be

Please sign in to comment.