diff --git a/plotly_express/__init__.py b/plotly_express/__init__.py index 370ebf5..01d899f 100644 --- a/plotly_express/__init__.py +++ b/plotly_express/__init__.py @@ -34,7 +34,6 @@ ) from ._core import ( # noqa: F401 - ExpressFigure, set_mapbox_access_token, defaults, get_trendline_results, diff --git a/plotly_express/_core.py b/plotly_express/_core.py index 339734e..b5af750 100644 --- a/plotly_express/_core.py +++ b/plotly_express/_core.py @@ -1,3 +1,4 @@ +from _plotly_future_ import renderer_defaults, v4_subplots import plotly.graph_objs as go from plotly.offline import init_notebook_mode, iplot import plotly.io as pio @@ -6,6 +7,12 @@ import math import pandas +from plotly.subplots import ( + make_subplots, + _set_trace_grid_reference, + _subplot_type_for_trace_type, +) + class PxDefaults(object): def __init__(self): @@ -38,25 +45,6 @@ def set_mapbox_access_token(token): MAPBOX_TOKEN = token -class ExpressFigure(go.Figure): - offline_initialized = False - """ - Boolean that starts out `False` and is set to `True` the first time the - `_ipython_display_()` method is called (by a Jupyter environment), to indicate that - subsequent calls to that method that `plotly.offline.init_notebook_mode()` has been - called once and should not be called again. - """ - - def __init__(self, *args, **kwargs): - super(ExpressFigure, self).__init__(*args, **kwargs) - - def _ipython_display_(self): - if not ExpressFigure.offline_initialized: - init_notebook_mode() - ExpressFigure.offline_initialized = True - iplot(self, show_link=False, auto_play=False) - - def get_trendline_results(fig): """ Extracts fit statistics for trendlines (when applied to figures generated with @@ -74,9 +62,17 @@ def get_trendline_results(fig): Mapping = namedtuple( "Mapping", - ["show_in_trace_name", "grouper", "val_map", "sequence", "updater", "variable"], + [ + "show_in_trace_name", + "grouper", + "val_map", + "sequence", + "updater", + "variable", + "facet", + ], ) -TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch"]) +TraceSpec = namedtuple("TraceSpec", ["constructor", "attrs", "trace_patch", "marginal"]) def get_label(args, column): @@ -110,6 +106,7 @@ def make_mapping(args, variable): sequence=[""], variable=variable, updater=(lambda trace, v: v), + facet=None, ) if variable == "facet_row" or variable == "facet_col": letter = "x" if variable == "facet_col" else "y" @@ -118,8 +115,9 @@ def make_mapping(args, variable): variable=letter, grouper=args[variable], val_map={}, - sequence=[letter + str(i) for i in range(1, 1000)], - updater=lambda trace, v: trace.update({letter + "axis": v}), + sequence=[i for i in range(1, 1000)], + updater=(lambda trace, v: v), + facet="row" if variable == "facet_row" else "col", ) (parent, variable) = variable.split(".") vprefix = variable @@ -136,6 +134,7 @@ def make_mapping(args, variable): val_map=args[vprefix + "_map"].copy(), sequence=args[vprefix + "_sequence"], updater=lambda trace, v: trace.update({parent: {variable: v}}), + facet=None, ) @@ -282,7 +281,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): return result, fit_results -def configure_axes(args, constructor, fig, axes, orders): +def configure_axes(args, constructor, fig, orders): configurators = { go.Scatter: configure_cartesian_axes, go.Scattergl: configure_cartesian_axes, @@ -302,126 +301,132 @@ def configure_axes(args, constructor, fig, axes, orders): go.Choropleth: configure_geo, } if constructor in configurators: - fig.update(layout=configurators[constructor](args, fig, axes, orders)) + configurators[constructor](args, fig, orders) -def set_cartesian_axis_opts(args, layout, letter, axis, orders): +def set_cartesian_axis_opts(args, axis, letter, orders): log_key = "log_" + letter range_key = "range_" + letter if log_key in args and args[log_key]: - layout[axis]["type"] = "log" + axis["type"] = "log" if range_key in args and args[range_key]: - layout[axis]["range"] = [math.log(r, 10) for r in args[range_key]] + axis["range"] = [math.log(r, 10) for r in args[range_key]] elif range_key in args and args[range_key]: - layout[axis]["range"] = args[range_key] + axis["range"] = args[range_key] if args[letter] in orders: - layout[axis]["categoryorder"] = "array" - layout[axis]["categoryarray"] = ( + axis["categoryorder"] = "array" + axis["categoryarray"] = ( orders[args[letter]] - if axis.startswith("x") + if isinstance(axis, go.layout.XAxis) else list(reversed(orders[args[letter]])) ) -def configure_cartesian_marginal_axes(args, orders): - layout = dict() +def configure_cartesian_marginal_axes(args, fig, orders): + if "histogram" in [args["marginal_x"], args["marginal_y"]]: - layout["barmode"] = "overlay" - for letter in ["x", "y"]: - layout[letter + "axis1"] = dict( - title=get_decorated_label(args, args[letter], letter) + fig.layout["barmode"] = "overlay" + + nrows = len(fig._grid_ref) + ncols = len(fig._grid_ref[0]) + + # Set y-axis titles and axis options in the left-most column + for yaxis in fig.select_yaxes(col=1): + set_cartesian_axis_opts(args, yaxis, "y", orders) + + # Set x-axis titles and axis options in the bottom-most row + for xaxis in fig.select_xaxes(row=1): + set_cartesian_axis_opts(args, xaxis, "x", orders) + + # Configure axis ticks on marginal subplots + if args["marginal_x"]: + fig.update_yaxes( + showticklabels=False, + showgrid=args["marginal_x"] == "histogram", + row=nrows, ) - set_cartesian_axis_opts(args, layout, letter, letter + "axis1", orders) - for letter in ["x", "y"]: - otherletter = "x" if letter == "y" else "y" - if args["marginal_" + letter]: - if args["marginal_" + letter] == "histogram" or ( - "color" in args and args["color"] - ): - main_size = 0.74 - else: - main_size = 0.84 - layout[otherletter + "axis1"]["domain"] = [0, main_size] - layout[otherletter + "axis1"]["showgrid"] = True - layout[otherletter + "axis2"] = { - "domain": [main_size + 0.005, 1], - "showticklabels": False, - } - return layout + fig.update_xaxes(showgrid=True, row=nrows) + if args["marginal_y"]: + fig.update_xaxes( + showticklabels=False, + showgrid=args["marginal_y"] == "histogram", + col=ncols, + ) + fig.update_yaxes(showgrid=True, col=ncols) -def configure_cartesian_axes(args, fig, axes, orders): - if ("marginal_x" in args and args["marginal_x"]) or ( - "marginal_y" in args and args["marginal_y"] - ): - return configure_cartesian_marginal_axes(args, orders) - - gap = 0.1 - layout = { - "annotations": [], - "grid": { - "xaxes": [], - "yaxes": [], - "xgap": gap, - "ygap": gap, - "xside": "bottom", - "yside": "left", - }, - } + # Add axis titles to non-marginal subplots + y_title = get_decorated_label(args, args["y"], "y") + for row in range(1, nrows): + fig.update_yaxes(title_text=y_title, row=row, col=1) - for letter, direction, row in (("x", "facet_col", False), ("y", "facet_row", True)): - for letter_number in [t[letter + "axis"] for t in fig.data]: - if letter_number not in layout["grid"][letter + "axes"]: - layout["grid"][letter + "axes"].append(letter_number) - axis = letter_number.replace(letter, letter + "axis") + x_title = get_decorated_label(args, args["x"], "x") + for col in range(1, ncols): + fig.update_xaxes(title_text=x_title, row=1, col=col) - layout[axis] = dict( - title=get_decorated_label(args, args[letter], letter) - ) - if len(letter_number) == 1: - set_cartesian_axis_opts(args, layout, letter, axis, orders) - else: - layout[axis]["matches"] = letter - log_key = "log_" + letter - if log_key in args and args[log_key]: - layout[axis]["type"] = "log" - - if args[direction]: - step = 1.0 / (len(layout["grid"][letter + "axes"]) - gap) - for key, value in axes[letter].items(): - i = int(value[1:]) - if row: - i = len(layout["grid"][letter + "axes"]) - i - else: - i -= 1 - layout["annotations"].append( - { - "xref": "paper", - "yref": "paper", - "showarrow": False, - "xanchor": "center", - "yanchor": "middle", - "text": args[direction] + "=" + str(key), - "x": 1.01 if row else step * (i + (0.5 - gap / 2)), - "y": step * (i + (0.5 - gap / 2)) if row else 1.02, - "textangle": 90 if row else 0, - } - ) - return layout + # Configure axis type across all x-axes + if "log_x" in args and args["log_x"]: + fig.update_xaxes(type="log") + + # Configure axis type across all y-axes + if "log_y" in args and args["log_y"]: + fig.update_yaxes(type="log") + + # Configure matching and axis type for marginal y-axes + matches_y = "y" + str(ncols + 1) + if args["marginal_x"]: + for row in range(2, nrows + 1, 2): + fig.update_yaxes(matches=matches_y, type=None, row=row) + if args["marginal_y"]: + for col in range(2, ncols + 1, 2): + fig.update_xaxes(matches="x2", type=None, col=col) -def configure_ternary_axes(args, fig, axes, orders): - return dict( - ternary=dict( - aaxis=dict(title=get_label(args, args["a"])), - baxis=dict(title=get_label(args, args["b"])), - caxis=dict(title=get_label(args, args["c"])), + +def configure_cartesian_axes(args, fig, orders): + if ("marginal_x" in args and args["marginal_x"]) or ( + "marginal_y" in args and args["marginal_y"] + ): + configure_cartesian_marginal_axes(args, fig, orders) + return + + # Set y-axis titles and axis options in the left-most column + y_title = get_decorated_label(args, args["y"], "y") + for yaxis in fig.select_yaxes(col=1): + yaxis.update(title_text=y_title) + set_cartesian_axis_opts(args, yaxis, "y", orders) + + # Set x-axis titles and axis options in the bottom-most row + x_title = get_decorated_label(args, args["x"], "x") + for xaxis in fig.select_xaxes(row=1): + xaxis.update(title_text=x_title) + set_cartesian_axis_opts(args, xaxis, "x", orders) + + # Configure axis type across all x-axes + if "log_x" in args and args["log_x"]: + fig.update_xaxes(type="log") + + # Configure axis type across all y-axes + if "log_y" in args and args["log_y"]: + fig.update_yaxes(type="log") + + return fig.layout + + +def configure_ternary_axes(args, fig, orders): + fig.update( + layout=dict( + ternary=dict( + aaxis=dict(title=get_label(args, args["a"])), + baxis=dict(title=get_label(args, args["b"])), + caxis=dict(title=get_label(args, args["c"])), + ) ) ) -def configure_polar_axes(args, fig, axes, orders): +def configure_polar_axes(args, fig, orders): layout = dict( polar=dict( angularaxis=dict(direction=args["direction"], rotation=args["start_angle"]), @@ -442,10 +447,10 @@ def configure_polar_axes(args, fig, axes, orders): else: if args["range_r"]: radialaxis["range"] = args["range_r"] - return layout + fig.update(layout=layout) -def configure_3d_axes(args, fig, axes, orders): +def configure_3d_axes(args, fig, orders): layout = dict( scene=dict( xaxis=dict(title=get_label(args, args["x"])), @@ -466,28 +471,32 @@ def configure_3d_axes(args, fig, axes, orders): if args[letter] in orders: axis["categoryorder"] = "array" axis["categoryarray"] = orders[args[letter]] - return layout - - -def configure_mapbox(args, fig, axes, orders): - return dict( - mapbox=dict( - accesstoken=MAPBOX_TOKEN, - center=dict( - lat=args["data_frame"][args["lat"]].mean(), - lon=args["data_frame"][args["lon"]].mean(), - ), - zoom=args["zoom"], + fig.update(layout=layout) + + +def configure_mapbox(args, fig, orders): + fig.update( + layout=dict( + mapbox=dict( + accesstoken=MAPBOX_TOKEN, + center=dict( + lat=args["data_frame"][args["lat"]].mean(), + lon=args["data_frame"][args["lon"]].mean(), + ), + zoom=args["zoom"], + ) ) ) -def configure_geo(args, fig, axes, orders): - return dict( - geo=dict( - center=args["center"], - scope=args["scope"], - projection=dict(type=args["projection"]), +def configure_geo(args, fig, orders): + fig.update( + layout=dict( + geo=dict( + center=args["center"], + scope=args["scope"], + projection=dict(type=args["projection"]), + ) ) ) @@ -551,7 +560,10 @@ def frame_args(duration): def make_trace_spec(args, constructor, attrs, trace_patch): - result = [TraceSpec(constructor, attrs, trace_patch)] + # Create base trace specification + result = [TraceSpec(constructor, attrs, trace_patch, None)] + + # Add marginal trace specifications for letter in ["x", "y"]: if "marginal_" + letter in args and args["marginal_" + letter]: trace_spec = None @@ -564,18 +576,21 @@ def make_trace_spec(args, constructor, attrs, trace_patch): constructor=go.Histogram, attrs=[letter, "marginal_" + letter], trace_patch=dict(opacity=0.5, bingroup=letter, **axis_map), + marginal=letter, ) elif args["marginal_" + letter] == "violin": trace_spec = TraceSpec( constructor=go.Violin, attrs=[letter, "hover_name", "hover_data"], - trace_patch=dict(scalegroup=letter, **axis_map), + trace_patch=dict(scalegroup=letter), + marginal=letter, ) elif args["marginal_" + letter] == "box": trace_spec = TraceSpec( constructor=go.Box, attrs=[letter, "hover_name", "hover_data"], - trace_patch=dict(notched=True, **axis_map), + trace_patch=dict(notched=True), + marginal=letter, ) elif args["marginal_" + letter] == "rug": symbols = {"x": "line-ns-open", "y": "line-ew-open"} @@ -589,8 +604,8 @@ def make_trace_spec(args, constructor, attrs, trace_patch): jitter=0, hoveron="points", marker={"symbol": symbols[letter]}, - **axis_map ), + marginal=letter, ) if "color" in attrs or "color" not in args: if "marker" not in trace_spec.trace_patch: @@ -598,9 +613,14 @@ def make_trace_spec(args, constructor, attrs, trace_patch): first_default_color = args["color_continuous_scale"][0] trace_spec.trace_patch["marker"]["color"] = first_default_color result.append(trace_spec) + + # Add trendline trace specifications if "trendline" in args and args["trendline"]: trace_spec = TraceSpec( - constructor=go.Scatter, attrs=["trendline"], trace_patch=dict(mode="lines") + constructor=go.Scatter, + attrs=["trendline"], + trace_patch=dict(mode="lines"), + marginal=None, ) if args["trendline_color_override"]: trace_spec.trace_patch["line"] = dict( @@ -659,8 +679,16 @@ def apply_default_cascade(args): if args["color_discrete_sequence"] is None: args["color_discrete_sequence"] = qualitative.Plotly + # If both marginals and faceting are specified, faceting wins + if args.get('facet_col', None) and args.get('marginal_y', None): + args['marginal_y'] = None + + if args.get('facet_row', None) and args.get('marginal_x', None): + args['marginal_x'] = None + def infer_config(args, constructor, trace_patch): + # Declare all supported attributes, across all plot types attrables = ( ["x", "y", "z", "a", "b", "c", "r", "theta", "size"] + ["dimensions", "hover_name", "hover_data", "text", "error_x", "error_x_minus"] @@ -670,6 +698,8 @@ def infer_config(args, constructor, trace_patch): array_attrables = ["dimensions", "hover_data"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] + # Validate that the strings provided as attribute values reference columns + # in the provided data_frame df_columns = args["data_frame"].columns for attr in attrables + group_attrables + ["color"]: @@ -693,10 +723,12 @@ def infer_config(args, constructor, trace_patch): attrs = [k for k in attrables if k in args] grouped_attrs = [] + # Compute sizeref sizeref = 0 if "size" in args and args["size"]: sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2 + # Compute color attributes and grouping attributes if "color" in args: if "color_continuous_scale" in args: if "color_discrete_sequence" not in args: @@ -716,12 +748,15 @@ def infer_config(args, constructor, trace_patch): show_colorbar = bool("color" in attrs and args["color"]) + # Compute line_dash grouping attribute if "line_dash" in args: grouped_attrs.append("line.dash") + # Compute symbol grouping attribute if "symbol" in args: grouped_attrs.append("marker.symbol") + # Compute final trace patch trace_patch = trace_patch.copy() if constructor == go.Histogram2d: @@ -744,17 +779,22 @@ def infer_config(args, constructor, trace_patch): if "line_shape" in args: trace_patch["line"] = dict(shape=args["line_shape"]) + # Compute marginal attribute if "marginal" in args: position = "marginal_x" if args["orientation"] == "v" else "marginal_y" other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y" args[position] = args["marginal"] args[other_position] = None + # Compute applicable grouping attributes for k in group_attrables: if k in args: grouped_attrs.append(k) + # Create grouped mappings grouped_mappings = [make_mapping(args, a) for a in grouped_attrs] + + # Create trace specs trace_specs = make_trace_spec(args, constructor, attrs, trace_patch) return trace_specs, grouped_mappings, sizeref, show_colorbar @@ -793,6 +833,7 @@ def get_orderings(args, grouper, grouped): def make_figure(args, constructor, trace_patch={}, layout_patch={}): apply_default_cascade(args) + trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( args, constructor, trace_patch ) @@ -801,9 +842,15 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): orders, sorted_group_names = get_orderings(args, grouper, grouped) + has_marginal_x = bool(args.get("marginal_x", False)) + has_marginal_y = bool(args.get("marginal_y", False)) + + subplot_type = _subplot_type_for_trace_type(constructor().type) + trace_names_by_frame = {} frames = OrderedDict() trendline_rows = [] + nrows = ncols = 1 for group_name in sorted_group_names: group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0]) mapping_labels = OrderedDict() @@ -854,6 +901,11 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): if trace_spec.constructor not in [go.Parcats, go.Parcoords]: trace.update(hoverlabel=dict(namelength=0)) trace_names.add(trace_name) + + # Init subplot row/col + trace._subplot_row = 1 + trace._subplot_col = 1 + for i, m in enumerate(grouped_mappings): val = group_name[i] if val not in m.val_map: @@ -875,6 +927,34 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): trace.update(marker=dict(color=m.val_map[val])) else: raise + + # Find row for trace, handling facet_row and marginal_x + if m.facet == "row": + row = m.val_map[val] + trace._subplot_row_val = val + else: + if trace_spec.marginal == "x": + row = 2 + else: + row = 1 + + nrows = max(nrows, row) + if row > 1: + trace._subplot_row = row + + # Find col for trace, handling facet_col and marginal_y + if m.facet == "col": + col = m.val_map[val] + trace._subplot_col_val = val + else: + if trace_spec.marginal == "y": + col = 2 + else: + col = 1 + + ncols = max(ncols, col) + if col > 1: + trace._subplot_col = col if ( trace_specs[0].constructor == go.Histogram2dContour and trace_spec.constructor == go.Box @@ -920,13 +1000,119 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): layout_patch["margin"] = {"t": 60} if "size" in args and args["size"]: layout_patch["legend"]["itemsizing"] = "constant" - fig = ExpressFigure( - data=frame_list[0]["data"] if len(frame_list) > 0 else [], - layout=layout_patch, - frames=frame_list if len(frames) > 1 else [], + + fig = init_figure( + args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y ) + + # Position traces in subplots + for frame in frame_list: + for trace in frame["data"]: + if isinstance(trace, go.Splom): + # Special case that is not compatible with make_subplots + continue + + _set_trace_grid_reference( + trace, + fig.layout, + fig._grid_ref, + trace._subplot_row, + trace._subplot_col, + ) + + # Add traces, layout and frames to figure + fig.add_traces(frame_list[0]["data"] if len(frame_list) > 0 else []) + fig.layout.update(layout_patch) + fig.frames = frame_list if len(frames) > 1 else [] + fig._px_trendlines = pandas.DataFrame(trendline_rows) - axes = {m.variable: m.val_map for m in grouped_mappings} - configure_axes(args, constructor, fig, axes, orders) + + configure_axes(args, constructor, fig, orders) configure_animation_controls(args, constructor, fig) return fig + + +def init_figure( + args, subplot_type, frame_list, ncols, nrows, has_marginal_x, has_marginal_y +): + # Build subplot specs + specs = [[{}] * ncols for _ in range(nrows)] + column_titles = [None] * ncols + row_titles = [None] * nrows + for frame in frame_list: + for trace in frame["data"]: + row0 = nrows - trace._subplot_row + col0 = trace._subplot_col - 1 + + if isinstance(trace, go.Splom): + # Splom not compatible with make_subplots, treat as domain + specs[row0][col0] = {"type": "domain"} + else: + specs[row0][col0] = {"type": trace.type} + if args.get("facet_row", None) and hasattr(trace, "_subplot_row_val"): + row_titles[row0] = ( + args["facet_row"] + "=" + str(trace._subplot_row_val) + ) + + if args.get("facet_col", None) and hasattr(trace, "_subplot_col_val"): + column_titles[col0] = ( + args["facet_col"] + "=" + str(trace._subplot_col_val) + ) + + # Default row/column widths uniform + column_widths = [1.0] * ncols + row_heights = [1.0] * nrows + + # Build column_widths/row_heights + if subplot_type == "xy": + if has_marginal_x: + if args["marginal_x"] == "histogram" or ("color" in args and args["color"]): + main_size = 0.74 + else: + main_size = 0.84 + + row_heights = [main_size] * (nrows - 1) + [1 - main_size] + vertical_spacing = 0.01 + else: + vertical_spacing = 0.03 + + if has_marginal_y: + if args["marginal_y"] == "histogram" or ("color" in args and args["color"]): + main_size = 0.74 + else: + main_size = 0.84 + + column_widths = [main_size] * (ncols - 1) + [1 - main_size] + horizontal_spacing = 0.005 + else: + horizontal_spacing = 0.02 + else: + # Other subplot types: + # 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None + # + # We can customize subplot spacing per type once we enable faceting + # for all plot types + vertical_spacing = 0.1 + horizontal_spacing = 0.1 + + # Create figure with subplots + fig = make_subplots( + rows=nrows, + cols=ncols, + specs=specs, + shared_xaxes="all", + shared_yaxes="all", + row_titles=row_titles, + column_titles=column_titles, + horizontal_spacing=horizontal_spacing, + vertical_spacing=vertical_spacing, + row_heights=row_heights, + column_widths=column_widths, + start_cell="bottom-left", + ) + + # Remove explicit font size of row/col titles so template can take over + for annot in fig.layout.annotations: + annot.update(font=None) + + return fig diff --git a/plotly_express/colors/_swatches.py b/plotly_express/colors/_swatches.py index d5a4b71..a2417df 100644 --- a/plotly_express/colors/_swatches.py +++ b/plotly_express/colors/_swatches.py @@ -13,7 +13,7 @@ def _swatches(module_names, module_contents): if not (k.startswith("_") or k == "swatches") ] - return _core.ExpressFigure( + return go.Figure( data=[ go.Bar( orientation="h", diff --git a/plotly_express/colors/colorbrewer.py b/plotly_express/colors/colorbrewer.py index e5205bd..fa2e2b9 100644 --- a/plotly_express/colors/colorbrewer.py +++ b/plotly_express/colors/colorbrewer.py @@ -456,4 +456,3 @@ def swatches(): "rgb(189,0,38)", "rgb(128,0,38)", ] -