Skip to content

Commit

Permalink
Integrate VegaFusion into JupyterChart (#3281)
Browse files Browse the repository at this point in the history
* Enhance JupyterChart to handle timezone and None charts

* Update JupyterChart to support VegaFusion ChartState

* Bump minimum vegafusion version to 1.5.0 for ChartState support

* Add max_wait option to JupyterChart

* Improve type hints

* Help mypy

* bump vegafusion in pyproject.toml

* Fix JupyterChart tests

* Use float if initial param value is int

* Add debug property and use this to enable printing VegaFusion messages
(these will end up in the JupyterLab console)

* mypy fixes

* Update Large Dataset documentation with JupyterChart usage

* Use built-in structuredClone to avoid deepClone dependency

* Rename imports to match vl-convert's bundling convention (for future offline support)
  • Loading branch information
jonmmease authored Dec 26, 2023
1 parent 50f4748 commit ebf9da5
Show file tree
Hide file tree
Showing 8 changed files with 345 additions and 52 deletions.
147 changes: 140 additions & 7 deletions altair/jupyter/js/index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import embed from "https://esm.sh/vega-embed@6?deps=vega@5&deps=vega-lite@5.16.3";
import debounce from "https://esm.sh/lodash-es@4.17.21/debounce";
import vegaEmbed from "https://esm.sh/vega-embed@6?deps=vega@5&deps=vega-lite@5.16.3";
import lodashDebounce from "https://esm.sh/lodash-es@4.17.21/debounce";

export async function render({ model, el }) {
let finalize;
Expand All @@ -19,10 +19,21 @@ export async function render({ model, el }) {
finalize();
}

let spec = model.get("spec");
model.set("local_tz", Intl.DateTimeFormat().resolvedOptions().timeZone);

let spec = structuredClone(model.get("spec"));
if (spec == null) {
// Remove any existing chart and return
while (el.firstChild) {
el.removeChild(el.lastChild);
}
model.save_changes();
return;
}

let api;
try {
api = await embed(el, spec);
api = await vegaEmbed(el, spec);
} catch (error) {
showError(error)
return;
Expand All @@ -32,7 +43,10 @@ export async function render({ model, el }) {

// Debounce config
const wait = model.get("debounce_wait") ?? 10;
const maxWait = wait;
const debounceOpts = {leading: false, trailing: true};
if (model.get("max_wait") ?? true) {
debounceOpts["maxWait"] = wait;
}

const initialSelections = {};
for (const selectionName of Object.keys(model.get("_vl_selections"))) {
Expand All @@ -45,7 +59,7 @@ export async function render({ model, el }) {
model.set("_vl_selections", newSelections);
model.save_changes();
};
api.view.addSignalListener(selectionName, debounce(selectionHandler, wait, {maxWait}));
api.view.addSignalListener(selectionName, lodashDebounce(selectionHandler, wait, debounceOpts));

initialSelections[selectionName] = {
value: cleanJson(api.view.signal(selectionName) ?? {}),
Expand All @@ -62,7 +76,7 @@ export async function render({ model, el }) {
model.set("_params", newParams);
model.save_changes();
};
api.view.addSignalListener(paramName, debounce(paramHandler, wait, {maxWait}));
api.view.addSignalListener(paramName, lodashDebounce(paramHandler, wait, debounceOpts));

initialParams[paramName] = api.view.signal(paramName) ?? null
}
Expand All @@ -76,13 +90,132 @@ export async function render({ model, el }) {
}
await api.view.runAsync();
});

// Add signal/data listeners
for (const watch of model.get("_js_watch_plan") ?? []) {
if (watch.namespace === "data") {
const dataHandler = (_, value) => {
model.set("_js_to_py_updates", [{
namespace: "data",
name: watch.name,
scope: watch.scope,
value: cleanJson(value)
}]);
model.save_changes();
};
addDataListener(api.view, watch.name, watch.scope, lodashDebounce(dataHandler, wait, debounceOpts))

} else if (watch.namespace === "signal") {
const signalHandler = (_, value) => {
model.set("_js_to_py_updates", [{
namespace: "signal",
name: watch.name,
scope: watch.scope,
value: cleanJson(value)
}]);
model.save_changes();
};

addSignalListener(api.view, watch.name, watch.scope, lodashDebounce(signalHandler, wait, debounceOpts))
}
}

// Add signal/data updaters
model.on('change:_py_to_js_updates', async (updates) => {
for (const update of updates.changed._py_to_js_updates ?? []) {
if (update.namespace === "signal") {
setSignalValue(api.view, update.name, update.scope, update.value);
} else if (update.namespace === "data") {
setDataValue(api.view, update.name, update.scope, update.value);
}
}
await api.view.runAsync();
});
}

model.on('change:spec', reembed);
model.on('change:debounce_wait', reembed);
model.on('change:max_wait', reembed);
await reembed();
}

function cleanJson(data) {
return JSON.parse(JSON.stringify(data))
}

function getNestedRuntime(view, scope) {
var runtime = view._runtime;
for (const index of scope) {
runtime = runtime.subcontext[index];
}
return runtime
}

function lookupSignalOp(view, name, scope) {
let parent_runtime = getNestedRuntime(view, scope);
return parent_runtime.signals[name] ?? null;
}

function dataRef(view, name, scope) {
let parent_runtime = getNestedRuntime(view, scope);
return parent_runtime.data[name];
}

export function setSignalValue(view, name, scope, value) {
let signal_op = lookupSignalOp(view, name, scope);
view.update(signal_op, value);
}

export function setDataValue(view, name, scope, value) {
let dataset = dataRef(view, name, scope);
let changeset = view.changeset().remove(() => true).insert(value)
dataset.modified = true;
view.pulse(dataset.input, changeset);
}

export function addSignalListener(view, name, scope, handler) {
let signal_op = lookupSignalOp(view, name, scope);
return addOperatorListener(
view,
name,
signal_op,
handler,
);
}

export function addDataListener(view, name, scope, handler) {
let dataset = dataRef(view, name, scope).values;
return addOperatorListener(
view,
name,
dataset,
handler,
);
}

// Private helpers from Vega for dealing with nested signals/data
function findOperatorHandler(op, handler) {
const h = (op._targets || [])
.filter(op => op._update && op._update.handler === handler);
return h.length ? h[0] : null;
}

function addOperatorListener(view, name, op, handler) {
let h = findOperatorHandler(op, handler);
if (!h) {
h = trap(view, () => handler(name, op.value));
h.handler = handler;
view.on(op, null, h);
}
return view;
}

function trap(view, fn) {
return !fn ? null : function() {
try {
fn.apply(this, arguments);
} catch (error) {
view.error(error);
}
};
}
100 changes: 88 additions & 12 deletions altair/jupyter/jupyter_chart.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import anywidget
import traitlets
import pathlib
from typing import Any, Set

import altair as alt
from altair.utils._vegafusion_data import using_vegafusion
from altair.utils._vegafusion_data import (
using_vegafusion,
compile_to_vegafusion_chart_state,
)
from altair import TopLevelSpec
from altair.utils.selection import IndexSelection, PointSelection, IntervalSelection

Expand All @@ -20,9 +24,7 @@ def __init__(self, trait_values):
super().__init__()

for key, value in trait_values.items():
if isinstance(value, int):
traitlet_type = traitlets.Int()
elif isinstance(value, float):
if isinstance(value, (int, float)):
traitlet_type = traitlets.Float()
elif isinstance(value, str):
traitlet_type = traitlets.Unicode()
Expand Down Expand Up @@ -101,9 +103,12 @@ class JupyterChart(anywidget.AnyWidget):
"""

# Public traitlets
chart = traitlets.Instance(TopLevelSpec)
spec = traitlets.Dict().tag(sync=True)
chart = traitlets.Instance(TopLevelSpec, allow_none=True)
spec = traitlets.Dict(allow_none=True).tag(sync=True)
debounce_wait = traitlets.Float(default_value=10).tag(sync=True)
max_wait = traitlets.Bool(default_value=True).tag(sync=True)
local_tz = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True)
debug = traitlets.Bool(default_value=False)

# Internal selection traitlets
_selection_types = traitlets.Dict()
Expand All @@ -112,7 +117,20 @@ class JupyterChart(anywidget.AnyWidget):
# Internal param traitlets
_params = traitlets.Dict().tag(sync=True)

def __init__(self, chart: TopLevelSpec, debounce_wait: int = 10, **kwargs: Any):
# Internal comm traitlets for VegaFusion support
_chart_state = traitlets.Any(allow_none=True)
_js_watch_plan = traitlets.Any(allow_none=True).tag(sync=True)
_js_to_py_updates = traitlets.Any(allow_none=True).tag(sync=True)
_py_to_js_updates = traitlets.Any(allow_none=True).tag(sync=True)

def __init__(
self,
chart: TopLevelSpec,
debounce_wait: int = 10,
max_wait: bool = True,
debug: bool = False,
**kwargs: Any,
):
"""
Jupyter Widget for displaying and updating Altair Charts, and
retrieving selection and parameter values
Expand All @@ -122,11 +140,24 @@ def __init__(self, chart: TopLevelSpec, debounce_wait: int = 10, **kwargs: Any):
chart: Chart
Altair Chart instance
debounce_wait: int
Debouncing wait time in milliseconds
Debouncing wait time in milliseconds. Updates will be sent from the client to the kernel
after debounce_wait milliseconds of no chart interactions.
max_wait: bool
If True (default), updates will be sent from the client to the kernel every debounce_wait
milliseconds even if there are ongoing chart interactions. If False, updates will not be
sent until chart interactions have completed.
debug: bool
If True, debug messages will be printed
"""
self.params = Params({})
self.selections = Selections({})
super().__init__(chart=chart, debounce_wait=debounce_wait, **kwargs)
super().__init__(
chart=chart,
debounce_wait=debounce_wait,
max_wait=max_wait,
debug=debug,
**kwargs,
)

@traitlets.observe("chart")
def _on_change_chart(self, change):
Expand All @@ -135,14 +166,22 @@ def _on_change_chart(self, change):
state when the wrapped Chart instance changes
"""
new_chart = change.new

params = getattr(new_chart, "params", [])
selection_watches = []
selection_types = {}
initial_params = {}
initial_vl_selections = {}
empty_selections = {}

if new_chart is None:
with self.hold_sync():
self.spec = None
self._selection_types = selection_types
self._vl_selections = initial_vl_selections
self._params = initial_params
return

params = getattr(new_chart, "params", [])

if params is not alt.Undefined:
for param in new_chart.params:
if isinstance(param.name, alt.ParameterName):
Expand Down Expand Up @@ -205,13 +244,50 @@ def on_param_traitlet_changed(param_change):
# Update properties all together
with self.hold_sync():
if using_vegafusion():
self.spec = new_chart.to_dict(format="vega")
if self.local_tz is None:
self.spec = None

def on_local_tz_change(change):
self._init_with_vegafusion(change["new"])

self.observe(on_local_tz_change, ["local_tz"])
else:
self._init_with_vegafusion(self.local_tz)
else:
self.spec = new_chart.to_dict()
self._selection_types = selection_types
self._vl_selections = initial_vl_selections
self._params = initial_params

def _init_with_vegafusion(self, local_tz: str):
if self.chart is not None:
vegalite_spec = self.chart.to_dict(context={"pre_transform": False})
with self.hold_sync():
self._chart_state = compile_to_vegafusion_chart_state(
vegalite_spec, local_tz
)
self._js_watch_plan = self._chart_state.get_watch_plan()[
"client_to_server"
]
self.spec = self._chart_state.get_transformed_spec()

# Callback to update chart state and send updates back to client
def on_js_to_py_updates(change):
if self.debug:
updates_str = json.dumps(change["new"], indent=2)
print(
f"JavaScript to Python VegaFusion updates:\n {updates_str}"
)
updates = self._chart_state.update(change["new"])
if self.debug:
updates_str = json.dumps(updates, indent=2)
print(
f"Python to JavaScript VegaFusion updates:\n {updates_str}"
)
self._py_to_js_updates = updates

self.observe(on_js_to_py_updates, ["_js_to_py_updates"])

@traitlets.observe("_params")
def _on_change_params(self, change):
for param_name, value in change.new.items():
Expand Down
2 changes: 1 addition & 1 deletion altair/utils/_importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def import_vegafusion() -> ModuleType:
min_version = "1.4.0"
min_version = "1.5.0"
try:
version = importlib_version("vegafusion")
if Version(version) < Version(min_version):
Expand Down
Loading

0 comments on commit ebf9da5

Please sign in to comment.