Skip to content

Commit

Permalink
Merge pull request #969 from silx-kit/rgb-shader
Browse files Browse the repository at this point in the history
Refactor `RgbVis` with shader
  • Loading branch information
axelboc authored Feb 16, 2022
2 parents 6973b2c + 06eab05 commit e96284c
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 77 deletions.
5 changes: 2 additions & 3 deletions apps/storybook/src/HeatmapMesh.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
VisCanvas,
} from '@h5web/lib';
import type { HeatmapMeshProps, Domain } from '@h5web/lib';
import { ScaleType, toTypedNdArray } from '@h5web/shared';
import { getDims, ScaleType, toTypedNdArray } from '@h5web/shared';
import type { Meta, Story } from '@storybook/react/types-6-0';
import ndarray from 'ndarray';
import {
Expand All @@ -31,8 +31,7 @@ const uint16DataArray = ndarray(Uint16Array.from(uint16Values), [2, 2]);
const uint16Domain: Domain = [10, 40];

const Template: Story<HeatmapMeshProps> = (args) => {
const { shape } = args.values;
const [rows, cols] = shape;
const { rows, cols } = getDims(args.values);

return (
<VisCanvas
Expand Down
4 changes: 2 additions & 2 deletions packages/lib/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ export { INTERPOLATORS } from './vis/heatmap/interpolators';
export { ScaleType } from '@h5web/shared';
export { CurveType } from './vis/line/models';

export type { Domain } from '@h5web/shared';
export type { Domain, Dims } from '@h5web/shared';

export type {
DomainErrors,
Expand All @@ -90,7 +90,7 @@ export type {
AxisParams,
} from './vis/models';

export type { Dims, D3Interpolator, ColorMap } from './vis/heatmap/models';
export type { D3Interpolator, ColorMap } from './vis/heatmap/models';

// Mock data and utilities
export {
Expand Down
11 changes: 2 additions & 9 deletions packages/lib/src/vis/heatmap/HeatmapMesh.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { DataTexture, RGBFormat, UnsignedByteType } from 'three';
import { useAxisSystemContext } from '../..';
import type { VisScaleType } from '../models';
import VisMesh from '../shared/VisMesh';
import { DEFAULT_DOMAIN, getUniforms } from '../utils';
import { DEFAULT_DOMAIN, getUniforms, VERTEX_SHADER } from '../utils';
import type { ColorMap, TextureSafeTypedArray } from './models';
import { getDataTexture, getInterpolator, scaleDomain } from './utils';

Expand Down Expand Up @@ -93,14 +93,7 @@ function HeatmapMesh(props: Props) {
alphaMin: alphaDomain[0],
oneOverAlphaRange: 1 / (alphaDomain[1] - alphaDomain[0]),
}),
vertexShader: `
varying vec2 coords;
void main() {
coords = uv;
gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
}
`,
vertexShader: VERTEX_SHADER,
fragmentShader: `
uniform sampler2D data;
uniform sampler2D colorMap;
Expand Down
17 changes: 12 additions & 5 deletions packages/lib/src/vis/heatmap/HeatmapVis.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import type { Domain, NumArray, NumericType } from '@h5web/shared';
import { assertDefined, formatTooltipVal, ScaleType } from '@h5web/shared';
import {
assertDefined,
formatTooltipVal,
getDims,
ScaleType,
} from '@h5web/shared';
import type { NdArray } from 'ndarray';
import type { ReactElement, ReactNode } from 'react';

Expand All @@ -13,9 +18,8 @@ import { DEFAULT_DOMAIN, formatNumType } from '../utils';
import ColorBar from './ColorBar';
import HeatmapMesh from './HeatmapMesh';
import styles from './HeatmapVis.module.css';
import { useAxisValues } from './hooks';
import { useAxisValues, useTextureSafeNdArray } from './hooks';
import type { ColorMap, Layout, TooltipData } from './models';
import { getDims, toTextureSafeNdArray } from './utils';

interface Props {
dataArray: NdArray<NumArray>;
Expand Down Expand Up @@ -69,6 +73,9 @@ function HeatmapVis(props: Props) {
const abscissaToIndex = useValueToIndexScale(abscissas);
const ordinateToIndex = useValueToIndexScale(ordinates);

const safeDataArray = useTextureSafeNdArray(dataArray);
const safeAlphaArray = useTextureSafeNdArray(alpha?.array);

return (
<figure className={styles.root} aria-label={title} data-keep-canvas-colors>
<VisCanvas
Expand Down Expand Up @@ -117,12 +124,12 @@ function HeatmapVis(props: Props) {
}}
/>
<HeatmapMesh
values={toTextureSafeNdArray(dataArray)}
values={safeDataArray}
domain={domain}
colorMap={colorMap}
invertColorMap={invertColorMap}
scaleType={scaleType}
alphaValues={alpha && toTextureSafeNdArray(alpha.array)}
alphaValues={safeAlphaArray}
alphaDomain={alpha?.domain}
/>
{children}
Expand Down
27 changes: 26 additions & 1 deletion packages/lib/src/vis/heatmap/hooks.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
import type { NumArray } from '@h5web/shared';
import type { NdArray } from 'ndarray';
import { useMemo } from 'react';
import { createMemo } from 'react-use';

import { getVisDomain, getSafeDomain, getAxisValues } from './utils';
import type { TextureSafeTypedArray } from './models';
import {
getVisDomain,
getSafeDomain,
getAxisValues,
toTextureSafeNdArray,
} from './utils';

export const useVisDomain = createMemo(getVisDomain);
export const useSafeDomain = createMemo(getSafeDomain);
export const useAxisValues = createMemo(getAxisValues);

export function useTextureSafeNdArray(
ndArr: NdArray<NumArray>
): NdArray<TextureSafeTypedArray>;

export function useTextureSafeNdArray(
ndArr: NdArray<NumArray> | undefined
): NdArray<TextureSafeTypedArray> | undefined;

export function useTextureSafeNdArray(
ndArr: NdArray<NumArray> | undefined
): NdArray<TextureSafeTypedArray> | undefined {
return useMemo(() => {
return ndArr && toTextureSafeNdArray(ndArr);
}, [ndArr]);
}
5 changes: 0 additions & 5 deletions packages/lib/src/vis/heatmap/models.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
import type { INTERPOLATORS } from './interpolators';

export interface Dims {
rows: number;
cols: number;
}

export type D3Interpolator = (t: number) => string;

export type ColorMap = keyof typeof INTERPOLATORS;
Expand Down
32 changes: 13 additions & 19 deletions packages/lib/src/vis/heatmap/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import type { Domain, NumArray } from '@h5web/shared';
import { isTypedArray, ScaleType, toTypedNdArray } from '@h5web/shared';
import {
getDims,
isTypedArray,
ScaleType,
toTypedNdArray,
} from '@h5web/shared';
import { range } from 'lodash';
import type { NdArray } from 'ndarray';
import {
Expand All @@ -17,12 +22,7 @@ import type { CustomDomain, DomainErrors } from '../models';
import { DomainError } from '../models';
import { H5WEB_SCALES } from '../scales';
import { INTERPOLATORS } from './interpolators';
import type {
ColorMap,
D3Interpolator,
Dims,
TextureSafeTypedArray,
} from './models';
import type { ColorMap, D3Interpolator, TextureSafeTypedArray } from './models';

const GRADIENT_PRECISION = 1 / 20;
export const GRADIENT_RANGE = range(
Expand Down Expand Up @@ -86,11 +86,6 @@ export function getSafeDomain(
];
}

export function getDims(dataArray: NdArray): Dims {
const [rows, cols] = dataArray.shape;
return { rows, cols };
}

function getColorStops(
interpolator: D3Interpolator,
minMaxOnly: boolean
Expand Down Expand Up @@ -152,13 +147,13 @@ export function scaleDomain(
}

export function toTextureSafeNdArray(
arr: NdArray<NumArray>
ndArr: NdArray<NumArray>
): NdArray<TextureSafeTypedArray> {
if (arr.dtype === 'float32' || arr.dtype.startsWith('uint8')) {
return arr as NdArray<TextureSafeTypedArray>;
if (ndArr.dtype === 'float32' || ndArr.dtype.startsWith('uint8')) {
return ndArr as NdArray<TextureSafeTypedArray>;
}

return toTypedNdArray(arr, Float32Array);
return toTypedNdArray(ndArr, Float32Array);
}

/*
Expand All @@ -174,11 +169,10 @@ export function getDataTexture(
values: NdArray<TextureSafeTypedArray | Uint16Array>,
magFilter = NearestFilter
): DataTexture {
const { data, shape } = values;
const [rows, cols] = shape;
const { rows, cols } = getDims(values);

return new DataTexture(
data,
values.data,
cols,
rows,
RedFormat,
Expand Down
50 changes: 50 additions & 0 deletions packages/lib/src/vis/rgb/RgbMesh.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import type { NdArray } from 'ndarray';
import { useMemo } from 'react';

import VisMesh from '../shared/VisMesh';
import { getUniforms, VERTEX_SHADER } from '../utils';
import { getDataTexture3D } from './utils';

interface Props {
values: NdArray<Uint8Array | Uint8ClampedArray | Float32Array>;
bgr?: boolean;
}

function RgbMesh(props: Props) {
const { values, bgr = false } = props;

const dataTexture = useMemo(() => getDataTexture3D(values), [values]);

const shader = {
uniforms: getUniforms({ data: dataTexture, bgr }),
vertexShader: VERTEX_SHADER,
fragmentShader: `
uniform highp sampler3D data;
uniform bool bgr;
varying vec2 coords;
void main() {
float yFlipped = 1. - coords.y;
float red = texture(data, vec3(0., coords.x, yFlipped)).r;
float green = texture(data, vec3(0.5, coords.x, yFlipped)).r;
float blue = texture(data, vec3(1., coords.x, yFlipped)).r;
if (bgr) {
gl_FragColor = vec4(blue, green, red, 1.);
} else {
gl_FragColor = vec4(red, green, blue, 1.);
}
}
`,
};

return (
<VisMesh>
<shaderMaterial args={[shader]} />
</VisMesh>
);
}

export default RgbMesh;
29 changes: 5 additions & 24 deletions packages/lib/src/vis/rgb/RgbVis.tsx
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import type { NumArray } from '@h5web/shared';
import { getDims } from '@h5web/shared';
import type { NdArray } from 'ndarray';
import type { ReactNode } from 'react';
import { useMemo } from 'react';
import { DataTexture, FloatType, RGBFormat, UnsignedByteType } from 'three';

import styles from '../heatmap/HeatmapVis.module.css';
import type { Layout } from '../heatmap/models';
import { getDims } from '../heatmap/utils';
import PanMesh from '../shared/PanMesh';
import VisCanvas from '../shared/VisCanvas';
import VisMesh from '../shared/VisMesh';
import ZoomMesh from '../shared/ZoomMesh';
import RgbMesh from './RgbMesh';
import { ImageType } from './models';
import { flipLastDimension, toRgbSafeNdArray } from './utils';
import { toRgbSafeNdArray } from './utils';

interface Props {
dataArray: NdArray<NumArray>;
Expand All @@ -34,23 +33,7 @@ function RgbVis(props: Props) {
} = props;

const { rows, cols } = getDims(dataArray);

const texture = useMemo(() => {
const typedDataArray = toRgbSafeNdArray(dataArray);

const flippedDataArray =
imageType === ImageType.BGR
? flipLastDimension(typedDataArray)
: typedDataArray;

return new DataTexture(
flippedDataArray.data,
cols,
rows,
RGBFormat,
flippedDataArray.dtype === 'float32' ? FloatType : UnsignedByteType
);
}, [dataArray, imageType, cols, rows]);
const safeDataArray = useMemo(() => toRgbSafeNdArray(dataArray), [dataArray]);

return (
<figure className={styles.root} aria-label={title} data-keep-canvas-colors>
Expand All @@ -72,9 +55,7 @@ function RgbVis(props: Props) {
>
<PanMesh />
<ZoomMesh />
<VisMesh scale={[1, -1, 1]}>
<meshBasicMaterial map={texture} />
</VisMesh>
<RgbMesh values={safeDataArray} bgr={imageType === ImageType.BGR} />
{children}
</VisCanvas>
</figure>
Expand Down
21 changes: 12 additions & 9 deletions packages/lib/src/vis/rgb/utils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import type { NumArray } from '@h5web/shared';
import { createArrayFromView, toTypedNdArray } from '@h5web/shared';
import type { NdArray, TypedArray } from 'ndarray';
import { getDims, toTypedNdArray } from '@h5web/shared';
import type { NdArray } from 'ndarray';
import ndarray from 'ndarray';
import { DataTexture3D, FloatType, RedFormat, UnsignedByteType } from 'three';

/*
* - `Float32Array | Float64Array` ndarrays must contain normalized color values in the range [0, 1].
Expand Down Expand Up @@ -33,12 +34,14 @@ export function toRgbSafeNdArray(
return toTypedNdArray(ndArr, Uint8Array);
}

export function flipLastDimension<T extends TypedArray>(
dataArray: NdArray<T>
): NdArray<T> {
const { shape } = dataArray;
const steps = shape.map((_, index) => (index === shape.length - 1 ? -1 : 1));
export function getDataTexture3D(
values: NdArray<Uint8Array | Uint8ClampedArray | Float32Array>
): DataTexture3D {
const { rows, cols } = getDims(values);

const flippedView = dataArray.step(...steps);
return createArrayFromView(flippedView);
const texture = new DataTexture3D(values.data, 3, cols, rows);
texture.format = RedFormat;
texture.type = values.dtype === 'float32' ? FloatType : UnsignedByteType;

return texture;
}
9 changes: 9 additions & 0 deletions packages/lib/src/vis/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,15 @@ const TYPE_STRINGS: Record<NumericType['class'], string> = {
[DTypeClass.Float]: 'float',
};

export const VERTEX_SHADER = `
varying vec2 coords;
void main() {
coords = uv;
gl_Position = projectionMatrix * modelViewMatrix * vec4(position, 1.0);
}
`;

export function formatNumType(numType: NumericType): string {
return `${TYPE_STRINGS[numType.class]}${numType.size}`;
}
Expand Down
5 changes: 5 additions & 0 deletions packages/shared/src/models-vis.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,8 @@ export interface Bounds {
positiveMin: number;
strictPositiveMin: number;
}

export interface Dims {
rows: number;
cols: number;
}
Loading

0 comments on commit e96284c

Please sign in to comment.