Skip to content

Commit

Permalink
fix issue when coords for default dims are provided (#2001)
Browse files Browse the repository at this point in the history
* fix issue when coords for default dims are provided

* update changelog
  • Loading branch information
OriolAbril authored Mar 21, 2022
1 parent eadc170 commit 479f2a7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
* Update attribute handling for InferenceData ([1357](https://github.com/arviz-devs/arviz/pull/1357))
* Fix R2 implementation ([1666](https://github.com/arviz-devs/arviz/pull/1666))
* Added warning message in `plot_dist_comparison()` in case subplots go over the limit ([1688](https://github.com/arviz-devs/arviz/pull/1688))
* Fix coord value ignoring for default dims ([2001](https://github.com/arviz-devs/arviz/pull/2001))

### Deprecation

Expand Down
6 changes: 5 additions & 1 deletion arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,11 @@ def generate_dims_coords(
dim_name = dims[idx]
if dim_name not in coords:
coords[dim_name] = np.arange(index_origin, dim_len + index_origin)
coords = {key: coord for key, coord in coords.items() if any(key == dim for dim in dims)}
coords = {
key: coord
for key, coord in coords.items()
if any(key == dim for dim in dims + default_dims)
}
return dims, coords


Expand Down
20 changes: 20 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,26 @@ def test_dims_coords():
assert len(coords["x_dim_2"]) == 5


def test_dims_coords_default_dims():
shape = 4, 7
var_name = "x"
dims, coords = generate_dims_coords(
shape,
var_name,
dims=["dim1", "dim2"],
coords={"chain": ["a", "b", "c"]},
default_dims=["chain", "draw"],
)
assert "dim1" in dims
assert "dim2" in dims
assert "chain" not in dims
assert "draw" not in dims
assert len(coords["dim1"]) == 4
assert len(coords["dim2"]) == 7
assert len(coords["chain"]) == 3
assert "draw" not in coords


def test_dims_coords_extra_dims():
shape = 4, 20
var_name = "x"
Expand Down

0 comments on commit 479f2a7

Please sign in to comment.