Skip to content

Commit

Permalink
Sunburst improvements (#2133)
Browse files Browse the repository at this point in the history
* color column now appears in hover

* corrected bug: path column can be numeric
  • Loading branch information
emmanuelle authored Feb 4, 2020
1 parent 51fa1ee commit f7dc2be
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
11 changes: 9 additions & 2 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ def build_dataframe(args, attrables, array_attrables):
def _check_dataframe_all_leaves(df):
df_sorted = df.sort_values(by=list(df.columns))
null_mask = df_sorted.isnull()
df_sorted = df_sorted.astype(str)
null_indices = np.nonzero(null_mask.any(axis=1).values)[0]
for null_row_index in null_indices:
row = null_mask.iloc[null_row_index]
Expand Down Expand Up @@ -1055,8 +1056,9 @@ def process_dataframe_hierarchy(args):

if args["color"] and args["color"] in path:
series_to_copy = df[args["color"]]
args["color"] = str(args["color"]) + "additional_col_for_px"
df[args["color"]] = series_to_copy
new_col_name = args["color"] + "additional_col_for_color"
path = [new_col_name if x == args["color"] else x for x in path]
df[new_col_name] = series_to_copy
if args["hover_data"]:
for col_name in args["hover_data"]:
if col_name == args["color"]:
Expand Down Expand Up @@ -1160,6 +1162,11 @@ def aggfunc_continuous(x):
args["ids"] = "id"
args["names"] = "labels"
args["parents"] = "parent"
if args["color"]:
if not args["hover_data"]:
args["hover_data"] = [args["color"]]
else:
args["hover_data"].append(args["color"])
return args


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,22 @@ def test_sunburst_treemap_with_path_color():
# Hover info
df["hover"] = [el.lower() for el in vendors]
fig = px.sunburst(df, path=path, color="calls", hover_data=["hover"])
custom = fig.data[0].customdata.ravel()
assert np.all(custom[:8] == df["hover"])
assert np.all(custom[8:] == "(?)")
custom = fig.data[0].customdata
assert np.all(custom[:8, 0] == df["hover"])
assert np.all(custom[8:, 0] == "(?)")
assert np.all(custom[:8, 1] == df["calls"])

# Discrete color
fig = px.sunburst(df, path=path, color="vendors")
assert len(np.unique(fig.data[0].marker.colors)) == 9

# Numerical column in path
df["regions"] = df["regions"].map({"North": 1, "South": 2})
path = ["total", "regions", "sectors", "vendors"]
fig = px.sunburst(df, path=path, values="values", color="calls")
colors = fig.data[0].marker.colors
assert np.all(np.array(colors[:8]) == np.array(calls))


def test_sunburst_treemap_with_path_non_rectangular():
vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None]
Expand Down

0 comments on commit f7dc2be

Please sign in to comment.