From 78f058faf16c36e007350f5a1f179cae40acb0ee Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sun, 19 May 2019 19:27:42 -0400 Subject: [PATCH 1/5] Port to plotly.py v4_subplots and renderers --- plotly_express/_core.py | 462 ++++++++++++++++++++++++++++++---------- 1 file changed, 347 insertions(+), 115 deletions(-) diff --git a/plotly_express/_core.py b/plotly_express/_core.py index 17a5948..a68cc18 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -1,9 +1,15 @@ +from _plotly_future_ import renderer_defaults, v4_subplots import plotly.graph_objs as go from plotly.offline import init_notebook_mode, iplot import plotly.io as pio from collections import namedtuple, OrderedDict from .colors import qualitative, sequential import math, pandas +from plotly.subplots import ( + make_subplots, + _set_trace_grid_reference, + _subplot_type_for_trace_type, +) class PxDefaults(object): @@ -73,9 +79,11 @@ def get_trendline_results(fig): Mapping = namedtuple( "Mapping", - ["show_in_trace_name", "grouper", "val_map", "sequence", "updater", "variable"], + ["show_in_trace_name", "grouper", "val_map", "sequence", + "updater", "variable", "facet"], ) -TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch"]) +TraceSpec = namedtuple("TraceSpec", + ["constructor", "attrs", "trace_patch", "marginal"]) def get_label(args, column): @@ -108,6 +116,7 @@ def make_mapping(args, variable): sequence=[""], variable=variable, updater=(lambda trace, v: v), + facet=None, ) if variable == "facet_row" or variable == "facet_col": letter = "x" if variable == "facet_col" else "y" @@ -116,8 +125,9 @@ def make_mapping(args, variable): variable=letter, grouper=args[variable], val_map={}, - sequence=[letter + str(i) for i in range(1, 1000)], - updater=lambda trace, v: trace.update({letter + "axis": v}), + sequence=[i for i in range(1, 1000)], + updater=(lambda trace, v: v), + facet="row" if variable == "facet_row" else "col", ) (parent, variable) = variable.split(".") vprefix = variable @@ -134,6 +144,7 @@ def make_mapping(args, variable): val_map=args[vprefix + "_map"].copy(), sequence=args[vprefix + "_sequence"], updater=lambda trace, v: trace.update({parent: {variable: v}}), + facet=None, ) @@ -275,7 +286,7 @@ def make_trace_kwargs( return result, fit_results -def configure_axes(args, constructor, fig, axes, orders): +def configure_axes(args, constructor, fig, orders): configurators = { go.Scatter: configure_cartesian_axes, go.Scattergl: configure_cartesian_axes, @@ -295,126 +306,139 @@ def configure_axes(args, constructor, fig, axes, orders): go.Choropleth: configure_geo, } if constructor in configurators: - fig.update(layout=configurators[constructor](args, fig, axes, orders)) + configurators[constructor](args, fig, orders) -def set_cartesian_axis_opts(args, layout, letter, axis, orders): +def set_cartesian_axis_opts(args, axis, letter, orders): log_key = "log_" + letter range_key = "range_" + letter if log_key in args and args[log_key]: - layout[axis]["type"] = "log" + axis["type"] = "log" if range_key in args and args[range_key]: - layout[axis]["range"] = [math.log(r, 10) for r in args[range_key]] + axis["range"] = [math.log(r, 10) for r in args[range_key]] elif range_key in args and args[range_key]: - layout[axis]["range"] = args[range_key] + axis["range"] = args[range_key] if args[letter] in orders: - layout[axis]["categoryorder"] = "array" - layout[axis]["categoryarray"] = ( + axis["categoryorder"] = "array" + axis["categoryarray"] = ( orders[args[letter]] - if axis.startswith("x") + if isinstance(axis, go.layout.XAxis) else list(reversed(orders[args[letter]])) ) -def configure_cartesian_marginal_axes(args, orders): - layout = dict() +def configure_cartesian_marginal_axes(args, fig, orders): + if "histogram" in [args["marginal_x"], args["marginal_y"]]: - layout["barmode"] = "overlay" - for letter in ["x", "y"]: - layout[letter + "axis1"] = dict( - title=get_decorated_label(args, args[letter], letter) + fig.layout["barmode"] = "overlay" + + nrows = len(fig._grid_ref) + row_step = 2 if args["marginal_x"] else 1 + + ncols = len(fig._grid_ref[0]) + col_step = 2 if args['marginal_y'] else 1 + + # Set y-axis titles and axis options in the left-most column + for yaxis in fig.select_yaxes(col=1): + set_cartesian_axis_opts(args, yaxis, 'y', orders) + + # Set x-axis titles and axis options in the bottom-most row + for xaxis in fig.select_xaxes(row=1): + set_cartesian_axis_opts(args, xaxis, 'x', orders) + + # Configure axis ticks on marginal subplots + for row in range(2, nrows+1, row_step): + fig.update_yaxes( + showticklabels=False, + showgrid=args["marginal_x"] == 'histogram', + row=row ) - set_cartesian_axis_opts(args, layout, letter, letter + "axis1", orders) - for letter in ["x", "y"]: - otherletter = "x" if letter == "y" else "y" - if args["marginal_" + letter]: - if args["marginal_" + letter] == "histogram" or ( - "color" in args and args["color"] - ): - main_size = 0.74 - else: - main_size = 0.84 - layout[otherletter + "axis1"]["domain"] = [0, main_size] - layout[otherletter + "axis1"]["showgrid"] = True - layout[otherletter + "axis2"] = { - "domain": [main_size + 0.005, 1], - "showticklabels": False, - } - return layout + fig.update_xaxes( + showgrid=True, + row=row + ) + + for col in range(2, ncols+1, col_step): + fig.update_xaxes( + showticklabels=False, + showgrid=args["marginal_y"] == 'histogram', + col=col + ) + fig.update_yaxes( + showgrid=True, + col=col + ) + + # Add axis titles to non-marginal subplots + y_title = get_decorated_label(args, args['y'], 'y') + for row in range(1, nrows + 1, row_step): + fig.update_yaxes(title_text=y_title, row=row, col=1) + + x_title = get_decorated_label(args, args['x'], 'x') + for col in range(1, ncols + 1, col_step): + fig.update_xaxes(title_text=x_title, row=1, col=col) + + # Configure axis type across all x-axes + if 'log_x' in args and args['log_x']: + fig.update_xaxes(type='log') + + # Configure axis type across all y-axes + if 'log_y' in args and args['log_y']: + fig.update_yaxes(type='log') + + # Configure matching and axis type for marginal y-axes + matches_y = 'y' + str(ncols + 1) + if args["marginal_x"]: + for row in range(2, nrows + 1, 2): + fig.update_yaxes(matches=matches_y, type=None, row=row) + if args["marginal_y"]: + for col in range(2, ncols + 1, 2): + fig.update_xaxes(matches='x2', type=None, col=col) -def configure_cartesian_axes(args, fig, axes, orders): + +def configure_cartesian_axes(args, fig, orders): if ("marginal_x" in args and args["marginal_x"]) or ( "marginal_y" in args and args["marginal_y"] ): - return configure_cartesian_marginal_axes(args, orders) - - gap = 0.1 - layout = { - "annotations": [], - "grid": { - "xaxes": [], - "yaxes": [], - "xgap": gap, - "ygap": gap, - "xside": "bottom", - "yside": "left", - }, - } + configure_cartesian_marginal_axes(args, fig, orders) + return - for letter, direction, row in (("x", "facet_col", False), ("y", "facet_row", True)): - for letter_number in [t[letter + "axis"] for t in fig.data]: - if letter_number not in layout["grid"][letter + "axes"]: - layout["grid"][letter + "axes"].append(letter_number) - axis = letter_number.replace(letter, letter + "axis") + # Set y-axis titles and axis options in the left-most column + y_title = get_decorated_label(args, args['y'], 'y') + for yaxis in fig.select_yaxes(col=1): + yaxis.update(title_text=y_title) + set_cartesian_axis_opts(args, yaxis, 'y', orders) - layout[axis] = dict( - title=get_decorated_label(args, args[letter], letter) - ) - if len(letter_number) == 1: - set_cartesian_axis_opts(args, layout, letter, axis, orders) - else: - layout[axis]["matches"] = letter - log_key = "log_" + letter - if log_key in args and args[log_key]: - layout[axis]["type"] = "log" - - if args[direction]: - step = 1.0 / (len(layout["grid"][letter + "axes"]) - gap) - for key, value in axes[letter].items(): - i = int(value[1:]) - if row: - i = len(layout["grid"][letter + "axes"]) - i - else: - i -= 1 - layout["annotations"].append( - { - "xref": "paper", - "yref": "paper", - "showarrow": False, - "xanchor": "center", - "yanchor": "middle", - "text": args[direction] + "=" + str(key), - "x": 1.01 if row else step * (i + (0.5 - gap / 2)), - "y": step * (i + (0.5 - gap / 2)) if row else 1.02, - "textangle": 90 if row else 0, - } - ) - return layout + # Set x-axis titles and axis options in the bottom-most row + x_title = get_decorated_label(args, args['x'], 'x') + for xaxis in fig.select_xaxes(row=1): + xaxis.update(title_text=x_title) + set_cartesian_axis_opts(args, xaxis, 'x', orders) + + # Configure axis type across all x-axes + if 'log_x' in args and args['log_x']: + fig.update_xaxes(type='log') + # Configure axis type across all y-axes + if 'log_y' in args and args['log_y']: + fig.update_yaxes(type='log') -def configure_ternary_axes(args, fig, axes, orders): - return dict( + return fig.layout + + +def configure_ternary_axes(args, fig, orders): + fig.update(layout=dict( ternary=dict( aaxis=dict(title=get_label(args, args["a"])), baxis=dict(title=get_label(args, args["b"])), caxis=dict(title=get_label(args, args["c"])), ) - ) + )) -def configure_polar_axes(args, fig, axes, orders): +def configure_polar_axes(args, fig, orders): layout = dict( polar=dict( angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]), @@ -435,10 +459,10 @@ def configure_polar_axes(args, fig, axes, orders): else: if args["range_r"]: radialaxis["range"] = args["range_r"] - return layout + fig.update(layout=layout) -def configure_3d_axes(args, fig, axes, orders): +def configure_3d_axes(args, fig, orders): layout = dict( scene=dict( xaxis=dict(title=get_label(args, args["x"])), @@ -459,11 +483,11 @@ def configure_3d_axes(args, fig, axes, orders): if args[letter] in orders: axis["categoryorder"] = "array" axis["categoryarray"] = orders[args[letter]] - return layout + fig.update(layout=layout) -def configure_mapbox(args, fig, axes, orders): - return dict( +def configure_mapbox(args, fig, orders): + fig.update(layout=dict( mapbox=dict( accesstoken=MAPBOX_TOKEN, center=dict( @@ -472,17 +496,17 @@ def configure_mapbox(args, fig, axes, orders): ), zoom=args["zoom"], ) - ) + )) -def configure_geo(args, fig, axes, orders): - return dict( +def configure_geo(args, fig, orders): + fig.update(layout=dict( geo=dict( center=args["center"], scope=args["scope"], projection=dict(type=args["projection"]), ) - ) + )) def configure_animation_controls(args, constructor, fig): @@ -544,7 +568,10 @@ def frame_args(duration): def make_trace_spec(args, constructor, attrs, trace_patch): - result = [TraceSpec(constructor, attrs, trace_patch)] + # Create base trace specification + result = [TraceSpec(constructor, attrs, trace_patch, None)] + + # Add marginal trace specifications for letter in ["x", "y"]: if "marginal_" + letter in args and args["marginal_" + letter]: trace_spec = None @@ -556,19 +583,22 @@ def make_trace_spec(args, constructor, attrs, trace_patch): trace_spec = TraceSpec( constructor=go.Histogram, attrs=[letter], - trace_patch=dict(opacity=0.5, **axis_map), + trace_patch=dict(opacity=0.5), + marginal=letter ) elif args["marginal_" + letter] == "violin": trace_spec = TraceSpec( constructor=go.Violin, attrs=[letter, "hover_name", "hover_data"], - trace_patch=dict(scalegroup=letter, **axis_map), + trace_patch=dict(scalegroup=letter), + marginal=letter ) elif args["marginal_" + letter] == "box": trace_spec = TraceSpec( constructor=go.Box, attrs=[letter, "hover_name", "hover_data"], - trace_patch=dict(notched=True, **axis_map), + trace_patch=dict(notched=True), + marginal=letter ) elif args["marginal_" + letter] == "rug": symbols = {"x": "line-ns-open", "y": "line-ew-open"} @@ -582,8 +612,8 @@ def make_trace_spec(args, constructor, attrs, trace_patch): jitter=0, hoveron="points", marker={"symbol": symbols[letter]}, - **axis_map ), + marginal=letter ) if "color" in attrs: if "marker" not in trace_spec.trace_patch: @@ -591,9 +621,14 @@ def make_trace_spec(args, constructor, attrs, trace_patch): first_default_color = args["color_discrete_sequence"][0] trace_spec.trace_patch["marker"]["color"] = first_default_color result.append(trace_spec) + + # Add trendline trace specifications if "trendline" in args and args["trendline"]: trace_spec = TraceSpec( - constructor=go.Scatter, attrs=["trendline"], trace_patch=dict(mode="lines") + constructor=go.Scatter, + attrs=["trendline"], + trace_patch=dict(mode="lines"), + marginal=None ) if args["trendline_color_override"]: trace_spec.trace_patch["line"] = dict( @@ -654,6 +689,7 @@ def apply_default_cascade(args): def infer_config(args, constructor, trace_patch): + # Declare all supported attributes, across all plot types attrables = ( ["x", "y", "z", "a", "b", "c", "r", "theta", "size"] + ["dimensions", "hover_name", "hover_data", "text", "error_x", "error_x_minus"] @@ -663,6 +699,8 @@ def infer_config(args, constructor, trace_patch): array_attrables = ["dimensions", "hover_data"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] + # Validate that the strings provided as attribute values reference columns + # in the provided data_frame df_columns = args["data_frame"].columns for attr in attrables + group_attrables + ["color"]: @@ -686,11 +724,12 @@ def infer_config(args, constructor, trace_patch): attrs = [k for k in attrables if k in args] grouped_attrs = [] + # Compute sizeref sizeref = 0 if "size" in args and args["size"]: sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2 - color_range = None + # Compute color attributes and grouping attributes if "color" in args: if "color_continuous_scale" in args: if "color_discrete_sequence" not in args: @@ -708,6 +747,9 @@ def infer_config(args, constructor, trace_patch): else: grouped_attrs.append("marker.color") + # Compute color_range + color_range = None + if "color" in args: if "color" in attrs and args["color"]: cmin = args["data_frame"][args["color"]].min() cmax = args["data_frame"][args["color"]].max() @@ -718,12 +760,15 @@ def infer_config(args, constructor, trace_patch): else: color_range = [cmin, cmax] + # Compute line_dash grouping attribute if "line_dash" in args: grouped_attrs.append("line.dash") + # Compute symbol grouping attribute if "symbol" in args: grouped_attrs.append("marker.symbol") + # Compute final trace patch trace_patch = trace_patch.copy() if "opacity" in args: if args["opacity"] is None: @@ -741,17 +786,22 @@ def infer_config(args, constructor, trace_patch): if "line_shape" in args: trace_patch["line"] = dict(shape=args["line_shape"]) + # Compute marginal attribute if "marginal" in args: position = "marginal_x" if args["orientation"] == "v" else "marginal_y" other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y" args[position] = args["marginal"] args[other_position] = None + # Compute applicable grouping attributes for k in group_attrables: if k in args: grouped_attrs.append(k) + # Create grouped mappings grouped_mappings = [make_mapping(args, a) for a in grouped_attrs] + + # Create trace specs trace_specs = make_trace_spec(args, constructor, attrs, trace_patch) return trace_specs, grouped_mappings, sizeref, color_range @@ -798,9 +848,15 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): orders, sorted_group_names = get_orderings(args, grouper, grouped) + has_marginal_x = bool(args.get('marginal_x', False)) + has_marginal_y = bool(args.get('marginal_y', False)) + + subplot_type = _subplot_type_for_trace_type(constructor().type) + trace_names_by_frame = {} frames = OrderedDict() trendline_rows = [] + nrows = ncols = 1 for group_name in sorted_group_names: group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) mapping_labels = OrderedDict() @@ -848,6 +904,11 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): ]: trace.update(hoverlabel=dict(namelength=0)) trace_names.add(trace_name) + + # Init subplot row/col + trace._subplot_row = 1 + trace._subplot_col = 1 + for i, m in enumerate(grouped_mappings): val = group_name[i] if val not in m.val_map: @@ -869,6 +930,42 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): trace.update(marker=dict(color=m.val_map[val])) else: raise + + # Find row for trace, handling facet_row and marginal_x + if m.facet == 'row': + if has_marginal_x: + row = (m.val_map[val] - 1) * 2 + 2 + else: + row = m.val_map[val] + + trace._subplot_row_val = val + else: + row = 1 + + if trace_spec.marginal == 'x': + row -= 1 + + nrows = max(nrows, row) + if row > 1: + trace._subplot_row = row + + # Find col for trace, handling facet_col and marginal_y + if m.facet == 'col': + if has_marginal_y: + col = (m.val_map[val] - 1) * 2 + 1 + else: + col = m.val_map[val] + + trace._subplot_col_val = val + else: + col = 1 + + if trace_spec.marginal == 'y': + col += 1 + + ncols = max(ncols, col) + if col > 1: + trace._subplot_col = col if ( trace_specs[0].constructor == go.Histogram2dContour and trace_spec.constructor == go.Box @@ -906,13 +1003,148 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): layout_patch["margin"] = {"t": 60} if "size" in args and args["size"]: layout_patch["legend"]["itemsizing"] = "constant" - fig = ExpressFigure( - data=frame_list[0]["data"] if len(frame_list) > 0 else [], - layout=layout_patch, - frames=frame_list if len(frames) > 1 else [], + + fig = init_figure( + args, + subplot_type, + frame_list, + ncols, + nrows, + has_marginal_x, + has_marginal_y, ) + + # Position traces in subplots + for frame in frame_list: + for trace in frame['data']: + if isinstance(trace, go.Splom): + # Special case that is not compatible with make_subplots + continue + + _set_trace_grid_reference( + trace, + fig.layout, + fig._grid_ref, + nrows - trace._subplot_row + 1, + trace._subplot_col, + ) + + # Add traces, layout and frames to figure + fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else []) + fig.layout.update(layout_patch) + fig.frames = frame_list if len(frames) > 1 else [] + fig._px_trendlines = pandas.DataFrame(trendline_rows) - axes = {m.variable: m.val_map for m in grouped_mappings} - configure_axes(args, constructor, fig, axes, orders) + + configure_axes(args, constructor, fig, orders) configure_animation_controls(args, constructor, fig) return fig + + +def init_figure( + args, + subplot_type, + frame_list, + ncols, + nrows, + has_marginal_x, + has_marginal_y, +): + # Build subplot specs + specs = [[{}] * ncols for _ in range(nrows)] + column_titles = [None] * ncols + row_titles = [None] * nrows + for frame in frame_list: + for trace in frame['data']: + row0 = nrows - trace._subplot_row + col0 = trace._subplot_col - 1 + + if isinstance(trace, go.Splom): + # Splom not compatible with make_subplots, treat as domain + specs[row0][col0] = {'type': 'domain'} + else: + specs[row0][col0] = {'type': trace.type} + if (args.get('facet_row', None) and + hasattr(trace, '_subplot_row_val')): + if row0 % 2 == 0 or not has_marginal_x: + row_titles[row0] = ( + args['facet_row'] + + '=' + + str(trace._subplot_row_val) + ) + + if (args.get('facet_col', None) and + hasattr(trace, '_subplot_col_val')): + if col0 % 2 == 0 or not has_marginal_y: + column_titles[col0] = ( + args['facet_col'] + + '=' + + str(trace._subplot_col_val) + ) + + # Default row/column widths uniform + column_widths = [1.0] * ncols + row_heights = [1.0] * nrows + + # Build column_widths/row_heights + if subplot_type == 'xy': + if has_marginal_x: + if args["marginal_x"] == "histogram" or ( + "color" in args and args["color"] + ): + main_size = 0.74 + else: + main_size = 0.84 + + row_heights = [main_size, 1 - main_size] * (nrows // 2) + vertical_spacing = 0.01 + + # Add padding to the top of marginal + for r in range(1, nrows, 2): + for c in range(ncols): + if specs[r][c] is not None: + specs[r][c]['t'] = vertical_spacing * 2 + else: + vertical_spacing = 0.03 + + if has_marginal_y: + if args["marginal_y"] == "histogram" or ( + "color" in args and args["color"] + ): + main_size = 0.74 + else: + main_size = 0.84 + + column_widths = [main_size, 1 - main_size] * (ncols // 2) + horizontal_spacing = 0.005 + + # Add padding to the top of marginal + for r in range(nrows): + for c in range(1, ncols, 2): + if specs[r][c] is not None: + specs[r][c]['r'] = horizontal_spacing * 2 + else: + horizontal_spacing = 0.02 + else: + # Other subplot types: + # 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None + # + # We can customize subplot spacing per type once we enable faceting + # for all plot types + vertical_spacing = 0.1 + horizontal_spacing = 0.1 + + # Create figure with subplots + fig = make_subplots(rows=nrows, cols=ncols, specs=specs, + shared_xaxes='all', shared_yaxes='all', + row_titles=row_titles, column_titles=column_titles, + horizontal_spacing=horizontal_spacing, + vertical_spacing=vertical_spacing, + row_heights=row_heights, column_widths=column_widths, + start_cell='bottom-left') + + # Remove explicit font size of row/col titles so template can take over + for annot in fig.layout.annotations: + annot.update(font=None) + + return fig From 82e191020fde93b29089cc19c0ba24efecb9dd1f Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Mon, 20 May 2019 18:18:12 -0400 Subject: [PATCH 2/5] Remove ExpressFigure --- plotly_express/__init__.py | 1 - plotly_express/_core.py | 19 ------------------- plotly_express/colors/_swatches.py | 2 +- 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/plotly_express/__init__.py b/plotly_express/__init__.py index cdee9b9..eee5ae2 100644 --- a/plotly_express/__init__.py +++ b/plotly_express/__init__.py @@ -32,7 +32,6 @@ ) from ._core import ( # noqa: F401 - ExpressFigure, set_mapbox_access_token, defaults, get_trendline_results, diff --git a/plotly_express/_core.py b/plotly_express/_core.py index a68cc18..c86344a 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -43,25 +43,6 @@ def set_mapbox_access_token(token): MAPBOX_TOKEN = token -class ExpressFigure(go.Figure): - offline_initialized = False - """ - Boolean that starts out `False` and is set to `True` the first time the - `_ipython_display_()` method is called (by a Jupyter environment), to indicate that - subsequent calls to that method that `plotly.offline.init_notebook_mode()` has been - called once and should not be called again. - """ - - def __init__(self, *args, **kwargs): - super(ExpressFigure, self).__init__(*args, **kwargs) - - def _ipython_display_(self): - if not ExpressFigure.offline_initialized: - init_notebook_mode() - ExpressFigure.offline_initialized = True - iplot(self, show_link=False, auto_play=False) - - def get_trendline_results(fig): """ Extracts fit statistics for trendlines (when applied to figures generated with diff --git a/plotly_express/colors/_swatches.py b/plotly_express/colors/_swatches.py index d5a4b71..a2417df 100644 --- a/plotly_express/colors/_swatches.py +++ b/plotly_express/colors/_swatches.py @@ -13,7 +13,7 @@ def _swatches(module_names, module_contents): if not (k.startswith("_") or k == "swatches") ] - return _core.ExpressFigure( + return go.Figure( data=[ go.Bar( orientation="h", From f1e4b89af6d8660b2d0a08cc2d42bee451bab296 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Mon, 20 May 2019 18:24:03 -0400 Subject: [PATCH 3/5] Fix marginal grid configuration when only one of x or y marginal is present --- plotly_express/_core.py | 42 +++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/plotly_express/_core.py b/plotly_express/_core.py index c86344a..757022c 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -329,27 +329,29 @@ def configure_cartesian_marginal_axes(args, fig, orders): set_cartesian_axis_opts(args, xaxis, 'x', orders) # Configure axis ticks on marginal subplots - for row in range(2, nrows+1, row_step): - fig.update_yaxes( - showticklabels=False, - showgrid=args["marginal_x"] == 'histogram', - row=row - ) - fig.update_xaxes( - showgrid=True, - row=row - ) + if args['marginal_x']: + for row in range(2, nrows+1, row_step): + fig.update_yaxes( + showticklabels=False, + showgrid=args["marginal_x"] == 'histogram', + row=row + ) + fig.update_xaxes( + showgrid=True, + row=row + ) - for col in range(2, ncols+1, col_step): - fig.update_xaxes( - showticklabels=False, - showgrid=args["marginal_y"] == 'histogram', - col=col - ) - fig.update_yaxes( - showgrid=True, - col=col - ) + if args['marginal_y']: + for col in range(2, ncols+1, col_step): + fig.update_xaxes( + showticklabels=False, + showgrid=args["marginal_y"] == 'histogram', + col=col + ) + fig.update_yaxes( + showgrid=True, + col=col + ) # Add axis titles to non-marginal subplots y_title = get_decorated_label(args, args['y'], 'y') From a313ce5b409cb1e9420e4fd738a503ce4511c510 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Tue, 21 May 2019 07:41:19 -0400 Subject: [PATCH 4/5] Black format files --- plotly_express/_core.py | 221 +++++++++++++-------------- plotly_express/colors/colorbrewer.py | 1 - 2 files changed, 106 insertions(+), 116 deletions(-) diff --git a/plotly_express/_core.py b/plotly_express/_core.py index 757022c..021978d 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -60,11 +60,17 @@ def get_trendline_results(fig): Mapping = namedtuple( "Mapping", - ["show_in_trace_name", "grouper", "val_map", "sequence", - "updater", "variable", "facet"], + [ + "show_in_trace_name", + "grouper", + "val_map", + "sequence", + "updater", + "variable", + "facet", + ], ) -TraceSpec = namedtuple("TraceSpec", - ["constructor", "attrs", "trace_patch", "marginal"]) +TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"]) def get_label(args, column): @@ -318,67 +324,61 @@ def configure_cartesian_marginal_axes(args, fig, orders): row_step = 2 if args["marginal_x"] else 1 ncols = len(fig._grid_ref[0]) - col_step = 2 if args['marginal_y'] else 1 + col_step = 2 if args["marginal_y"] else 1 # Set y-axis titles and axis options in the left-most column for yaxis in fig.select_yaxes(col=1): - set_cartesian_axis_opts(args, yaxis, 'y', orders) + set_cartesian_axis_opts(args, yaxis, "y", orders) # Set x-axis titles and axis options in the bottom-most row for xaxis in fig.select_xaxes(row=1): - set_cartesian_axis_opts(args, xaxis, 'x', orders) + set_cartesian_axis_opts(args, xaxis, "x", orders) # Configure axis ticks on marginal subplots - if args['marginal_x']: - for row in range(2, nrows+1, row_step): + if args["marginal_x"]: + for row in range(2, nrows + 1, row_step): fig.update_yaxes( showticklabels=False, - showgrid=args["marginal_x"] == 'histogram', - row=row - ) - fig.update_xaxes( - showgrid=True, - row=row + showgrid=args["marginal_x"] == "histogram", + row=row, ) + fig.update_xaxes(showgrid=True, row=row) - if args['marginal_y']: - for col in range(2, ncols+1, col_step): + if args["marginal_y"]: + for col in range(2, ncols + 1, col_step): fig.update_xaxes( showticklabels=False, - showgrid=args["marginal_y"] == 'histogram', - col=col - ) - fig.update_yaxes( - showgrid=True, - col=col + showgrid=args["marginal_y"] == "histogram", + col=col, ) + fig.update_yaxes(showgrid=True, col=col) # Add axis titles to non-marginal subplots - y_title = get_decorated_label(args, args['y'], 'y') + y_title = get_decorated_label(args, args["y"], "y") for row in range(1, nrows + 1, row_step): fig.update_yaxes(title_text=y_title, row=row, col=1) - x_title = get_decorated_label(args, args['x'], 'x') + x_title = get_decorated_label(args, args["x"], "x") for col in range(1, ncols + 1, col_step): fig.update_xaxes(title_text=x_title, row=1, col=col) # Configure axis type across all x-axes - if 'log_x' in args and args['log_x']: - fig.update_xaxes(type='log') + if "log_x" in args and args["log_x"]: + fig.update_xaxes(type="log") # Configure axis type across all y-axes - if 'log_y' in args and args['log_y']: - fig.update_yaxes(type='log') + if "log_y" in args and args["log_y"]: + fig.update_yaxes(type="log") # Configure matching and axis type for marginal y-axes - matches_y = 'y' + str(ncols + 1) + matches_y = "y" + str(ncols + 1) if args["marginal_x"]: for row in range(2, nrows + 1, 2): fig.update_yaxes(matches=matches_y, type=None, row=row) if args["marginal_y"]: for col in range(2, ncols + 1, 2): - fig.update_xaxes(matches='x2', type=None, col=col) + fig.update_xaxes(matches="x2", type=None, col=col) def configure_cartesian_axes(args, fig, orders): @@ -389,36 +389,38 @@ def configure_cartesian_axes(args, fig, orders): return # Set y-axis titles and axis options in the left-most column - y_title = get_decorated_label(args, args['y'], 'y') + y_title = get_decorated_label(args, args["y"], "y") for yaxis in fig.select_yaxes(col=1): yaxis.update(title_text=y_title) - set_cartesian_axis_opts(args, yaxis, 'y', orders) + set_cartesian_axis_opts(args, yaxis, "y", orders) # Set x-axis titles and axis options in the bottom-most row - x_title = get_decorated_label(args, args['x'], 'x') + x_title = get_decorated_label(args, args["x"], "x") for xaxis in fig.select_xaxes(row=1): xaxis.update(title_text=x_title) - set_cartesian_axis_opts(args, xaxis, 'x', orders) + set_cartesian_axis_opts(args, xaxis, "x", orders) # Configure axis type across all x-axes - if 'log_x' in args and args['log_x']: - fig.update_xaxes(type='log') + if "log_x" in args and args["log_x"]: + fig.update_xaxes(type="log") # Configure axis type across all y-axes - if 'log_y' in args and args['log_y']: - fig.update_yaxes(type='log') + if "log_y" in args and args["log_y"]: + fig.update_yaxes(type="log") return fig.layout def configure_ternary_axes(args, fig, orders): - fig.update(layout=dict( - ternary=dict( - aaxis=dict(title=get_label(args, args["a"])), - baxis=dict(title=get_label(args, args["b"])), - caxis=dict(title=get_label(args, args["c"])), + fig.update( + layout=dict( + ternary=dict( + aaxis=dict(title=get_label(args, args["a"])), + baxis=dict(title=get_label(args, args["b"])), + caxis=dict(title=get_label(args, args["c"])), + ) ) - )) + ) def configure_polar_axes(args, fig, orders): @@ -470,26 +472,30 @@ def configure_3d_axes(args, fig, orders): def configure_mapbox(args, fig, orders): - fig.update(layout=dict( - mapbox=dict( - accesstoken=MAPBOX_TOKEN, - center=dict( - lat=args["data_frame"][args["lat"]].mean(), - lon=args["data_frame"][args["lon"]].mean(), - ), - zoom=args["zoom"], + fig.update( + layout=dict( + mapbox=dict( + accesstoken=MAPBOX_TOKEN, + center=dict( + lat=args["data_frame"][args["lat"]].mean(), + lon=args["data_frame"][args["lon"]].mean(), + ), + zoom=args["zoom"], + ) ) - )) + ) def configure_geo(args, fig, orders): - fig.update(layout=dict( - geo=dict( - center=args["center"], - scope=args["scope"], - projection=dict(type=args["projection"]), + fig.update( + layout=dict( + geo=dict( + center=args["center"], + scope=args["scope"], + projection=dict(type=args["projection"]), + ) ) - )) + ) def configure_animation_controls(args, constructor, fig): @@ -567,21 +573,21 @@ def make_trace_spec(args, constructor, attrs, trace_patch): constructor=go.Histogram, attrs=[letter], trace_patch=dict(opacity=0.5), - marginal=letter + marginal=letter, ) elif args["marginal_" + letter] == "violin": trace_spec = TraceSpec( constructor=go.Violin, attrs=[letter, "hover_name", "hover_data"], trace_patch=dict(scalegroup=letter), - marginal=letter + marginal=letter, ) elif args["marginal_" + letter] == "box": trace_spec = TraceSpec( constructor=go.Box, attrs=[letter, "hover_name", "hover_data"], trace_patch=dict(notched=True), - marginal=letter + marginal=letter, ) elif args["marginal_" + letter] == "rug": symbols = {"x": "line-ns-open", "y": "line-ew-open"} @@ -596,7 +602,7 @@ def make_trace_spec(args, constructor, attrs, trace_patch): hoveron="points", marker={"symbol": symbols[letter]}, ), - marginal=letter + marginal=letter, ) if "color" in attrs: if "marker" not in trace_spec.trace_patch: @@ -611,7 +617,7 @@ def make_trace_spec(args, constructor, attrs, trace_patch): constructor=go.Scatter, attrs=["trendline"], trace_patch=dict(mode="lines"), - marginal=None + marginal=None, ) if args["trendline_color_override"]: trace_spec.trace_patch["line"] = dict( @@ -831,8 +837,8 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): orders, sorted_group_names = get_orderings(args, grouper, grouped) - has_marginal_x = bool(args.get('marginal_x', False)) - has_marginal_y = bool(args.get('marginal_y', False)) + has_marginal_x = bool(args.get("marginal_x", False)) + has_marginal_y = bool(args.get("marginal_y", False)) subplot_type = _subplot_type_for_trace_type(constructor().type) @@ -915,7 +921,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): raise # Find row for trace, handling facet_row and marginal_x - if m.facet == 'row': + if m.facet == "row": if has_marginal_x: row = (m.val_map[val] - 1) * 2 + 2 else: @@ -925,7 +931,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): else: row = 1 - if trace_spec.marginal == 'x': + if trace_spec.marginal == "x": row -= 1 nrows = max(nrows, row) @@ -933,7 +939,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): trace._subplot_row = row # Find col for trace, handling facet_col and marginal_y - if m.facet == 'col': + if m.facet == "col": if has_marginal_y: col = (m.val_map[val] - 1) * 2 + 1 else: @@ -943,7 +949,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): else: col = 1 - if trace_spec.marginal == 'y': + if trace_spec.marginal == "y": col += 1 ncols = max(ncols, col) @@ -988,18 +994,12 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): layout_patch["legend"]["itemsizing"] = "constant" fig = init_figure( - args, - subplot_type, - frame_list, - ncols, - nrows, - has_marginal_x, - has_marginal_y, + args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y ) # Position traces in subplots for frame in frame_list: - for trace in frame['data']: + for trace in frame["data"]: if isinstance(trace, go.Splom): # Special case that is not compatible with make_subplots continue @@ -1025,44 +1025,32 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): def init_figure( - args, - subplot_type, - frame_list, - ncols, - nrows, - has_marginal_x, - has_marginal_y, + args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y ): # Build subplot specs specs = [[{}] * ncols for _ in range(nrows)] column_titles = [None] * ncols row_titles = [None] * nrows for frame in frame_list: - for trace in frame['data']: + for trace in frame["data"]: row0 = nrows - trace._subplot_row col0 = trace._subplot_col - 1 if isinstance(trace, go.Splom): # Splom not compatible with make_subplots, treat as domain - specs[row0][col0] = {'type': 'domain'} + specs[row0][col0] = {"type": "domain"} else: - specs[row0][col0] = {'type': trace.type} - if (args.get('facet_row', None) and - hasattr(trace, '_subplot_row_val')): + specs[row0][col0] = {"type": trace.type} + if args.get("facet_row", None) and hasattr(trace, "_subplot_row_val"): if row0 % 2 == 0 or not has_marginal_x: row_titles[row0] = ( - args['facet_row'] - + '=' - + str(trace._subplot_row_val) + args["facet_row"] + "=" + str(trace._subplot_row_val) ) - if (args.get('facet_col', None) and - hasattr(trace, '_subplot_col_val')): + if args.get("facet_col", None) and hasattr(trace, "_subplot_col_val"): if col0 % 2 == 0 or not has_marginal_y: column_titles[col0] = ( - args['facet_col'] - + '=' - + str(trace._subplot_col_val) + args["facet_col"] + "=" + str(trace._subplot_col_val) ) # Default row/column widths uniform @@ -1070,11 +1058,9 @@ def init_figure( row_heights = [1.0] * nrows # Build column_widths/row_heights - if subplot_type == 'xy': + if subplot_type == "xy": if has_marginal_x: - if args["marginal_x"] == "histogram" or ( - "color" in args and args["color"] - ): + if args["marginal_x"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: main_size = 0.84 @@ -1086,14 +1072,12 @@ def init_figure( for r in range(1, nrows, 2): for c in range(ncols): if specs[r][c] is not None: - specs[r][c]['t'] = vertical_spacing * 2 + specs[r][c]["t"] = vertical_spacing * 2 else: vertical_spacing = 0.03 if has_marginal_y: - if args["marginal_y"] == "histogram" or ( - "color" in args and args["color"] - ): + if args["marginal_y"] == "histogram" or ("color" in args and args["color"]): main_size = 0.74 else: main_size = 0.84 @@ -1105,7 +1089,7 @@ def init_figure( for r in range(nrows): for c in range(1, ncols, 2): if specs[r][c] is not None: - specs[r][c]['r'] = horizontal_spacing * 2 + specs[r][c]["r"] = horizontal_spacing * 2 else: horizontal_spacing = 0.02 else: @@ -1118,13 +1102,20 @@ def init_figure( horizontal_spacing = 0.1 # Create figure with subplots - fig = make_subplots(rows=nrows, cols=ncols, specs=specs, - shared_xaxes='all', shared_yaxes='all', - row_titles=row_titles, column_titles=column_titles, - horizontal_spacing=horizontal_spacing, - vertical_spacing=vertical_spacing, - row_heights=row_heights, column_widths=column_widths, - start_cell='bottom-left') + fig = make_subplots( + rows=nrows, + cols=ncols, + specs=specs, + shared_xaxes="all", + shared_yaxes="all", + row_titles=row_titles, + column_titles=column_titles, + horizontal_spacing=horizontal_spacing, + vertical_spacing=vertical_spacing, + row_heights=row_heights, + column_widths=column_widths, + start_cell="bottom-left", + ) # Remove explicit font size of row/col titles so template can take over for annot in fig.layout.annotations: diff --git a/plotly_express/colors/colorbrewer.py b/plotly_express/colors/colorbrewer.py index e5205bd..fa2e2b9 100644 --- a/plotly_express/colors/colorbrewer.py +++ b/plotly_express/colors/colorbrewer.py @@ -456,4 +456,3 @@ def swatches(): "rgb(189,0,38)", "rgb(128,0,38)", ] - From 1f639bf68cf1b85d8d66c55366e99b6b98381a93 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Thu, 13 Jun 2019 13:09:53 -0400 Subject: [PATCH 5/5] Remove support for marginals inside facets. If faceting and marginal specified then faceting wins. It's still possible to facet columns and have marginals per columns or to facet rows and have marginals per row. --- plotly_express/_core.py | 101 ++++++++++++++++------------------------ 1 file changed, 41 insertions(+), 60 deletions(-) diff --git a/plotly_express/_core.py b/plotly_express/_core.py index 459aab1..b5af750 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -329,10 +329,7 @@ def configure_cartesian_marginal_axes(args, fig, orders): fig.layout["barmode"] = "overlay" nrows = len(fig._grid_ref) - row_step = 2 if args["marginal_x"] else 1 - ncols = len(fig._grid_ref[0]) - col_step = 2 if args["marginal_y"] else 1 # Set y-axis titles and axis options in the left-most column for yaxis in fig.select_yaxes(col=1): @@ -344,30 +341,28 @@ def configure_cartesian_marginal_axes(args, fig, orders): # Configure axis ticks on marginal subplots if args["marginal_x"]: - for row in range(2, nrows + 1, row_step): - fig.update_yaxes( - showticklabels=False, - showgrid=args["marginal_x"] == "histogram", - row=row, - ) - fig.update_xaxes(showgrid=True, row=row) + fig.update_yaxes( + showticklabels=False, + showgrid=args["marginal_x"] == "histogram", + row=nrows, + ) + fig.update_xaxes(showgrid=True, row=nrows) if args["marginal_y"]: - for col in range(2, ncols + 1, col_step): - fig.update_xaxes( - showticklabels=False, - showgrid=args["marginal_y"] == "histogram", - col=col, - ) - fig.update_yaxes(showgrid=True, col=col) + fig.update_xaxes( + showticklabels=False, + showgrid=args["marginal_y"] == "histogram", + col=ncols, + ) + fig.update_yaxes(showgrid=True, col=ncols) # Add axis titles to non-marginal subplots y_title = get_decorated_label(args, args["y"], "y") - for row in range(1, nrows + 1, row_step): + for row in range(1, nrows): fig.update_yaxes(title_text=y_title, row=row, col=1) x_title = get_decorated_label(args, args["x"], "x") - for col in range(1, ncols + 1, col_step): + for col in range(1, ncols): fig.update_xaxes(title_text=x_title, row=1, col=col) # Configure axis type across all x-axes @@ -684,6 +679,13 @@ def apply_default_cascade(args): if args["color_discrete_sequence"] is None: args["color_discrete_sequence"] = qualitative.Plotly + # If both marginals and faceting are specified, faceting wins + if args.get('facet_col', None) and args.get('marginal_y', None): + args['marginal_y'] = None + + if args.get('facet_row', None) and args.get('marginal_x', None): + args['marginal_x'] = None + def infer_config(args, constructor, trace_patch): # Declare all supported attributes, across all plot types @@ -831,6 +833,7 @@ def get_orderings(args, grouper, grouped): def make_figure(args, constructor, trace_patch={}, layout_patch={}): apply_default_cascade(args) + trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( args, constructor, trace_patch ) @@ -927,17 +930,13 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): # Find row for trace, handling facet_row and marginal_x if m.facet == "row": - if has_marginal_x: - row = (m.val_map[val] - 1) * 2 + 2 - else: - row = m.val_map[val] - + row = m.val_map[val] trace._subplot_row_val = val else: - row = 1 - - if trace_spec.marginal == "x": - row -= 1 + if trace_spec.marginal == "x": + row = 2 + else: + row = 1 nrows = max(nrows, row) if row > 1: @@ -945,17 +944,13 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): # Find col for trace, handling facet_col and marginal_y if m.facet == "col": - if has_marginal_y: - col = (m.val_map[val] - 1) * 2 + 1 - else: - col = m.val_map[val] - + col = m.val_map[val] trace._subplot_col_val = val else: - col = 1 - - if trace_spec.marginal == "y": - col += 1 + if trace_spec.marginal == "y": + col = 2 + else: + col = 1 ncols = max(ncols, col) if col > 1: @@ -1021,7 +1016,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): trace, fig.layout, fig._grid_ref, - nrows - trace._subplot_row + 1, + trace._subplot_row, trace._subplot_col, ) @@ -1055,16 +1050,14 @@ def init_figure( else: specs[row0][col0] = {"type": trace.type} if args.get("facet_row", None) and hasattr(trace, "_subplot_row_val"): - if row0 % 2 == 0 or not has_marginal_x: - row_titles[row0] = ( - args["facet_row"] + "=" + str(trace._subplot_row_val) - ) + row_titles[row0] = ( + args["facet_row"] + "=" + str(trace._subplot_row_val) + ) if args.get("facet_col", None) and hasattr(trace, "_subplot_col_val"): - if col0 % 2 == 0 or not has_marginal_y: - column_titles[col0] = ( - args["facet_col"] + "=" + str(trace._subplot_col_val) - ) + column_titles[col0] = ( + args["facet_col"] + "=" + str(trace._subplot_col_val) + ) # Default row/column widths uniform column_widths = [1.0] * ncols @@ -1078,14 +1071,8 @@ def init_figure( else: main_size = 0.84 - row_heights = [main_size, 1 - main_size] * (nrows // 2) + row_heights = [main_size] * (nrows - 1) + [1 - main_size] vertical_spacing = 0.01 - - # Add padding to the top of marginal - for r in range(1, nrows, 2): - for c in range(ncols): - if specs[r][c] is not None: - specs[r][c]["t"] = vertical_spacing * 2 else: vertical_spacing = 0.03 @@ -1095,14 +1082,8 @@ def init_figure( else: main_size = 0.84 - column_widths = [main_size, 1 - main_size] * (ncols // 2) + column_widths = [main_size] * (ncols - 1) + [1 - main_size] horizontal_spacing = 0.005 - - # Add padding to the top of marginal - for r in range(nrows): - for c in range(1, ncols, 2): - if specs[r][c] is not None: - specs[r][c]["r"] = horizontal_spacing * 2 else: horizontal_spacing = 0.02 else: