From 81f1a08398920d75f5525b681ea3299ffc6c974b Mon Sep 17 00:00:00 2001 From: Jake Wagoner Date: Fri, 12 Apr 2024 10:46:17 -0600 Subject: [PATCH] Uniform and non-uniform scalar bars for deepssm heatmaps --- shapeworks_cloud/core/deepssm_tasks.py | 7 +- web/shapeworks/src/api/rest.ts | 12 +- .../src/components/Analysis/PCA.vue | 2 +- .../src/components/ShapeViewer/methods.js | 133 +++++++++++++----- .../src/components/ShapeViewer/scripting.js | 6 - .../src/components/ShapeViewer/viewer.vue | 32 ++--- web/shapeworks/src/reader/image.ts | 6 +- web/shapeworks/src/views/Main.vue | 24 ++-- 8 files changed, 131 insertions(+), 91 deletions(-) diff --git a/shapeworks_cloud/core/deepssm_tasks.py b/shapeworks_cloud/core/deepssm_tasks.py index 7296d4d8..ad03c084 100644 --- a/shapeworks_cloud/core/deepssm_tasks.py +++ b/shapeworks_cloud/core/deepssm_tasks.py @@ -426,10 +426,9 @@ def post_command_function(project, download_dir, result_data, project_filename): filename = file1.split('.')[0] # filename here represents the SUBJECT INDEX OF THE TEST SPLIT - subject_name = ( - result_data['testing']['test_split_subjects'][int(filename)] - .get_display_name() - ) + subject_name = result_data['testing']['test_split_subjects'][ + int(filename) + ].get_display_name() test_pair = models.DeepSSMTestingData.objects.create( project=project, diff --git a/web/shapeworks/src/api/rest.ts b/web/shapeworks/src/api/rest.ts index b7c650cf..f5c59b6e 100644 --- a/web/shapeworks/src/api/rest.ts +++ b/web/shapeworks/src/api/rest.ts @@ -1,6 +1,6 @@ import { AnalysisParams, DataObject, Dataset, LandmarkInfo, Constraints, Project, Subject } from "@/types"; import { apiClient } from "./auth"; -import { loadGroomedShapeForObject, loadParticlesForObject } from "@/store"; +import { deepSSMDataTab, loadGroomedShapeForObject, loadParticlesForObject } from "@/store"; export async function getDatasets(search: string | undefined): Promise{ @@ -101,10 +101,12 @@ export async function getOptimizedParticlesForDataObject( export async function getGroomedShapeForDataObject( type: string, id: number, projectId: number|undefined ) { - const plural = `${type}${type == 'mesh' ?'es' :'s'}` - return (await apiClient.get(`/groomed-${plural}`, { - params: {[type]: id, project: projectId} - })).data.results + if (type !== 'image') { + const plural = `${type}${type == 'mesh' ?'es' :'s'}` + return (await apiClient.get(`/groomed-${plural}`, { + params: {[type]: id, project: projectId} + })).data.results + } } export async function getReconstructedSamplesForProject( diff --git a/web/shapeworks/src/components/Analysis/PCA.vue b/web/shapeworks/src/components/Analysis/PCA.vue index d0fa3774..0fa62aec 100644 --- a/web/shapeworks/src/components/Analysis/PCA.vue +++ b/web/shapeworks/src/components/Analysis/PCA.vue @@ -149,7 +149,7 @@ import { groupBy } from '../../helper'; currentTasks.value[selectedProject.value.id] = {} } - const taskIds = await spawnJob("analyze", {"analysis": {range, steps: numSteps}}); // Record + const taskIds = await spawnJob("analyze", {"range": range, "steps": numSteps}); // Record if(!taskIds || taskIds.length === 0) { message.value = `Failed to submit analysis job.` diff --git a/web/shapeworks/src/components/ShapeViewer/methods.js b/web/shapeworks/src/components/ShapeViewer/methods.js index a8e2ec4f..e0068098 100644 --- a/web/shapeworks/src/components/ShapeViewer/methods.js +++ b/web/shapeworks/src/components/ShapeViewer/methods.js @@ -251,15 +251,12 @@ export default { if ([1, 2].includes(deepSSMDataTab.value)) { const data = shapeData.getPointData().getArrayByName('deepssm_error').getData() let normalizeRange; - console.log('uniform scale?', uniformScale.value) if (uniformScale.value) { normalizeRange = deepSSMErrorGlobalRange.value } else { normalizeRange = [Math.min(...data), Math.max(...data)] } - console.log('normalize range', normalizeRange) const normalizedData = data.map((v) => v / (normalizeRange[1] - normalizeRange[0])) - console.log('normalized data', normalizedData) const normalizedArray = vtkDataArray.newInstance({ name: 'deepssm_error_normalized', values: normalizedData, @@ -378,51 +375,107 @@ export default { if (showDifferenceFromMeanMode.value) { this.lookupTable.setMappingRange(0, 1) this.lookupTable.updateRange(); - this.prepareColorScale() + this.prepareColorScales() } this.render() }, - prepareColorScale() { - const canvas = this.$refs.colors - const labelDiv = this.$refs.colorLabels; - let dataRange = this.lookupTable.getMappingRange() - - if (uniformScale.value) { - dataRange = deepSSMErrorGlobalRange.value + prepareColorScales() { + const canvasDiv = this.$refs.colors + const titleDiv = this.$refs.colorsTitle + const labelsDiv = this.$refs.colorsLabels + canvasDiv.innerHTML = "" + titleDiv.innerHTML = "" + labelsDiv.innerHTML = "" + if (showDifferenceFromMeanMode.value || uniformScale.value) { + const canvas = document.createElement('canvas') + canvas.style.right = "10px" + canvas.style.height = "100%" + canvas.style.width = "20px" + canvasDiv.appendChild(canvas) + + const labels = document.createElement('div') + labels.style.right = "35px" + labels.style.height = "100%" + labelsDiv.appendChild(labels) + + let range = [-5, 5] + if (showDifferenceFromMeanMode.value) { + titleDiv.innerHTML = "Distance from particle on mean shape" + } else if (uniformScale.value) { + titleDiv.innerHTML = "DeepSSMError" + range = deepSSMErrorGlobalRange.value + } + this.prepareColorScale(canvas, labels, range) + } else { + const { width, height } = this.$refs.vtk.getBoundingClientRect() + Object.entries(this.data).forEach(([label, data], i) => { + const [x1, y1, x2, y2] = this.grid[i] + const canvas = document.createElement('canvas') + canvas.style.top = `calc(${(1 - y2) * 100}%)` + canvas.style.right = `calc(${(1 - x2) * 100}% + 10px)` + canvas.style.height = `${(y2 - y1) * 96}%` + canvas.style.width = "10px" + canvasDiv.appendChild(canvas) + + const labels = document.createElement('div') + labels.style.top = `calc(${(1 - y2) * 100}%)` + labels.style.right = `calc(${(1 - x2) * 100}% + 35px)` + labels.style.height = `${(y2 - y1) * 96}%` + labelsDiv.appendChild(labels) + + let range = [-5, 5] + data[0].shape.forEach((shape) => { + if (shape.getClassName() === 'vtkPolyData') { + const arr = shape.getPointData().getArrayByName('deepssm_error').getData() + if (arr) range = [Math.min(...arr), Math.max(...arr)] + } + }) + this.prepareColorScale(canvas, labels, range) + }) } - - if (canvas && labelDiv) { + }, + prepareColorScale(canvas, labels, range) { + if (canvas && labels) { + canvas.style.position = "absolute" + canvas.style.zIndex = "1" const { width, height } = canvas const context = canvas.getContext('2d', { willReadFrequently: true }); const pixelsArea = context.getImageData(0, 0, width, height); const colorsData = this.lookupTable.getUint8Table( - ...dataRange, + ...this.lookupTable.getMappingRange(), height * width, true ) - pixelsArea.data.set(colorsData) context.putImageData(pixelsArea, 0, 0) + + const labelProportions = [1, 0.75, 0.5, 0.25, 0]; + labels.innerHTML = '' + labels.style.position = "absolute" + labels.style.display = "flex" + labels.style.flexDirection = "column" + labels.style.justifyContent = "space-between" + labels.style.alignItems = "flex-end" + labels.style.textAlign = "right" + labelProportions.forEach((p) => { + const child = document.createElement('span'); + child.innerHTML = Math.round(p * (range[1] - range[0]) + range[0]); + labels.appendChild(child); + }) } - - const labels = [1, 0.75, 0.5, 0.25, 0]; - labelDiv.innerHTML = '' - labels.forEach((l) => { - const child = document.createElement('span'); - child.innerHTML = l * (dataRange[1] - dataRange[0]) + dataRange[0]; - labelDiv.appendChild(child); - }) }, prepareLabelCanvas() { + const labelCanvas = this.$refs.labels + const labelCanvasContext = labelCanvas.getContext('2d') const { clientWidth, clientHeight } = this.$refs.vtk; // increase the resolution of the canvas so text isn't blurry - this.labelCanvas.width = clientWidth; - this.labelCanvas.height = clientHeight; + labelCanvas.width = clientWidth; + labelCanvas.height = clientHeight; - this.labelCanvasContext.clearRect(0, 0, this.labelCanvas.width, this.labelCanvas.height) - this.labelCanvasContext.font = "16px Arial"; - this.labelCanvasContext.fillStyle = "white"; + labelCanvasContext.clearRect(0, 0, labelCanvas.width, labelCanvas.height) + labelCanvasContext.font = "16px Arial"; + labelCanvasContext.fillStyle = "white"; }, populateRenderer(renderer, shapes) { this.addShapes(renderer, shapes.map(({ shape }) => shape)); @@ -463,10 +516,12 @@ export default { const newRenderer = vtkRenderer.newInstance({ background: [0.115, 0.115, 0.115] }); const bounds = this.grid[i]; - this.labelCanvasContext.fillText( + const labelCanvas = this.$refs.labels + const labelCanvasContext = labelCanvas.getContext('2d') + labelCanvasContext.fillText( label, - this.labelCanvas.width * bounds[0], - this.labelCanvas.height * (1 - bounds[1]) - 20 + labelCanvas.width * bounds[0], + labelCanvas.height * (1 - bounds[1]) - 20 ); newRenderer.setViewport.apply(newRenderer, bounds); this.vtk.renderers[label] = newRenderer; @@ -489,16 +544,18 @@ export default { if (imageViewIntersectMode.value) this.resetIntersections() else if ([1, 2].includes(deepSSMDataTab.value)) { - this.prepareColorScale() + this.prepareColorScales() } - const targetRenderer = Object.values(this.vtk.renderers)[this.columns - 1] - this.vtk.orientationCube = this.newOrientationCube(this.vtk.interactor) - if (targetRenderer) { - this.vtk.orientationCube.setParentRenderer(targetRenderer) - this.vtk.orientationCube.setEnabled(true); - this.vtk.interactor.enable() + if (!this.showColorScale) { + const targetRenderer = Object.values(this.vtk.renderers)[this.columns - 1] + this.vtk.orientationCube = this.newOrientationCube(this.vtk.interactor) + if (targetRenderer) { + this.vtk.orientationCube.setParentRenderer(targetRenderer) + this.vtk.orientationCube.setEnabled(true); + } } + this.vtk.interactor.enable() this.render(); renderLoading.value = false; setTimeout( diff --git a/web/shapeworks/src/components/ShapeViewer/scripting.js b/web/shapeworks/src/components/ShapeViewer/scripting.js index 199bbcd6..86edc1e0 100644 --- a/web/shapeworks/src/components/ShapeViewer/scripting.js +++ b/web/shapeworks/src/components/ShapeViewer/scripting.js @@ -57,12 +57,6 @@ export default { } return grid; }, - labelCanvas() { - return this.$refs.labels; - }, - labelCanvasContext() { - return this.labelCanvas.getContext("2d"); - }, showDifferenceFromMeanMode() { return showDifferenceFromMeanMode.value }, diff --git a/web/shapeworks/src/components/ShapeViewer/viewer.vue b/web/shapeworks/src/components/ShapeViewer/viewer.vue index b2329cb2..c3ae2f70 100644 --- a/web/shapeworks/src/components/ShapeViewer/viewer.vue +++ b/web/shapeworks/src/components/ShapeViewer/viewer.vue @@ -5,12 +5,9 @@ class="render-area" > - -
- Distance from particle on mean shape - DeepSSM Error -
-
+
+
+
@@ -25,28 +22,21 @@ width: 100%; height: 100%; } -.color-scale-canvas { +.color-scales { position: absolute; - right: 10px; - width: 20px; + width: 100%; height: 100%; - z-index: 1; } -.color-scale-labels-canvas { - display: flex; - flex-direction: column; - justify-content: space-between; - align-items: flex-end; +.color-scale-title { position: absolute; - right: 35px; - width: 50px; + writing-mode: vertical-rl; + right: -10px; + text-align: center; height: 100%; } -.color-scale-title-text { +.color-scale-labels { position: absolute; - writing-mode: vertical-rl; - right: -10px; - text-align: center;; + width: 100%; height: 100%; } diff --git a/web/shapeworks/src/reader/image.ts b/web/shapeworks/src/reader/image.ts index b45b80fa..6574c073 100644 --- a/web/shapeworks/src/reader/image.ts +++ b/web/shapeworks/src/reader/image.ts @@ -26,7 +26,7 @@ async function readDeepSSMScalars(url: string | undefined) { const data = reader.getOutputData(); const content = (await axios.get(url)).data; - let values: number[] = [] + const values: number[] = [] content.split('\r\n\r').forEach((section) => { if (section.includes('deepssm_error')) { const [header, data] = section.split('deepssm_error') @@ -38,18 +38,14 @@ async function readDeepSSMScalars(url: string | undefined) { } } }) - console.log(values) const dataRange = [ Math.min(...values), Math.max(...values), ] - console.log('range', dataRange) - console.log('global 1', deepSSMErrorGlobalRange.value) deepSSMErrorGlobalRange.value = [ Math.min(deepSSMErrorGlobalRange.value[0], dataRange[0]), Math.max(deepSSMErrorGlobalRange.value[1], dataRange[1]), ] - console.log(deepSSMErrorGlobalRange.value) const arr = vtkDataArray.newInstance({ name: 'deepssm_error', diff --git a/web/shapeworks/src/views/Main.vue b/web/shapeworks/src/views/Main.vue index 5c104fd4..6b2a864e 100644 --- a/web/shapeworks/src/views/Main.vue +++ b/web/shapeworks/src/views/Main.vue @@ -442,17 +442,19 @@ export default { ) } if(layersShown.value.includes("Groomed")){ - const shapeURL = groomedShapesForOriginalDataObjects.value[ - dataObject.type - ][dataObject.id].file - shapePromises.push( - imageReader( - shapeURL, - shortFileName(shapeURL), - "Groomed", - { domain: dataObject.anatomy_type.replace('anatomy_', '') } - ) - ) + if (groomedShapesForOriginalDataObjects.value[dataObject.type]) { + const shapeURL = groomedShapesForOriginalDataObjects.value[ + dataObject.type + ][dataObject.id].file + shapePromises.push( + imageReader( + shapeURL, + shortFileName(shapeURL), + "Groomed", + { domain: dataObject.anatomy_type.replace('anatomy_', '') } + ) + ) + } } if(layersShown.value.includes("Reconstructed")){ const targetReconstruction = reconstructionsForOriginalDataObjects.value.find(