Skip to content

Commit

Permalink
Add support to transformed_data for reconstructed charts (with from_d…
Browse files Browse the repository at this point in the history
…ict/from_json) (#3102)

* Fix transformed_data for reconstructed charts (to_dict/to_json)

* Add tests
  • Loading branch information
binste authored Jul 10, 2023
1 parent e743ecd commit 84cdd39
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 14 deletions.
56 changes: 51 additions & 5 deletions altair/utils/_transformed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,21 @@
HConcatChart,
VConcatChart,
ConcatChart,
TopLevelUnitSpec,
FacetedUnitSpec,
UnitSpec,
UnitSpecWithFrame,
NonNormalizedSpec,
TopLevelLayerSpec,
LayerSpec,
TopLevelConcatSpec,
ConcatSpecGenericSpec,
TopLevelHConcatSpec,
HConcatSpecGenericSpec,
TopLevelVConcatSpec,
VConcatSpecGenericSpec,
TopLevelFacetSpec,
FacetSpec,
data_transformers,
)
from altair.utils._vegafusion_data import get_inline_tables
Expand All @@ -17,6 +32,25 @@
FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]]


# For the transformed_data functionality, the chart classes in the values
# can be considered equivalent to the chart class in the key.
_chart_class_mapping = {
Chart: (
Chart,
TopLevelUnitSpec,
FacetedUnitSpec,
UnitSpec,
UnitSpecWithFrame,
NonNormalizedSpec,
),
LayerChart: (LayerChart, TopLevelLayerSpec, LayerSpec),
ConcatChart: (ConcatChart, TopLevelConcatSpec, ConcatSpecGenericSpec),
HConcatChart: (HConcatChart, TopLevelHConcatSpec, HConcatSpecGenericSpec),
VConcatChart: (VConcatChart, TopLevelVConcatSpec, VConcatSpecGenericSpec),
FacetChart: (FacetChart, TopLevelFacetSpec, FacetSpec),
}


@overload
def transformed_data(
chart: Union[Chart, FacetChart],
Expand Down Expand Up @@ -118,6 +152,16 @@ def transformed_data(chart, row_limit=None, exclude=None):
return datasets


# The equivalent classes from _chart_class_mapping should also be added
# to the type hints below for `chart` as the function would also work for them.
# However, this was not possible so far as mypy then complains about
# "Overloaded function signatures 1 and 2 overlap with incompatible return types [misc]"
# This might be due to the complex type hierarchy of the chart classes.
# See also https://github.com/python/mypy/issues/5119
# and https://github.com/python/mypy/issues/4020 which show that mypy might not have
# a very consistent behavior for overloaded functions.
# The same error appeared when trying it with Protocols for the concat and layer charts.
# This function is only used internally and so we accept this inconsistency for now.
def name_views(
chart: Union[
Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart
Expand Down Expand Up @@ -148,7 +192,9 @@ def name_views(
List of the names of the charts and subcharts
"""
exclude = set(exclude) if exclude is not None else set()
if isinstance(chart, (Chart, FacetChart)):
if isinstance(chart, _chart_class_mapping[Chart]) or isinstance(
chart, _chart_class_mapping[FacetChart]
):
if chart.name not in exclude:
if chart.name in (None, Undefined):
# Add name since none is specified
Expand All @@ -157,13 +203,13 @@ def name_views(
else:
return []
else:
if isinstance(chart, LayerChart):
if isinstance(chart, _chart_class_mapping[LayerChart]):
subcharts = chart.layer
elif isinstance(chart, HConcatChart):
elif isinstance(chart, _chart_class_mapping[HConcatChart]):
subcharts = chart.hconcat
elif isinstance(chart, VConcatChart):
elif isinstance(chart, _chart_class_mapping[VConcatChart]):
subcharts = chart.vconcat
elif isinstance(chart, ConcatChart):
elif isinstance(chart, _chart_class_mapping[ConcatChart]):
subcharts = chart.concat
else:
raise ValueError(
Expand Down
40 changes: 31 additions & 9 deletions tests/test_transformed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@
("window_rank.py", 12, ["team", "diff"]),
])
# fmt: on
def test_primitive_chart_examples(filename, rows, cols):
@pytest.mark.parametrize("to_reconstruct", [True, False])
def test_primitive_chart_examples(filename, rows, cols, to_reconstruct):
source = pkgutil.get_data(examples_methods_syntax.__name__, filename)
chart = eval_block(source)
if to_reconstruct:
# When reconstructing a Chart, Altair uses different classes
# then what might have been originally used. See
# https://github.com/hex-inc/vegafusion/issues/354 for more info.
chart = alt.Chart.from_dict(chart.to_dict())
df = chart.transformed_data()

assert len(df) == rows
assert set(cols).issubset(set(df.columns))

Expand Down Expand Up @@ -96,19 +103,29 @@ def test_primitive_chart_examples(filename, rows, cols):
("us_population_pyramid_over_time.py", [19, 38, 19], [["gender"], ["year"], ["gender"]]),
])
# fmt: on
def test_compound_chart_examples(filename, all_rows, all_cols):
@pytest.mark.parametrize("to_reconstruct", [True, False])
def test_compound_chart_examples(filename, all_rows, all_cols, to_reconstruct):
source = pkgutil.get_data(examples_methods_syntax.__name__, filename)
chart = eval_block(source)
print(chart)

if to_reconstruct:
# When reconstructing a Chart, Altair uses different classes
# then what might have been originally used. See
# https://github.com/hex-inc/vegafusion/issues/354 for more info.
chart = alt.Chart.from_dict(chart.to_dict())
dfs = chart.transformed_data()
assert len(dfs) == len(all_rows)
for df, rows, cols in zip(dfs, all_rows, all_cols):
assert len(df) == rows
assert set(cols).issubset(set(df.columns))

if not to_reconstruct:
# Only run assert statements if the chart is not reconstructed. Reason
# is that for some charts, the original chart contained duplicated datasets
# which disappear when reconstructing the chart.
assert len(dfs) == len(all_rows)
for df, rows, cols in zip(dfs, all_rows, all_cols):
assert len(df) == rows
assert set(cols).issubset(set(df.columns))


def test_transformed_data_exclude():
@pytest.mark.parametrize("to_reconstruct", [True, False])
def test_transformed_data_exclude(to_reconstruct):
source = data.wheat()
bar = alt.Chart(source).mark_bar().encode(x="year:O", y="wheat:Q")
rule = alt.Chart(source).mark_rule(color="red").encode(y="mean(wheat):Q")
Expand All @@ -119,6 +136,11 @@ def test_transformed_data_exclude():
)

chart = (bar + rule + some_annotation).properties(width=600)
if to_reconstruct:
# When reconstructing a Chart, Altair uses different classes
# then what might have been originally used. See
# https://github.com/hex-inc/vegafusion/issues/354 for more info.
chart = alt.Chart.from_dict(chart.to_dict())
datasets = chart.transformed_data(exclude=["some_annotation"])

assert len(datasets) == 2
Expand Down

0 comments on commit 84cdd39

Please sign in to comment.