Skip to content

Commit

Permalink
adapt histograms to float/int (#2247)
Browse files Browse the repository at this point in the history
* adapt histograms to float/int

* changelog and black

* fix logic in bokeh

* define ticks from bin edges after calling hist

* black
  • Loading branch information
OriolAbril authored Jun 10, 2023
1 parent 61bda88 commit 88f7e2a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
### Maintenance and fixes

- Fixes for creating numpy object array ([2233](https://github.com/arviz-devs/arviz/pull/2233) and [2239](https://github.com/arviz-devs/arviz/pull/2239))
- Adapt histograms generated by plot_dist to input dtype ([2247](https://github.com/arviz-devs/arviz/pull/2247))

### Deprecation

Expand Down
3 changes: 2 additions & 1 deletion arviz/plots/backends/bokeh/distplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def _histplot_bokeh_op(values, values2, rotated, ax, hist_kwargs, is_circular):
if hist_kwargs.pop("cumulative", False):
hist = np.cumsum(hist)
hist /= hist[-1]
if values.dtype.kind == "i":
edges = edges.astype(float) - 0.5

if is_circular:

Expand Down Expand Up @@ -174,7 +176,6 @@ def _histplot_bokeh_op(values, values2, rotated, ax, hist_kwargs, is_circular):
)

ax = set_bokeh_circular_ticks_labels(ax, hist, labels)

elif rotated:
ax.quad(top=edges[:-1], bottom=edges[1:], left=0, right=hist, **hist_kwargs)
else:
Expand Down
15 changes: 12 additions & 3 deletions arviz/plots/backends/matplotlib/distplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def plot_dist(
hist_kwargs.setdefault("color", color)
hist_kwargs.setdefault("label", label)
hist_kwargs.setdefault("rwidth", 0.9)
hist_kwargs.setdefault("align", "left")
hist_kwargs.setdefault("density", True)

if rotated:
Expand Down Expand Up @@ -151,12 +150,22 @@ def _histplot_mpl_op(values, values2, rotated, ax, hist_kwargs, is_circular):
if bins is None:
bins = get_bins(values)

if values.dtype.kind == "i":
hist_kwargs.setdefault("align", "left")
else:
hist_kwargs.setdefault("align", "mid")

n, bins, _ = ax.hist(np.asarray(values).flatten(), bins=bins, **hist_kwargs)

if values.dtype.kind == "i":
ticks = bins[:-1]
else:
ticks = (bins[1:] + bins[:-1]) / 2

if rotated:
ax.set_yticks(bins[:-1])
ax.set_yticks(ticks)
elif not is_circular:
ax.set_xticks(bins[:-1])
ax.set_xticks(ticks)

if is_circular:
ax.set_ylim(0, 1.5 * n.max())
Expand Down

0 comments on commit 88f7e2a

Please sign in to comment.