From ad6a9a04ff56160f5da4aaa0c7c87f4e202b3039 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Thu, 29 Jul 2021 16:24:58 -0400 Subject: [PATCH 01/37] adding keyword args --- arviz/data/inference_data.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index ed1dd3f301..b4680c2f9f 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -314,7 +314,7 @@ def items(self) -> "InferenceData.InferenceDataItemsView": return InferenceData.InferenceDataItemsView(self) @staticmethod - def from_netcdf(filename: str) -> "InferenceData": + def from_netcdf(filename: str, **kwargs) -> "InferenceData": """Initialize object from a netcdf file. Expects that the file will have groups, each of which can be loaded by xarray. @@ -326,6 +326,8 @@ def from_netcdf(filename: str) -> "InferenceData": ---------- filename : str location of netcdf file + kwargs : + Keyword arguments to be passed into xarray.open_dataset Returns ------- @@ -337,7 +339,7 @@ def from_netcdf(filename: str) -> "InferenceData": data_groups = list(data.groups) for group in data_groups: - with xr.open_dataset(filename, group=group) as data: + with xr.open_dataset(filename, group=group, **kwargs) as data: if rcParams["data.load"] == "eager": groups[group] = data.load() else: From 95c9a340046b0041a76ac71f5b2a3d078be3d75f Mon Sep 17 00:00:00 2001 From: mortonjt Date: Thu, 29 Jul 2021 20:09:55 -0400 Subject: [PATCH 02/37] pushing intermediate commits to pass group_kwargs to load_arviz_data --- arviz/data/datasets.py | 4 ++-- arviz/data/inference_data.py | 13 +++++++++---- arviz/data/io_netcdf.py | 4 ++-- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/arviz/data/datasets.py b/arviz/data/datasets.py index cf72abd06e..6ed6319e14 100644 --- a/arviz/data/datasets.py +++ b/arviz/data/datasets.py @@ -199,7 +199,7 @@ def _sha256(path): return sha256hash.hexdigest() -def load_arviz_data(dataset=None, data_home=None): +def load_arviz_data(dataset=None, data_home=None, group_kwargs=None): """Load a local or remote pre-made dataset. Run with no parameters to get a list of all available models. @@ -245,7 +245,7 @@ def load_arviz_data(dataset=None, data_home=None): "file may be corrupted. Run `arviz.clear_data_home()` and try " "again, or please open an issue.".format(file_path, checksum, remote.checksum) ) - return from_netcdf(file_path) + return from_netcdf(file_path, group_kwargs) else: if dataset is None: return dict(itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items())) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index b4680c2f9f..e30ceacff1 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -314,7 +314,7 @@ def items(self) -> "InferenceData.InferenceDataItemsView": return InferenceData.InferenceDataItemsView(self) @staticmethod - def from_netcdf(filename: str, **kwargs) -> "InferenceData": + def from_netcdf(filename: str, group_kwargs: dict = None) -> "InferenceData": """Initialize object from a netcdf file. Expects that the file will have groups, each of which can be loaded by xarray. @@ -326,20 +326,25 @@ def from_netcdf(filename: str, **kwargs) -> "InferenceData": ---------- filename : str location of netcdf file - kwargs : - Keyword arguments to be passed into xarray.open_dataset + group_kwargs : dict of dict + Keyword arguments to be passed into each call of `xarray.open_dataset`. Returns ------- InferenceData object """ groups = {} + try: with nc.Dataset(filename, mode="r") as data: data_groups = list(data.groups) for group in data_groups: - with xr.open_dataset(filename, group=group, **kwargs) as data: + if group_kwargs is not None and group in group_kwargs: + group_kws = group_kwargs[group] + else: + group_kws = {} + with xr.open_dataset(filename, group=group, **group_kws) as data: if rcParams["data.load"] == "eager": groups[group] = data.load() else: diff --git a/arviz/data/io_netcdf.py b/arviz/data/io_netcdf.py index 4216a2ab2e..12e1dc701b 100644 --- a/arviz/data/io_netcdf.py +++ b/arviz/data/io_netcdf.py @@ -4,7 +4,7 @@ from .inference_data import InferenceData -def from_netcdf(filename): +def from_netcdf(filename, group_kwargs: dict = None): """Load netcdf file back into an arviz.InferenceData. Parameters @@ -22,7 +22,7 @@ def from_netcdf(filename): of loaded into memory. This behaviour is regulated by the value of ``az.rcParams["data.load"]``. """ - return InferenceData.from_netcdf(filename) + return InferenceData.from_netcdf(filename, group_kwargs) def to_netcdf(data, filename, *, group="posterior", coords=None, dims=None): From ffcafdf40421fce51ca484fd15d2e35f7889ceb3 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sun, 22 Aug 2021 19:55:08 -0400 Subject: [PATCH 03/37] adding regex and unittest --- arviz/data/datasets.py | 4 ++-- arviz/data/inference_data.py | 22 +++++++++++++++++----- arviz/data/io_netcdf.py | 8 ++++++-- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/arviz/data/datasets.py b/arviz/data/datasets.py index 6ed6319e14..2c54540d81 100644 --- a/arviz/data/datasets.py +++ b/arviz/data/datasets.py @@ -199,7 +199,7 @@ def _sha256(path): return sha256hash.hexdigest() -def load_arviz_data(dataset=None, data_home=None, group_kwargs=None): +def load_arviz_data(dataset=None, data_home=None, group_kwargs=None, regex=False): """Load a local or remote pre-made dataset. Run with no parameters to get a list of all available models. @@ -245,7 +245,7 @@ def load_arviz_data(dataset=None, data_home=None, group_kwargs=None): "file may be corrupted. Run `arviz.clear_data_home()` and try " "again, or please open an issue.".format(file_path, checksum, remote.checksum) ) - return from_netcdf(file_path, group_kwargs) + return from_netcdf(file_path, group_kwargs, regex) else: if dataset is None: return dict(itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items())) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index e30ceacff1..e13219c39a 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -9,6 +9,7 @@ from copy import deepcopy from datetime import datetime from html import escape +import re from typing import ( TYPE_CHECKING, Any, @@ -314,7 +315,8 @@ def items(self) -> "InferenceData.InferenceDataItemsView": return InferenceData.InferenceDataItemsView(self) @staticmethod - def from_netcdf(filename: str, group_kwargs: dict = None) -> "InferenceData": + def from_netcdf(filename: str, group_kwargs: dict = None, + regex: bool = False) -> "InferenceData": """Initialize object from a netcdf file. Expects that the file will have groups, each of which can be loaded by xarray. @@ -328,6 +330,8 @@ def from_netcdf(filename: str, group_kwargs: dict = None) -> "InferenceData": location of netcdf file group_kwargs : dict of dict Keyword arguments to be passed into each call of `xarray.open_dataset`. + regex : str + Specifies where regex search should be used to extend the keyword arguments. Returns ------- @@ -340,10 +344,18 @@ def from_netcdf(filename: str, group_kwargs: dict = None) -> "InferenceData": data_groups = list(data.groups) for group in data_groups: - if group_kwargs is not None and group in group_kwargs: - group_kws = group_kwargs[group] - else: - group_kws = {} + # if group_kwargs is not None and group in group_kwargs: + # group_kws = group_kwargs[group] + # else: + # group_kws = {} + + group_kws = {} + if group_kwargs is not None and regex is False: + group_kws = group_kwargs.get(group, {}) + if group_kwargs is not None and regex is True: + for key, kws in group_kwargs.items(): + if re.search(key, group): + group_kws = kws with xr.open_dataset(filename, group=group, **group_kws) as data: if rcParams["data.load"] == "eager": groups[group] = data.load() diff --git a/arviz/data/io_netcdf.py b/arviz/data/io_netcdf.py index 12e1dc701b..69d2ce137a 100644 --- a/arviz/data/io_netcdf.py +++ b/arviz/data/io_netcdf.py @@ -4,13 +4,17 @@ from .inference_data import InferenceData -def from_netcdf(filename, group_kwargs: dict = None): +def from_netcdf(filename, group_kwargs: dict = None, regex=False): """Load netcdf file back into an arviz.InferenceData. Parameters ---------- filename : str name or path of the file to load trace + group_kwargs : dict of dict + Keyword arguments to be passed into each call of `xarray.open_dataset`. + regex : str + Specifies where regex search should be used to extend the keyword arguments. Returns ------- @@ -22,7 +26,7 @@ def from_netcdf(filename, group_kwargs: dict = None): of loaded into memory. This behaviour is regulated by the value of ``az.rcParams["data.load"]``. """ - return InferenceData.from_netcdf(filename, group_kwargs) + return InferenceData.from_netcdf(filename, group_kwargs, regex) def to_netcdf(data, filename, *, group="posterior", coords=None, dims=None): From 69c4743f5321b06f4da26ba230a04fda773f6491 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sun, 22 Aug 2021 20:18:54 -0400 Subject: [PATCH 04/37] adding unittests --- arviz/tests/base_tests/test_data.py | 31 +++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index 1c0809284e..f2f8166901 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -1271,6 +1271,37 @@ def test_empty_inference_data_object(self): os.remove(filepath) assert not os.path.exists(filepath) + def test_dask_chunk_group_kwds(self): + from dask.distributed import Client + Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) + client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB") + group_kwargs = { + 'posterior': {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0' : 2}}, + 'posterior_predictive': {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0': 2}} + } + centered_data = az.load_arviz_data("regression10d", group_kwargs=group_kwargs) + exp = [('chain', (4,)), + ('draw', (500,)), + ('true_w_dim_0', (2, 2, 2, 2, 2)), + ('w_dim_0', (10,))] + self.assertListEqual(list(centered_data.chunks.items()), exp) + client.close() + + def test_dask_chunk_group_regex(self): + from dask.distributed import Client + Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) + client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB") + group_kwargs = { + "posterior.*": {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0' : 2}} + } + centered_data = az.load_arviz_data("regression10d", group_kwargs=group_kwargs) + exp = [('chain', (4,)), + ('draw', (500,)), + ('true_w_dim_0', (2, 2, 2, 2, 2)), + ('w_dim_0', (10,))] + self.assertListEqual(list(centered_data.chunks.items()), exp) + client.close() + class TestJSON: def test_json_converters(self, models): From c41ac0af3a36c76f525986a4c17f207ff4db5451 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sun, 22 Aug 2021 20:28:03 -0400 Subject: [PATCH 05/37] updating changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 808aa92733..281205869d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ * Added ability to plot HDI contours to `plot_kde` with the new `hdi_probs` parameter. ([1665](https://github.com/arviz-devs/arviz/pull/1665)) * Add dtype parsing and setting in all Stan converters ([1632](https://github.com/arviz-devs/arviz/pull/1632)) * Add option to specify colors for each element in ppc_plot ([1769](https://github.com/arviz-devs/arviz/pull/1769)) +* Enable dask chunking information to be passed to `InferenceData.from_netcdf` ([1749](https://github.com/arviz-devs/arviz/pull/1749)) ### Maintenance and fixes * Fix conversion for numpyro models with ImproperUniform latent sites ([1713](https://github.com/arviz-devs/arviz/pull/1713)) From 8c5616c217e15178db7a9fd9de67a66fe184e4ab Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Mon, 23 Aug 2021 13:31:27 -0600 Subject: [PATCH 06/37] Update arviz/data/inference_data.py Co-authored-by: Oriol Abril-Pla --- arviz/data/inference_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index e13219c39a..399180e946 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -328,8 +328,8 @@ def from_netcdf(filename: str, group_kwargs: dict = None, ---------- filename : str location of netcdf file - group_kwargs : dict of dict - Keyword arguments to be passed into each call of `xarray.open_dataset`. + group_kwargs : dict of {str: dict}, optional + Keyword arguments to be passed into each call of {func}`xarray.open_dataset`. regex : str Specifies where regex search should be used to extend the keyword arguments. From 8d547c86b49829fe19f9cf3b4f36c0ae63a833db Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Mon, 23 Aug 2021 13:31:44 -0600 Subject: [PATCH 07/37] Update arviz/data/inference_data.py Co-authored-by: Oriol Abril-Pla --- arviz/data/inference_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 399180e946..c65fba4761 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -330,7 +330,7 @@ def from_netcdf(filename: str, group_kwargs: dict = None, location of netcdf file group_kwargs : dict of {str: dict}, optional Keyword arguments to be passed into each call of {func}`xarray.open_dataset`. - regex : str + regex : bool, default False Specifies where regex search should be used to extend the keyword arguments. Returns From b84b057cae2443f734cbebb46a86d5e7643e23b3 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 23 Aug 2021 15:41:08 -0400 Subject: [PATCH 08/37] docstrings and black --- arviz/data/datasets.py | 15 ++++++++++++-- arviz/data/inference_data.py | 5 +++-- arviz/tests/base_tests/test_data.py | 31 ----------------------------- 3 files changed, 16 insertions(+), 35 deletions(-) diff --git a/arviz/data/datasets.py b/arviz/data/datasets.py index 2c54540d81..943945c7d0 100644 --- a/arviz/data/datasets.py +++ b/arviz/data/datasets.py @@ -199,7 +199,7 @@ def _sha256(path): return sha256hash.hexdigest() -def load_arviz_data(dataset=None, data_home=None, group_kwargs=None, regex=False): +def load_arviz_data(dataset=None, data_home=None, regex=False, **kwargs): """Load a local or remote pre-made dataset. Run with no parameters to get a list of all available models. @@ -218,9 +218,20 @@ def load_arviz_data(dataset=None, data_home=None, group_kwargs=None, regex=False data_home : str, optional Where to save remote datasets + regex : bool, optional + Specifies regex support for chunking information in + `arviz.io_netcdf.from_netcdf`. This feature is currently experimental. + See :meth:`arviz.io_netcdf.from_netcdf` + + **kwargs : dict of {str: dict}, optional + Keyword arguments to be passed into arviz.io_netcdf.from_netcdf`. + This feature is currently experimental. + See :meth:`arviz.io_netcdf.from_netcdf` + Returns ------- xarray.Dataset + """ if dataset in LOCAL_DATASETS: resource = LOCAL_DATASETS[dataset] @@ -245,7 +256,7 @@ def load_arviz_data(dataset=None, data_home=None, group_kwargs=None, regex=False "file may be corrupted. Run `arviz.clear_data_home()` and try " "again, or please open an issue.".format(file_path, checksum, remote.checksum) ) - return from_netcdf(file_path, group_kwargs, regex) + return from_netcdf(file_path, regex, group_kwargs=kwargs) else: if dataset is None: return dict(itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items())) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index c65fba4761..4f8d92720c 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -315,8 +315,9 @@ def items(self) -> "InferenceData.InferenceDataItemsView": return InferenceData.InferenceDataItemsView(self) @staticmethod - def from_netcdf(filename: str, group_kwargs: dict = None, - regex: bool = False) -> "InferenceData": + def from_netcdf( + filename: str, group_kwargs: dict = None, regex: bool = False + ) -> "InferenceData": """Initialize object from a netcdf file. Expects that the file will have groups, each of which can be loaded by xarray. diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index f2f8166901..1c0809284e 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -1271,37 +1271,6 @@ def test_empty_inference_data_object(self): os.remove(filepath) assert not os.path.exists(filepath) - def test_dask_chunk_group_kwds(self): - from dask.distributed import Client - Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) - client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB") - group_kwargs = { - 'posterior': {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0' : 2}}, - 'posterior_predictive': {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0': 2}} - } - centered_data = az.load_arviz_data("regression10d", group_kwargs=group_kwargs) - exp = [('chain', (4,)), - ('draw', (500,)), - ('true_w_dim_0', (2, 2, 2, 2, 2)), - ('w_dim_0', (10,))] - self.assertListEqual(list(centered_data.chunks.items()), exp) - client.close() - - def test_dask_chunk_group_regex(self): - from dask.distributed import Client - Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) - client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB") - group_kwargs = { - "posterior.*": {'chunks': {'true_w_dim_0': 2, 'true_w_dim_0' : 2}} - } - centered_data = az.load_arviz_data("regression10d", group_kwargs=group_kwargs) - exp = [('chain', (4,)), - ('draw', (500,)), - ('true_w_dim_0', (2, 2, 2, 2, 2)), - ('w_dim_0', (10,))] - self.assertListEqual(list(centered_data.chunks.items()), exp) - client.close() - class TestJSON: def test_json_converters(self, models): From 1cdef7a4d6eec827fbc3b347df208b83f8a95f89 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 23 Aug 2021 15:49:22 -0400 Subject: [PATCH 09/37] fixing tests --- arviz/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/data/datasets.py b/arviz/data/datasets.py index 943945c7d0..9ef5bd3563 100644 --- a/arviz/data/datasets.py +++ b/arviz/data/datasets.py @@ -256,7 +256,7 @@ def load_arviz_data(dataset=None, data_home=None, regex=False, **kwargs): "file may be corrupted. Run `arviz.clear_data_home()` and try " "again, or please open an issue.".format(file_path, checksum, remote.checksum) ) - return from_netcdf(file_path, regex, group_kwargs=kwargs) + return from_netcdf(file_path, kwargs, regex) else: if dataset is None: return dict(itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items())) From 80e092c99f4f6b75cdaeae04cf71b75dcbf4fcb2 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 23 Aug 2021 15:49:37 -0400 Subject: [PATCH 10/37] adding dask test --- arviz/tests/base_tests/test_data_dask.py | 44 ++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 arviz/tests/base_tests/test_data_dask.py diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py new file mode 100644 index 0000000000..1ff3b403e3 --- /dev/null +++ b/arviz/tests/base_tests/test_data_dask.py @@ -0,0 +1,44 @@ +# pylint: disable=redefined-outer-name, no-member +import importlib +import arviz as az +from arviz.utils import Dask +import pytest + + +pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name + (importlib.util.find_spec("dask") is None) and not running_on_ci(), + reason="test requires dask which is not installed", +) + +class TestDataDask: + + def test_dask_chunk_group_kwds(self): + from dask.distributed import Client + Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) + client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB") + group_kwargs = { + 'posterior': {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}}, + 'posterior_predictive': {'chunks': {'w_dim_0': 2, 'true_w_dim_0': 2}} + } + centered_data = az.load_arviz_data("regression10d", **group_kwargs) + exp = [('chain', (4,)), + ('draw', (500,)), + ('true_w_dim_0', (2, 2, 2, 2, 2)), + ('w_dim_0', (10,))] + self.assertListEqual(list(centered_data.chunks.items()), exp) + client.close() + + def test_dask_chunk_group_regex(self): + from dask.distributed import Client + Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) + client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB") + group_kwargs = { + "posterior.*": {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}} + } + centered_data = az.load_arviz_data("regression10d", regex=True, **group_kwargs) + exp = [('chain', (4,)), + ('draw', (500,)), + ('true_w_dim_0', (2, 2, 2, 2, 2)), + ('w_dim_0', (10,))] + self.assertListEqual(list(centered_data.chunks.items()), exp) + client.close() From 6fd613ea521d4d6ce91fcff7c6cf927e33ba33b5 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 23 Aug 2021 15:51:43 -0400 Subject: [PATCH 11/37] adding more experimental warnings --- arviz/data/inference_data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 4f8d92720c..47ad81b65a 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -331,8 +331,10 @@ def from_netcdf( location of netcdf file group_kwargs : dict of {str: dict}, optional Keyword arguments to be passed into each call of {func}`xarray.open_dataset`. + This feature is currently experimental. regex : bool, default False Specifies where regex search should be used to extend the keyword arguments. + This feature is currently experimental. Returns ------- From 6384959557e3399c5fbc40659407df6efdb46216 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 24 Aug 2021 13:49:39 -0400 Subject: [PATCH 12/37] adding dask distributed dep --- requirements-optional.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-optional.txt b/requirements-optional.txt index ec7ef8ddd6..c7bb9e3390 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -2,4 +2,5 @@ numba bokeh>=1.4.0 ujson dask -zarr>=2.5.0 \ No newline at end of file +distributed +zarr>=2.5.0 From 9e99e350bc3075407719f26bfc808a3595d075f7 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 24 Aug 2021 16:07:46 -0400 Subject: [PATCH 13/37] fixing list equality --- arviz/tests/base_tests/test_data_dask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index 1ff3b403e3..2b17e93d73 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -25,7 +25,7 @@ def test_dask_chunk_group_kwds(self): ('draw', (500,)), ('true_w_dim_0', (2, 2, 2, 2, 2)), ('w_dim_0', (10,))] - self.assertListEqual(list(centered_data.chunks.items()), exp) + assert list(centered_data.chunks.items()) == exp client.close() def test_dask_chunk_group_regex(self): @@ -40,5 +40,5 @@ def test_dask_chunk_group_regex(self): ('draw', (500,)), ('true_w_dim_0', (2, 2, 2, 2, 2)), ('w_dim_0', (10,))] - self.assertListEqual(list(centered_data.chunks.items()), exp) + assert list(centered_data.chunks.items()) == exp client.close() From 84844c33cfc0eb48a191c058113fa632ee1fd01c Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 24 Aug 2021 16:53:21 -0400 Subject: [PATCH 14/37] requirements --- requirements-optional.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-optional.txt b/requirements-optional.txt index c7bb9e3390..70f60c790d 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -2,5 +2,5 @@ numba bokeh>=1.4.0 ujson dask -distributed +"dask[distributed]" zarr>=2.5.0 From 41102ceb527bf13b1c084eb0a8c65ce0340e3c83 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 24 Aug 2021 16:57:06 -0400 Subject: [PATCH 15/37] remove quotes --- requirements-optional.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-optional.txt b/requirements-optional.txt index 70f60c790d..2bb206a62d 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -2,5 +2,5 @@ numba bokeh>=1.4.0 ujson dask -"dask[distributed]" +dask[distributed] zarr>=2.5.0 From 9ed77b40fa17e8ca7db96d0713e3e36d301ab4a8 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Fri, 27 Aug 2021 20:18:01 -0400 Subject: [PATCH 16/37] checking chunks on the wrong object :/ --- arviz/tests/base_tests/test_data_dask.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index 2b17e93d73..b9d6fe90a1 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -24,8 +24,8 @@ def test_dask_chunk_group_kwds(self): exp = [('chain', (4,)), ('draw', (500,)), ('true_w_dim_0', (2, 2, 2, 2, 2)), - ('w_dim_0', (10,))] - assert list(centered_data.chunks.items()) == exp + ('w_dim_0', (2, 2, 2, 2, 2))] + assert list(centered_data.posterior.chunks.items()) == exp client.close() def test_dask_chunk_group_regex(self): @@ -39,6 +39,6 @@ def test_dask_chunk_group_regex(self): exp = [('chain', (4,)), ('draw', (500,)), ('true_w_dim_0', (2, 2, 2, 2, 2)), - ('w_dim_0', (10,))] - assert list(centered_data.chunks.items()) == exp + ('w_dim_0', (2, 2, 2, 2, 2))] + assert list(centered_data.posterior.chunks.items()) == exp client.close() From 40407d5f3e4804e07b59ba886d425deebfab6105 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sat, 28 Aug 2021 15:33:13 -0400 Subject: [PATCH 17/37] bump --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 281205869d..a8e744d97b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ * Add dtype parsing and setting in all Stan converters ([1632](https://github.com/arviz-devs/arviz/pull/1632)) * Add option to specify colors for each element in ppc_plot ([1769](https://github.com/arviz-devs/arviz/pull/1769)) * Enable dask chunking information to be passed to `InferenceData.from_netcdf` ([1749](https://github.com/arviz-devs/arviz/pull/1749)) +* Enable dask chunking information to be passed to `InferenceData.from_netcdf` with regex support ([1749](https://github.com/arviz-devs/arviz/pull/1749)) ### Maintenance and fixes * Fix conversion for numpyro models with ImproperUniform latent sites ([1713](https://github.com/arviz-devs/arviz/pull/1713)) From ed433cf4ccabdcf1cf8fae50f99d17073ddcd475 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sat, 28 Aug 2021 16:23:17 -0400 Subject: [PATCH 18/37] bump debug, not sure why the results differ --- arviz/tests/base_tests/test_data_dask.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index b9d6fe90a1..1e5a5016f8 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -25,6 +25,7 @@ def test_dask_chunk_group_kwds(self): ('draw', (500,)), ('true_w_dim_0', (2, 2, 2, 2, 2)), ('w_dim_0', (2, 2, 2, 2, 2))] + print(list(centered_data.posterior.chunks.items())) assert list(centered_data.posterior.chunks.items()) == exp client.close() @@ -40,5 +41,6 @@ def test_dask_chunk_group_regex(self): ('draw', (500,)), ('true_w_dim_0', (2, 2, 2, 2, 2)), ('w_dim_0', (2, 2, 2, 2, 2))] + print(list(centered_data.posterior.chunks.items())) assert list(centered_data.posterior.chunks.items()) == exp client.close() From f7f9177f6595de162b488c916b84d441d2013b63 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sat, 28 Aug 2021 16:26:57 -0400 Subject: [PATCH 19/37] simplifying test with only 1 chunk --- arviz/tests/base_tests/test_data_dask.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index 1e5a5016f8..97a1939906 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -15,7 +15,7 @@ class TestDataDask: def test_dask_chunk_group_kwds(self): from dask.distributed import Client Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) - client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB") + client = Client(threads_per_worker=1, n_workers=1, memory_limit="2GB") group_kwargs = { 'posterior': {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}}, 'posterior_predictive': {'chunks': {'w_dim_0': 2, 'true_w_dim_0': 2}} @@ -23,8 +23,8 @@ def test_dask_chunk_group_kwds(self): centered_data = az.load_arviz_data("regression10d", **group_kwargs) exp = [('chain', (4,)), ('draw', (500,)), - ('true_w_dim_0', (2, 2, 2, 2, 2)), - ('w_dim_0', (2, 2, 2, 2, 2))] + ('true_w_dim_0', (10)), + ('w_dim_0', (10))] print(list(centered_data.posterior.chunks.items())) assert list(centered_data.posterior.chunks.items()) == exp client.close() @@ -32,15 +32,14 @@ def test_dask_chunk_group_kwds(self): def test_dask_chunk_group_regex(self): from dask.distributed import Client Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) - client = Client(threads_per_worker=4, n_workers=2, memory_limit="2GB") + client = Client(threads_per_worker=1, n_workers=1, memory_limit="2GB") group_kwargs = { - "posterior.*": {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}} + "posterior.*": {'chunks': {'w_dim_0': 10, 'true_w_dim_0' : 10}} } centered_data = az.load_arviz_data("regression10d", regex=True, **group_kwargs) exp = [('chain', (4,)), ('draw', (500,)), - ('true_w_dim_0', (2, 2, 2, 2, 2)), - ('w_dim_0', (2, 2, 2, 2, 2))] - print(list(centered_data.posterior.chunks.items())) + ('true_w_dim_0', (10)), + ('w_dim_0', (10))] assert list(centered_data.posterior.chunks.items()) == exp client.close() From 4f9aed8fe0aa4e703865801ec9790e16321bced6 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sun, 29 Aug 2021 18:04:21 -0400 Subject: [PATCH 20/37] getting rid of dask client following xarray dask tests --- arviz/tests/base_tests/test_data_dask.py | 51 ++++++++++++------------ 1 file changed, 25 insertions(+), 26 deletions(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index 97a1939906..38a85ba550 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -13,33 +13,32 @@ class TestDataDask: def test_dask_chunk_group_kwds(self): - from dask.distributed import Client + Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) - client = Client(threads_per_worker=1, n_workers=1, memory_limit="2GB") - group_kwargs = { - 'posterior': {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}}, - 'posterior_predictive': {'chunks': {'w_dim_0': 2, 'true_w_dim_0': 2}} - } - centered_data = az.load_arviz_data("regression10d", **group_kwargs) - exp = [('chain', (4,)), - ('draw', (500,)), - ('true_w_dim_0', (10)), - ('w_dim_0', (10))] - print(list(centered_data.posterior.chunks.items())) - assert list(centered_data.posterior.chunks.items()) == exp - client.close() + with dask.config.set(scheduler="synchronous") + group_kwargs = { + 'posterior': {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}}, + 'posterior_predictive': {'chunks': {'w_dim_0': 2, 'true_w_dim_0': 2}} + } + centered_data = az.load_arviz_data("regression10d", **group_kwargs) + exp = [('chain', (4,)), + ('draw', (500,)), + ('true_w_dim_0', (10)), + ('w_dim_0', (10))] + print(list(centered_data.posterior.chunks.items())) + assert list(centered_data.posterior.chunks.items()) == exp + def test_dask_chunk_group_regex(self): - from dask.distributed import Client + with dask.config.set(scheduler="synchronous") Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) - client = Client(threads_per_worker=1, n_workers=1, memory_limit="2GB") - group_kwargs = { - "posterior.*": {'chunks': {'w_dim_0': 10, 'true_w_dim_0' : 10}} - } - centered_data = az.load_arviz_data("regression10d", regex=True, **group_kwargs) - exp = [('chain', (4,)), - ('draw', (500,)), - ('true_w_dim_0', (10)), - ('w_dim_0', (10))] - assert list(centered_data.posterior.chunks.items()) == exp - client.close() + client = Client(threads_per_worker=1, n_workers=1, memory_limit="2GB") + group_kwargs = { + "posterior.*": {'chunks': {'w_dim_0': 10, 'true_w_dim_0' : 10}} + } + centered_data = az.load_arviz_data("regression10d", regex=True, **group_kwargs) + exp = [('chain', (4,)), + ('draw', (500,)), + ('true_w_dim_0', (10)), + ('w_dim_0', (10))] + assert list(centered_data.posterior.chunks.items()) == exp From 64f906656e3abd6eb48158de8fc494d618f3794a Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sun, 29 Aug 2021 18:09:48 -0400 Subject: [PATCH 21/37] premature commit. tests now passing locally --- arviz/tests/base_tests/test_data_dask.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index 38a85ba550..876737310e 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -1,5 +1,6 @@ # pylint: disable=redefined-outer-name, no-member import importlib +import dask import arviz as az from arviz.utils import Dask import pytest @@ -15,7 +16,7 @@ class TestDataDask: def test_dask_chunk_group_kwds(self): Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) - with dask.config.set(scheduler="synchronous") + with dask.config.set(scheduler="synchronous"): group_kwargs = { 'posterior': {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}}, 'posterior_predictive': {'chunks': {'w_dim_0': 2, 'true_w_dim_0': 2}} @@ -23,22 +24,21 @@ def test_dask_chunk_group_kwds(self): centered_data = az.load_arviz_data("regression10d", **group_kwargs) exp = [('chain', (4,)), ('draw', (500,)), - ('true_w_dim_0', (10)), - ('w_dim_0', (10))] + ('true_w_dim_0', (2, 2, 2, 2, 2,)), + ('w_dim_0', (2, 2, 2, 2, 2,))] print(list(centered_data.posterior.chunks.items())) assert list(centered_data.posterior.chunks.items()) == exp def test_dask_chunk_group_regex(self): - with dask.config.set(scheduler="synchronous") - Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) - client = Client(threads_per_worker=1, n_workers=1, memory_limit="2GB") + with dask.config.set(scheduler="synchronous"): + Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) group_kwargs = { "posterior.*": {'chunks': {'w_dim_0': 10, 'true_w_dim_0' : 10}} } centered_data = az.load_arviz_data("regression10d", regex=True, **group_kwargs) exp = [('chain', (4,)), ('draw', (500,)), - ('true_w_dim_0', (10)), - ('w_dim_0', (10))] + ('true_w_dim_0', (10,)), + ('w_dim_0', (10,))] assert list(centered_data.posterior.chunks.items()) == exp From 7d2917938e8701d40c94d3c7b74640f4e75bd443 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sun, 29 Aug 2021 18:26:39 -0400 Subject: [PATCH 22/37] removing running_on_cli: --- arviz/tests/base_tests/test_data_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index 876737310e..3bfc9356ed 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -7,7 +7,7 @@ pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name - (importlib.util.find_spec("dask") is None) and not running_on_ci(), + importlib.util.find_spec("dask") is None, reason="test requires dask which is not installed", ) From a89294ce68bc6bba6d24717bf0a44e10b9d22dd2 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Sun, 29 Aug 2021 18:27:16 -0400 Subject: [PATCH 23/37] bleh fixing ordering of import --- arviz/tests/base_tests/test_data_dask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index 3bfc9356ed..da9ab25908 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -1,9 +1,9 @@ # pylint: disable=redefined-outer-name, no-member +import pytest import importlib import dask import arviz as az from arviz.utils import Dask -import pytest pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name From d0e3426e8d8c3253a88eb911be470ef90f46f0c7 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 31 Aug 2021 12:59:36 -0400 Subject: [PATCH 24/37] removing parallel flag --- arviz/tests/base_tests/test_data_dask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index da9ab25908..50462f0c58 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -1,6 +1,6 @@ # pylint: disable=redefined-outer-name, no-member -import pytest import importlib +import pytest import dask import arviz as az from arviz.utils import Dask @@ -15,7 +15,7 @@ class TestDataDask: def test_dask_chunk_group_kwds(self): - Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) + Dask.enable_dask(dask_kwargs={"output_dtypes": [float]}) with dask.config.set(scheduler="synchronous"): group_kwargs = { 'posterior': {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}}, @@ -32,7 +32,7 @@ def test_dask_chunk_group_kwds(self): def test_dask_chunk_group_regex(self): with dask.config.set(scheduler="synchronous"): - Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]}) + Dask.enable_dask(dask_kwargs={"output_dtypes": [float]}) group_kwargs = { "posterior.*": {'chunks': {'w_dim_0': 10, 'true_w_dim_0' : 10}} } From ec26d0e6f10d4a583e82f39b61fea077ac646e56 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 31 Aug 2021 13:06:39 -0400 Subject: [PATCH 25/37] another debug run --- arviz/tests/base_tests/test_data_dask.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index 50462f0c58..dfa9904fe4 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -4,6 +4,7 @@ import dask import arviz as az from arviz.utils import Dask +import xarray pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name @@ -14,7 +15,8 @@ class TestDataDask: def test_dask_chunk_group_kwds(self): - + print(dask.__version__) + print(xarray.__version__) Dask.enable_dask(dask_kwargs={"output_dtypes": [float]}) with dask.config.set(scheduler="synchronous"): group_kwargs = { From 39f5d7b1488073e2c8f4484c86409002da591802 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 31 Aug 2021 13:34:04 -0400 Subject: [PATCH 26/37] guarantee ordering --- arviz/tests/base_tests/test_data_dask.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index dfa9904fe4..989c388131 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -28,8 +28,10 @@ def test_dask_chunk_group_kwds(self): ('draw', (500,)), ('true_w_dim_0', (2, 2, 2, 2, 2,)), ('w_dim_0', (2, 2, 2, 2, 2,))] - print(list(centered_data.posterior.chunks.items())) - assert list(centered_data.posterior.chunks.items()) == exp + res = list(centered_data.posterior.chunks.items()) + res.sort() + exp.sort() + assert res == exp def test_dask_chunk_group_regex(self): @@ -43,4 +45,7 @@ def test_dask_chunk_group_regex(self): ('draw', (500,)), ('true_w_dim_0', (10,)), ('w_dim_0', (10,))] - assert list(centered_data.posterior.chunks.items()) == exp + res = list(centered_data.posterior.chunks.items()) + res.sort() + exp.sort() + assert res == exp From 715d55c91e1fa37a19303010ad1bf5d4fe6faf13 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 31 Aug 2021 14:02:33 -0400 Subject: [PATCH 27/37] adding temporary debugging checks --- arviz/data/inference_data.py | 10 +++++----- arviz/tests/base_tests/test_data_dask.py | 3 --- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 47ad81b65a..f4c2b84763 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -347,10 +347,6 @@ def from_netcdf( data_groups = list(data.groups) for group in data_groups: - # if group_kwargs is not None and group in group_kwargs: - # group_kws = group_kwargs[group] - # else: - # group_kws = {} group_kws = {} if group_kwargs is not None and regex is False: @@ -359,12 +355,16 @@ def from_netcdf( for key, kws in group_kwargs.items(): if re.search(key, group): group_kws = kws + print('DEBUG : group_kws', group_kws) with xr.open_dataset(filename, group=group, **group_kws) as data: if rcParams["data.load"] == "eager": groups[group] = data.load() else: groups[group] = data - return InferenceData(**groups) + print('DEBUG : group chunks', groups.chunks) + res = InferenceData(**groups) + print('DEBUG : chunks', res.chunks) + return res except OSError as e: # pylint: disable=invalid-name if e.errno == -101: raise type(e)( diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index 989c388131..b05976d540 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -4,7 +4,6 @@ import dask import arviz as az from arviz.utils import Dask -import xarray pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name @@ -15,8 +14,6 @@ class TestDataDask: def test_dask_chunk_group_kwds(self): - print(dask.__version__) - print(xarray.__version__) Dask.enable_dask(dask_kwargs={"output_dtypes": [float]}) with dask.config.set(scheduler="synchronous"): group_kwargs = { From 6627676555dc50cf461bc92848518bcd0535702c Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 31 Aug 2021 14:15:17 -0400 Subject: [PATCH 28/37] bump --- arviz/data/inference_data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index f4c2b84763..6766718a62 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -361,9 +361,8 @@ def from_netcdf( groups[group] = data.load() else: groups[group] = data - print('DEBUG : group chunks', groups.chunks) + print('DEBUG : group chunks', data.chunks) res = InferenceData(**groups) - print('DEBUG : chunks', res.chunks) return res except OSError as e: # pylint: disable=invalid-name if e.errno == -101: From 9581d4f13732e8ce13b7d507760f1302bcd90269 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Tue, 31 Aug 2021 14:53:14 -0400 Subject: [PATCH 29/37] maybe we don't need the synchronize black --- arviz/tests/base_tests/test_data_dask.py | 54 ++++++++++++------------ 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py index b05976d540..b8e06b9681 100644 --- a/arviz/tests/base_tests/test_data_dask.py +++ b/arviz/tests/base_tests/test_data_dask.py @@ -15,34 +15,32 @@ class TestDataDask: def test_dask_chunk_group_kwds(self): Dask.enable_dask(dask_kwargs={"output_dtypes": [float]}) - with dask.config.set(scheduler="synchronous"): - group_kwargs = { - 'posterior': {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}}, - 'posterior_predictive': {'chunks': {'w_dim_0': 2, 'true_w_dim_0': 2}} - } - centered_data = az.load_arviz_data("regression10d", **group_kwargs) - exp = [('chain', (4,)), - ('draw', (500,)), - ('true_w_dim_0', (2, 2, 2, 2, 2,)), - ('w_dim_0', (2, 2, 2, 2, 2,))] - res = list(centered_data.posterior.chunks.items()) - res.sort() - exp.sort() - assert res == exp + group_kwargs = { + 'posterior': {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}}, + 'posterior_predictive': {'chunks': {'w_dim_0': 2, 'true_w_dim_0': 2}} + } + centered_data = az.load_arviz_data("regression10d", **group_kwargs) + exp = [('chain', (4,)), + ('draw', (500,)), + ('true_w_dim_0', (2, 2, 2, 2, 2,)), + ('w_dim_0', (2, 2, 2, 2, 2,))] + res = list(centered_data.posterior.chunks.items()) + res.sort() + exp.sort() + assert res == exp def test_dask_chunk_group_regex(self): - with dask.config.set(scheduler="synchronous"): - Dask.enable_dask(dask_kwargs={"output_dtypes": [float]}) - group_kwargs = { - "posterior.*": {'chunks': {'w_dim_0': 10, 'true_w_dim_0' : 10}} - } - centered_data = az.load_arviz_data("regression10d", regex=True, **group_kwargs) - exp = [('chain', (4,)), - ('draw', (500,)), - ('true_w_dim_0', (10,)), - ('w_dim_0', (10,))] - res = list(centered_data.posterior.chunks.items()) - res.sort() - exp.sort() - assert res == exp + Dask.enable_dask(dask_kwargs={"output_dtypes": [float]}) + group_kwargs = { + "posterior.*": {'chunks': {'w_dim_0': 10, 'true_w_dim_0' : 10}} + } + centered_data = az.load_arviz_data("regression10d", regex=True, **group_kwargs) + exp = [('chain', (4,)), + ('draw', (500,)), + ('true_w_dim_0', (10,)), + ('w_dim_0', (10,))] + res = list(centered_data.posterior.chunks.items()) + res.sort() + exp.sort() + assert res == exp From 7d03b15855c8111c3fac23da7e7bc745c63f8661 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 6 Sep 2021 11:39:25 -0400 Subject: [PATCH 30/37] removing tests for now --- arviz/tests/base_tests/test_data_dask.py | 46 ------------------------ 1 file changed, 46 deletions(-) delete mode 100644 arviz/tests/base_tests/test_data_dask.py diff --git a/arviz/tests/base_tests/test_data_dask.py b/arviz/tests/base_tests/test_data_dask.py deleted file mode 100644 index b8e06b9681..0000000000 --- a/arviz/tests/base_tests/test_data_dask.py +++ /dev/null @@ -1,46 +0,0 @@ -# pylint: disable=redefined-outer-name, no-member -import importlib -import pytest -import dask -import arviz as az -from arviz.utils import Dask - - -pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name - importlib.util.find_spec("dask") is None, - reason="test requires dask which is not installed", -) - -class TestDataDask: - - def test_dask_chunk_group_kwds(self): - Dask.enable_dask(dask_kwargs={"output_dtypes": [float]}) - group_kwargs = { - 'posterior': {'chunks': {'w_dim_0': 2, 'true_w_dim_0' : 2}}, - 'posterior_predictive': {'chunks': {'w_dim_0': 2, 'true_w_dim_0': 2}} - } - centered_data = az.load_arviz_data("regression10d", **group_kwargs) - exp = [('chain', (4,)), - ('draw', (500,)), - ('true_w_dim_0', (2, 2, 2, 2, 2,)), - ('w_dim_0', (2, 2, 2, 2, 2,))] - res = list(centered_data.posterior.chunks.items()) - res.sort() - exp.sort() - assert res == exp - - - def test_dask_chunk_group_regex(self): - Dask.enable_dask(dask_kwargs={"output_dtypes": [float]}) - group_kwargs = { - "posterior.*": {'chunks': {'w_dim_0': 10, 'true_w_dim_0' : 10}} - } - centered_data = az.load_arviz_data("regression10d", regex=True, **group_kwargs) - exp = [('chain', (4,)), - ('draw', (500,)), - ('true_w_dim_0', (10,)), - ('w_dim_0', (10,))] - res = list(centered_data.posterior.chunks.items()) - res.sort() - exp.sort() - assert res == exp From 454f51a5a53fbdd5e11b0e79d1d4b8a9c95b5338 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 11 Oct 2021 22:47:10 -0400 Subject: [PATCH 31/37] updating argument --- arviz/data/io_netcdf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/arviz/data/io_netcdf.py b/arviz/data/io_netcdf.py index 69d2ce137a..eaabdc1b06 100644 --- a/arviz/data/io_netcdf.py +++ b/arviz/data/io_netcdf.py @@ -26,6 +26,8 @@ def from_netcdf(filename, group_kwargs: dict = None, regex=False): of loaded into memory. This behaviour is regulated by the value of ``az.rcParams["data.load"]``. """ + if group_kwargs is None: + group_kwargs = {} return InferenceData.from_netcdf(filename, group_kwargs, regex) From 6ec71ccaca0206ab39e8321e07f0748a877d1ae5 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 11 Oct 2021 22:56:21 -0400 Subject: [PATCH 32/37] removing debug statements again --- arviz/data/inference_data.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 6766718a62..0f9dc3e989 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -355,13 +355,11 @@ def from_netcdf( for key, kws in group_kwargs.items(): if re.search(key, group): group_kws = kws - print('DEBUG : group_kws', group_kws) with xr.open_dataset(filename, group=group, **group_kws) as data: if rcParams["data.load"] == "eager": groups[group] = data.load() else: groups[group] = data - print('DEBUG : group chunks', data.chunks) res = InferenceData(**groups) return res except OSError as e: # pylint: disable=invalid-name From 5de4d5df374de7fa83ccbcb21b93907c3b447fb8 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 11 Oct 2021 23:04:48 -0400 Subject: [PATCH 33/37] simplifying types --- arviz/data/inference_data.py | 2 +- arviz/data/io_netcdf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 0f9dc3e989..ef84c05781 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -316,7 +316,7 @@ def items(self) -> "InferenceData.InferenceDataItemsView": @staticmethod def from_netcdf( - filename: str, group_kwargs: dict = None, regex: bool = False + filename, group_kwargs=None, regex=False ) -> "InferenceData": """Initialize object from a netcdf file. diff --git a/arviz/data/io_netcdf.py b/arviz/data/io_netcdf.py index eaabdc1b06..2928cc61f0 100644 --- a/arviz/data/io_netcdf.py +++ b/arviz/data/io_netcdf.py @@ -4,7 +4,7 @@ from .inference_data import InferenceData -def from_netcdf(filename, group_kwargs: dict = None, regex=False): +def from_netcdf(filename, group_kwargs=None, regex=False): """Load netcdf file back into an arviz.InferenceData. Parameters From 4a52f4e85930612fa31fda4dc48b2a6d18f302fb Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 11 Oct 2021 23:25:41 -0400 Subject: [PATCH 34/37] black? --- arviz/data/inference_data.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index ef84c05781..84e688f704 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -315,9 +315,7 @@ def items(self) -> "InferenceData.InferenceDataItemsView": return InferenceData.InferenceDataItemsView(self) @staticmethod - def from_netcdf( - filename, group_kwargs=None, regex=False - ) -> "InferenceData": + def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData": """Initialize object from a netcdf file. Expects that the file will have groups, each of which can be loaded by xarray. From f4114e26617d600c4dd5a317c984dea271294aeb Mon Sep 17 00:00:00 2001 From: mortonjt Date: Mon, 11 Oct 2021 23:54:55 -0400 Subject: [PATCH 35/37] bump --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a8e744d97b..8a7920a92f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ * Add option to specify colors for each element in ppc_plot ([1769](https://github.com/arviz-devs/arviz/pull/1769)) * Enable dask chunking information to be passed to `InferenceData.from_netcdf` ([1749](https://github.com/arviz-devs/arviz/pull/1749)) * Enable dask chunking information to be passed to `InferenceData.from_netcdf` with regex support ([1749](https://github.com/arviz-devs/arviz/pull/1749)) +* Enable dask chunking information to be passed to `InferenceData.from_netcdf` with regex support to enable parallel io ([1749](https://github.com/arviz-devs/arviz/pull/1749)) ### Maintenance and fixes * Fix conversion for numpyro models with ImproperUniform latent sites ([1713](https://github.com/arviz-devs/arviz/pull/1713)) From 79805d31eb23a59a357e207d722446e1c15990b9 Mon Sep 17 00:00:00 2001 From: mortonjt Date: Thu, 14 Oct 2021 22:15:14 -0400 Subject: [PATCH 36/37] rebase against main? --- .pylintrc | 1022 ++++++++++++++++---------------- arviz/stats/stats_utils.py | 1148 ++++++++++++++++++------------------ 2 files changed, 1085 insertions(+), 1085 deletions(-) diff --git a/.pylintrc b/.pylintrc index 2c5d44676b..a8d9237443 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,511 +1,511 @@ -[MASTER] - -# A comma-separated list of package or module names from where C extensions may -# be loaded. Extensions are loading into the active Python interpreter and may -# run arbitrary code -extension-pkg-whitelist= - -# Add files or directories to the blacklist. They should be base names, not -# paths. -ignore=CVS - -# Add files or directories matching the regex patterns to the blacklist. The -# regex matches against base names, not paths. -ignore-patterns= - -# Python code to execute, usually for sys.path manipulation such as -# pygtk.require(). -#init-hook= - -# Use multiple processes to speed up Pylint. -jobs=1 - -# List of plugins (as comma separated values of python modules names) to load, -# usually to register additional checkers. -load-plugins= - -# Pickle collected data for later comparisons. -persistent=yes - -# Specify a configuration file. -#rcfile= - -# When enabled, pylint would attempt to guess common misconfiguration and emit -# user-friendly hints instead of false-positive error messages -suggestion-mode=yes - -# Allow loading of arbitrary C extensions. Extensions are imported into the -# active Python interpreter and may run arbitrary code. -unsafe-load-any-extension=no - - -[MESSAGES CONTROL] - -# Only show warnings with the listed confidence levels. Leave empty to show -# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED -confidence= - -# Disable the message, report, category or checker with the given id(s). You -# can either give multiple identifiers separated by comma (,) or put this -# option multiple times (only on the command line, not in the configuration -# file where it should appear only once).You can also use "--disable=all" to -# disable everything first and then reenable specific checks. For example, if -# you want to run only the similarities checker, you can use "--disable=all -# --enable=similarities". If you want to run only the classes checker, but have -# no Warning level messages displayed, use"--disable=all --enable=classes -# --disable=W" -disable=missing-docstring, - no-else-return, - len-as-condition, - too-many-arguments, - too-many-locals, - too-many-branches, - too-many-statements, - no-self-use, - too-few-public-methods, - bad-continuation, - import-outside-toplevel, - no-else-continue, - unnecessary-comprehension, - unsubscriptable-object, - cyclic-import, - ungrouped-imports, - not-an-iterable, - no-member, - #TODO: Remove this once todos are done - fixme, - consider-using-with, - consider-using-f-string - - -# Enable the message, report, category or checker with the given id(s). You can -# either give multiple identifier separated by comma (,) or put this option -# multiple time (only on the command line, not in the configuration file where -# it should appear only once). See also the "--disable" option for examples. -enable=c-extension-no-member - - -[REPORTS] - -# Python expression which should return a note less than 10 (10 is the highest -# note). You have access to the variables errors warning, statement which -# respectively contain the number of errors / warnings messages and the total -# number of statements analyzed. This is used by the global evaluation report -# (RP0004). -evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) - -# Template used to display messages. This is a python new-style format string -# used to format the message information. See doc for all details -#msg-template= - -# Set the output format. Available formats are text, parseable, colorized, json -# and msvs (visual studio).You can also give a reporter class, eg -# mypackage.mymodule.MyReporterClass. -output-format=text - -# Tells whether to display a full report or only the messages -reports=no - -# Activate the evaluation score. -score=yes - - -[REFACTORING] - -# Maximum number of nested blocks for function / method body -max-nested-blocks=5 - -# Complete name of functions that never returns. When checking for -# inconsistent-return-statements if a never returning function is called then -# it will be considered as an explicit return statement and no message will be -# printed. -never-returning-functions=optparse.Values,sys.exit - - -[FORMAT] - -# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. -expected-line-ending-format= - -# Regexp for a line that is allowed to be longer than the limit. -ignore-long-lines=^\s*(# )??$ - -# Number of spaces of indent required inside a hanging or continued line. -indent-after-paren=4 - -# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 -# tab). -indent-string=' ' - -# Maximum number of characters on a single line. -max-line-length=100 - -# Maximum number of lines in a module -max-module-lines=1000 - -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma, - dict-separator - -# Allow the body of a class to be on the same line as the declaration if body -# contains single statement. -single-line-class-stmt=no - -# Allow the body of an if to be on the same line as the test if there is no -# else. -single-line-if-stmt=no - - -[VARIABLES] - -# List of additional names supposed to be defined in builtins. Remember that -# you should avoid to define new builtins when possible. -additional-builtins= - -# Tells whether unused global variables should be treated as a violation. -allow-global-unused-variables=yes - -# List of strings which can identify a callback function by name. A callback -# name must start or end with one of those strings. -callbacks=cb_, - _cb - -# A regular expression matching the name of dummy variables (i.e. expectedly -# not used). -dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ - -# Argument names that match this expression will be ignored. Default to name -# with leading underscore -ignored-argument-names=_.*|^ignored_|^unused_ - -# Tells whether we should check for unused import in __init__ files. -init-import=no - -# List of qualified module names which can have objects that can redefine -# builtins. -redefining-builtins-modules=six.moves,past.builtins,future.builtins - - -[BASIC] - -# Naming style matching correct argument names -argument-naming-style=snake_case - -# Regular expression matching correct argument names. Overrides argument- -# naming-style -#argument-rgx= - -# Naming style matching correct attribute names -attr-naming-style=snake_case - -# Regular expression matching correct attribute names. Overrides attr-naming- -# style -#attr-rgx= - -# Bad variable names which should always be refused, separated by a comma -bad-names=foo, - bar, - baz, - toto, - tutu, - tata - -# Naming style matching correct class attribute names -class-attribute-naming-style=any - -# Regular expression matching correct class attribute names. Overrides class- -# attribute-naming-style -#class-attribute-rgx= - -# Naming style matching correct class names -class-naming-style=PascalCase - -# Regular expression matching correct class names. Overrides class-naming-style -#class-rgx= - -# Naming style matching correct constant names -const-naming-style=UPPER_CASE - -# Regular expression matching correct constant names. Overrides const-naming- -# style -#const-rgx= - -# Minimum line length for functions/classes that require docstrings, shorter -# ones are exempt. -docstring-min-length=-1 - -# Naming style matching correct function names -function-naming-style=snake_case - -# Regular expression matching correct function names. Overrides function- -# naming-style -#function-rgx= - -# Good variable names which should always be accepted, separated by a comma -good-names=b, - e, - i, - j, - k, - n, - m, - t, - q, - x, - y, - z, - ax, - bw, - df, - dx, - ex, - gs, - ic, - mu, - ok, - sd, - tr, - eta, - Run, - _log, - _ - - -# Include a hint for the correct naming format with invalid-name -include-naming-hint=no - -# Naming style matching correct inline iteration names -inlinevar-naming-style=any - -# Regular expression matching correct inline iteration names. Overrides -# inlinevar-naming-style -#inlinevar-rgx= - -# Naming style matching correct method names -method-naming-style=snake_case - -# Regular expression matching correct method names. Overrides method-naming- -# style -#method-rgx= - -# Naming style matching correct module names -module-naming-style=snake_case - -# Regular expression matching correct module names. Overrides module-naming- -# style -#module-rgx= - -# Colon-delimited sets of names that determine each other's naming style when -# the name regexes allow several styles. -name-group= - -# Regular expression which should only match function or class names that do -# not require a docstring. -no-docstring-rgx=^_ - -# List of decorators that produce properties, such as abc.abstractproperty. Add -# to this list to register other decorators that produce valid properties. -property-classes=abc.abstractproperty - -# Naming style matching correct variable names -variable-naming-style=snake_case - -# Regular expression matching correct variable names. Overrides variable- -# naming-style -#variable-rgx= - - -[LOGGING] - -# Logging modules to check that the string format arguments are in logging -# function parameter format -logging-modules=logging - - -[SIMILARITIES] - -# Ignore comments when computing similarities. -ignore-comments=yes - -# Ignore docstrings when computing similarities. -ignore-docstrings=yes - -# Ignore imports when computing similarities. -ignore-imports=no - -# Minimum lines number of a similarity. -min-similarity-lines=50 - - -[TYPECHECK] - -# List of decorators that produce context managers, such as -# contextlib.contextmanager. Add to this list to register other decorators that -# produce valid context managers. -contextmanager-decorators=contextlib.contextmanager - -# List of members which are set dynamically and missed by pylint inference -# system, and so shouldn't trigger E1101 when accessed. Python regular -# expressions are accepted. -generated-members= - -# Tells whether missing members accessed in mixin class should be ignored. A -# mixin class is detected if its name ends with "mixin" (case insensitive). -ignore-mixin-members=yes - -# This flag controls whether pylint should warn about no-member and similar -# checks whenever an opaque object is returned when inferring. The inference -# can return multiple potential results while evaluating a Python object, but -# some branches might not be evaluated, which results in partial inference. In -# that case, it might be useful to still emit no-member and other checks for -# the rest of the inferred objects. -ignore-on-opaque-inference=yes - -# List of class names for which member attributes should not be checked (useful -# for classes with dynamically set attributes). This supports the use of -# qualified names. -ignored-classes=optparse.Values,thread._local,_thread._local,netCDF4 - -# List of module names for which member attributes should not be checked -# (useful for modules/projects where namespaces are manipulated during runtime -# and thus existing member attributes cannot be deduced by static analysis. It -# supports qualified module names, as well as Unix pattern matching. -ignored-modules= - -# Show a hint with possible names when a member name was not found. The aspect -# of finding the hint is based on edit distance. -missing-member-hint=yes - -# The minimum edit distance a name should have in order to be considered a -# similar match for a missing member name. -missing-member-hint-distance=1 - -# The total number of similar names that should be taken in consideration when -# showing a hint for a missing member. -missing-member-max-choices=1 - - -[MISCELLANEOUS] - -# List of note tags to take in consideration, separated by a comma. -notes=FIXME, - XXX, - TODO - - -[SPELLING] - -# Limits count of emitted suggestions for spelling mistakes -max-spelling-suggestions=4 - -# Spelling dictionary name. Available dictionaries: none. To make it working -# install python-enchant package. -spelling-dict= - -# List of comma separated words that should not be checked. -spelling-ignore-words= - -# A path to a file that contains private dictionary; one word per line. -spelling-private-dict-file= - -# Tells whether to store unknown words to indicated private dictionary in -# --spelling-private-dict-file option instead of raising a message. -spelling-store-unknown-words=no - - -[CLASSES] - -# List of method names used to declare (i.e. assign) instance attributes. -defining-attr-methods=__init__, - __new__, - setUp - -# List of member names, which should be excluded from the protected access -# warning. -exclude-protected=_asdict, - _fields, - _replace, - _source, - _make - -# List of valid names for the first argument in a class method. -valid-classmethod-first-arg=cls - -# List of valid names for the first argument in a metaclass class method. -valid-metaclass-classmethod-first-arg=mcs - - -[IMPORTS] - -# Allow wildcard imports from modules that define __all__. -allow-wildcard-with-all=no - -# Analyse import fallback blocks. This can be used to support both Python 2 and -# 3 compatible code, which means that the block might have code that exists -# only in one or another interpreter, leading to false positives when analysed. -analyse-fallback-blocks=no - -# Deprecated modules which should not be used, separated by a comma -deprecated-modules=optparse,tkinter.tix - -# Create a graph of external dependencies in the given file (report RP0402 must -# not be disabled) -ext-import-graph= - -# Create a graph of every (i.e. internal and external) dependencies in the -# given file (report RP0402 must not be disabled) -import-graph= - -# Create a graph of internal dependencies in the given file (report RP0402 must -# not be disabled) -int-import-graph= - -# Force import order to recognize a module as part of the standard -# compatibility libraries. -known-standard-library= - -# Force import order to recognize a module as part of a third party library. -known-third-party=enchant - - -[DESIGN] - -# Maximum number of arguments for function / method -max-args=10 - -# Maximum number of attributes for a class (see R0902). -max-attributes=10 - -# Maximum number of boolean expressions in a if statement -max-bool-expr=5 - -# Maximum number of branch for function / method body -max-branches=12 - -# Maximum number of locals for function / method body -max-locals=15 - -# Maximum number of parents for a class (see R0901). -max-parents=7 - -# Maximum number of public methods for a class (see R0904). -max-public-methods=20 - -# Maximum number of return / yield for function / method body -max-returns=6 - -# Maximum number of statements in function / method body -max-statements=50 - -# Minimum number of public methods for a class (see R0903). -min-public-methods=2 - - -[EXCEPTIONS] - -# Exceptions that will emit a warning when being caught. Defaults to -# "Exception" -overgeneral-exceptions=Exception +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code +extension-pkg-whitelist= + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. +jobs=1 + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins= + +# Pickle collected data for later comparisons. +persistent=yes + +# Specify a configuration file. +#rcfile= + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages +suggestion-mode=yes + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=missing-docstring, + no-else-return, + len-as-condition, + too-many-arguments, + too-many-locals, + too-many-branches, + too-many-statements, + no-self-use, + too-few-public-methods, + bad-continuation, + import-outside-toplevel, + no-else-continue, + unnecessary-comprehension, + unsubscriptable-object, + cyclic-import, + ungrouped-imports, + not-an-iterable, + no-member, + #TODO: Remove this once todos are done + fixme, + consider-using-with, + consider-using-f-string + + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=c-extension-no-member + + +[REPORTS] + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio).You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + +# Complete name of functions that never returns. When checking for +# inconsistent-return-statements if a never returning function is called then +# it will be considered as an explicit return statement and no message will be +# printed. +never-returning-functions=optparse.Values,sys.exit + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )??$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=100 + +# Maximum number of lines in a module +max-module-lines=1000 + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check=trailing-comma, + dict-separator + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_, + _cb + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,past.builtins,future.builtins + + +[BASIC] + +# Naming style matching correct argument names +argument-naming-style=snake_case + +# Regular expression matching correct argument names. Overrides argument- +# naming-style +#argument-rgx= + +# Naming style matching correct attribute names +attr-naming-style=snake_case + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style +#attr-rgx= + +# Bad variable names which should always be refused, separated by a comma +bad-names=foo, + bar, + baz, + toto, + tutu, + tata + +# Naming style matching correct class attribute names +class-attribute-naming-style=any + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style +#class-attribute-rgx= + +# Naming style matching correct class names +class-naming-style=PascalCase + +# Regular expression matching correct class names. Overrides class-naming-style +#class-rgx= + +# Naming style matching correct constant names +const-naming-style=UPPER_CASE + +# Regular expression matching correct constant names. Overrides const-naming- +# style +#const-rgx= + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming style matching correct function names +function-naming-style=snake_case + +# Regular expression matching correct function names. Overrides function- +# naming-style +#function-rgx= + +# Good variable names which should always be accepted, separated by a comma +good-names=b, + e, + i, + j, + k, + n, + m, + t, + q, + x, + y, + z, + ax, + bw, + df, + dx, + ex, + gs, + ic, + mu, + ok, + sd, + tr, + eta, + Run, + _log, + _ + + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# Naming style matching correct inline iteration names +inlinevar-naming-style=any + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style +#inlinevar-rgx= + +# Naming style matching correct method names +method-naming-style=snake_case + +# Regular expression matching correct method names. Overrides method-naming- +# style +#method-rgx= + +# Naming style matching correct module names +module-naming-style=snake_case + +# Regular expression matching correct module names. Overrides module-naming- +# style +#module-rgx= + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty + +# Naming style matching correct variable names +variable-naming-style=snake_case + +# Regular expression matching correct variable names. Overrides variable- +# naming-style +#variable-rgx= + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=50 + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local,netCDF4 + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME, + XXX, + TODO + + +[SPELLING] + +# Limits count of emitted suggestions for spelling mistakes +max-spelling-suggestions=4 + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__, + __new__, + setUp + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict, + _fields, + _replace, + _source, + _make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[IMPORTS] + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=10 + +# Maximum number of attributes for a class (see R0902). +max-attributes=10 + +# Maximum number of boolean expressions in a if statement +max-bool-expr=5 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of locals for function / method body +max-locals=15 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of statements in function / method body +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=2 + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception diff --git a/arviz/stats/stats_utils.py b/arviz/stats/stats_utils.py index 87c4626938..a2a92d48ec 100644 --- a/arviz/stats/stats_utils.py +++ b/arviz/stats/stats_utils.py @@ -1,574 +1,574 @@ -"""Stats-utility functions for ArviZ.""" -import warnings -from collections.abc import Sequence -from copy import copy as _copy -from copy import deepcopy as _deepcopy - -import numpy as np -import pandas as pd -from scipy.fftpack import next_fast_len -from scipy.interpolate import CubicSpline -from scipy.stats.mstats import mquantiles -from xarray import apply_ufunc - -from .. import _log -from ..utils import conditional_jit, conditional_vect, conditional_dask -from .density_utils import histogram as _histogram - - -__all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "wrap_xarray_ufunc"] - - -def autocov(ary, axis=-1): - """Compute autocovariance estimates for every lag for the input array. - - Parameters - ---------- - ary : Numpy array - An array containing MCMC samples - - Returns - ------- - acov: Numpy array same size as the input array - """ - axis = axis if axis > 0 else len(ary.shape) + axis - n = ary.shape[axis] - m = next_fast_len(2 * n) - - ary = ary - ary.mean(axis, keepdims=True) - - # added to silence tuple warning for a submodule - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - - ifft_ary = np.fft.rfft(ary, n=m, axis=axis) - ifft_ary *= np.conjugate(ifft_ary) - - shape = tuple( - slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape) - ) - cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape] - cov /= n - - return cov - - -def autocorr(ary, axis=-1): - """Compute autocorrelation using FFT for every lag for the input array. - - See https://en.wikipedia.org/wiki/autocorrelation#Efficient_computation - - Parameters - ---------- - ary : Numpy array - An array containing MCMC samples - - Returns - ------- - acorr: Numpy array same size as the input array - """ - corr = autocov(ary, axis=axis) - axis = axis = axis if axis > 0 else len(corr.shape) + axis - norm = tuple( - slice(None, None) if dim != axis else slice(None, 1) for dim, _ in enumerate(corr.shape) - ) - with np.errstate(invalid="ignore"): - corr /= corr[norm] - return corr - - -def make_ufunc( - func, n_dims=2, n_output=1, n_input=1, index=Ellipsis, ravel=True, check_shape=None -): # noqa: D202 - """Make ufunc from a function taking 1D array input. - - Parameters - ---------- - func : callable - n_dims : int, optional - Number of core dimensions not broadcasted. Dimensions are skipped from the end. - At minimum n_dims > 0. - n_output : int, optional - Select number of results returned by `func`. - If n_output > 1, ufunc returns a tuple of objects else returns an object. - n_input : int, optional - Number of **array** inputs to func, i.e. ``n_input=2`` means that func is called - with ``func(ary1, ary2, *args, **kwargs)`` - index : int, optional - Slice ndarray with `index`. Defaults to `Ellipsis`. - ravel : bool, optional - If true, ravel the ndarray before calling `func`. - check_shape: bool, optional - If false, do not check if the shape of the output is compatible with n_dims and - n_output. By default, True only for n_input=1. If n_input is larger than 1, the last - input array is used to check the shape, however, shape checking with multiple inputs - may not be correct. - - Returns - ------- - callable - ufunc wrapper for `func`. - """ - if n_dims < 1: - raise TypeError("n_dims must be one or higher.") - - if n_input == 1 and check_shape is None: - check_shape = True - elif check_shape is None: - check_shape = False - - def _ufunc(*args, out=None, out_shape=None, **kwargs): - """General ufunc for single-output function.""" - arys = args[:n_input] - n_dims_out = None - if out is None: - if out_shape is None: - out = np.empty(arys[-1].shape[:-n_dims]) - else: - out = np.empty((*arys[-1].shape[:-n_dims], *out_shape)) - n_dims_out = -len(out_shape) - elif check_shape: - if out.shape != arys[-1].shape[:-n_dims]: - msg = f"Shape incorrect for `out`: {out.shape}." - msg += f" Correct shape is {arys[-1].shape[:-n_dims]}" - raise TypeError(msg) - for idx in np.ndindex(out.shape[:n_dims_out]): - arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys] - out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index] - return out - - def _multi_ufunc(*args, out=None, out_shape=None, **kwargs): - """General ufunc for multi-output function.""" - arys = args[:n_input] - element_shape = arys[-1].shape[:-n_dims] - if out is None: - if out_shape is None: - out = tuple(np.empty(element_shape) for _ in range(n_output)) - else: - out = tuple(np.empty((*element_shape, *out_shape[i])) for i in range(n_output)) - - elif check_shape: - raise_error = False - correct_shape = tuple(element_shape for _ in range(n_output)) - if isinstance(out, tuple): - out_shape = tuple(item.shape for item in out) - if out_shape != correct_shape: - raise_error = True - else: - raise_error = True - out_shape = "not tuple, type={type(out)}" - if raise_error: - msg = f"Shapes incorrect for `out`: {out_shape}." - msg += f" Correct shapes are {correct_shape}" - raise TypeError(msg) - for idx in np.ndindex(element_shape): - arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys] - results = func(*arys_idx, *args[n_input:], **kwargs) - for i, res in enumerate(results): - out[i][idx] = np.asarray(res)[index] - return out - - if n_output > 1: - ufunc = _multi_ufunc - else: - ufunc = _ufunc - - update_docstring(ufunc, func, n_output) - return ufunc - - -@conditional_dask -def wrap_xarray_ufunc( - ufunc, - *datasets, - ufunc_kwargs=None, - func_args=None, - func_kwargs=None, - dask_kwargs=None, - **kwargs, -): - """Wrap make_ufunc with xarray.apply_ufunc. - - Parameters - ---------- - ufunc : callable - datasets : xarray.dataset - ufunc_kwargs : dict - Keyword arguments passed to `make_ufunc`. - - 'n_dims', int, by default 2 - - 'n_output', int, by default 1 - - 'n_input', int, by default len(datasets) - - 'index', slice, by default Ellipsis - - 'ravel', bool, by default True - func_args : tuple - Arguments passed to 'ufunc'. - func_kwargs : dict - Keyword arguments passed to 'ufunc'. - - 'out_shape', int, by default None - dask_kwargs : dict - Dask related kwargs passed to :func:`xarray:xarray.apply_ufunc`. - Use :meth:`~arviz.Dask.enable_dask` to set default kwargs. - **kwargs - Passed to xarray.apply_ufunc. - - Returns - ------- - xarray.dataset - """ - if ufunc_kwargs is None: - ufunc_kwargs = {} - ufunc_kwargs.setdefault("n_input", len(datasets)) - if func_args is None: - func_args = tuple() - if func_kwargs is None: - func_kwargs = {} - if dask_kwargs is None: - dask_kwargs = {} - - kwargs.setdefault( - "input_core_dims", tuple(("chain", "draw") for _ in range(len(func_args) + len(datasets))) - ) - ufunc_kwargs.setdefault("n_dims", len(kwargs["input_core_dims"][-1])) - kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1)))) - - callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs) - - return apply_ufunc( - callable_ufunc, *datasets, *func_args, kwargs=func_kwargs, **dask_kwargs, **kwargs - ) - - -def update_docstring(ufunc, func, n_output=1): - """Update ArviZ generated ufunc docstring.""" - module = "" - name = "" - docstring = "" - if hasattr(func, "__module__") and isinstance(func.__module__, str): - module += func.__module__ - if hasattr(func, "__name__"): - name += func.__name__ - if hasattr(func, "__doc__") and isinstance(func.__doc__, str): - docstring += func.__doc__ - ufunc.__doc__ += "\n\n" - if module or name: - ufunc.__doc__ += "This function is a ufunc wrapper for " - ufunc.__doc__ += module + "." + name - ufunc.__doc__ += "\n" - ufunc.__doc__ += 'Call ufunc with n_args from xarray against "chain" and "draw" dimensions:' - ufunc.__doc__ += "\n\n" - input_core_dims = 'tuple(("chain", "draw") for _ in range(n_args))' - if n_output > 1: - output_core_dims = f" tuple([] for _ in range({n_output}))" - msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims}, " - msg += f"output_core_dims={ output_core_dims})" - ufunc.__doc__ += msg - else: - output_core_dims = "" - msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims})" - ufunc.__doc__ += msg - ufunc.__doc__ += "\n\n" - ufunc.__doc__ += "For example: np.std(data, ddof=1) --> n_args=2" - if docstring: - ufunc.__doc__ += "\n\n" - ufunc.__doc__ += module - ufunc.__doc__ += name - ufunc.__doc__ += " docstring:" - ufunc.__doc__ += "\n\n" - ufunc.__doc__ += docstring - - -def logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True): - """Stable logsumexp when b >= 0 and b is scalar. - - b_inv overwrites b unless b_inv is None. - """ - # check dimensions for result arrays - ary = np.asarray(ary) - if ary.dtype.kind == "i": - ary = ary.astype(np.float64) - dtype = ary.dtype.type - shape = ary.shape - shape_len = len(shape) - if isinstance(axis, Sequence): - axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis) - agroup = axis - else: - axis = axis if (axis is None) or (axis >= 0) else shape_len + axis - agroup = (axis,) - shape_max = ( - tuple(1 for _ in shape) - if axis is None - else tuple(1 if i in agroup else d for i, d in enumerate(shape)) - ) - # create result arrays - if out is None: - if not keepdims: - out_shape = ( - tuple() - if axis is None - else tuple(d for i, d in enumerate(shape) if i not in agroup) - ) - else: - out_shape = shape_max - out = np.empty(out_shape, dtype=dtype) - if b_inv == 0: - return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf - if b_inv is None and b == 0: - return np.full_like(out, -np.inf) if out.shape else -np.inf - ary_max = np.empty(shape_max, dtype=dtype) - # calculations - ary.max(axis=axis, keepdims=True, out=ary_max) - if copy: - ary = ary.copy() - ary -= ary_max - np.exp(ary, out=ary) - ary.sum(axis=axis, keepdims=keepdims, out=out) - np.log(out, out=out) - if b_inv is not None: - ary_max -= np.log(b_inv) - elif b: - ary_max += np.log(b) - out += ary_max.squeeze() if not keepdims else ary_max - # transform to scalar if possible - return out if out.shape else dtype(out) - - -def quantile(ary, q, axis=None, limit=None): - """Use same quantile function as R (Type 7).""" - if limit is None: - limit = tuple() - return mquantiles(ary, q, alphap=1, betap=1, axis=axis, limit=limit) - - -def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwargs=None): - """Validate ndarray. - - Parameters - ---------- - ary : numpy.ndarray - check_nan : bool - Check if any value contains NaN. - check_shape : bool - Check if array has correct shape. Assumes dimensions in order (chain, draw, *shape). - For 1D arrays (shape = (n,)) assumes chain equals 1. - nan_kwargs : dict - Valid kwargs are: - axis : int, - Defaults to None. - how : str, {"all", "any"} - Default to "any". - shape_kwargs : dict - Valid kwargs are: - min_chains : int - Defaults to 1. - min_draws : int - Defaults to 4. - - Returns - ------- - bool - """ - ary = np.asarray(ary) - - nan_error = False - draw_error = False - chain_error = False - - if check_nan: - if nan_kwargs is None: - nan_kwargs = {} - - isnan = np.isnan(ary) - axis = nan_kwargs.get("axis", None) - if nan_kwargs.get("how", "any").lower() == "all": - nan_error = isnan.all(axis) - else: - nan_error = isnan.any(axis) - - if (isinstance(nan_error, bool) and nan_error) or nan_error.any(): - _log.warning("Array contains NaN-value.") - - if check_shape: - shape = ary.shape - - if shape_kwargs is None: - shape_kwargs = {} - - min_chains = shape_kwargs.get("min_chains", 2) - min_draws = shape_kwargs.get("min_draws", 4) - error_msg = f"Shape validation failed: input_shape: {shape}, " - error_msg += f"minimum_shape: (chains={min_chains}, draws={min_draws})" - - chain_error = ((min_chains > 1) and (len(shape) < 2)) or (shape[0] < min_chains) - draw_error = ((len(shape) < 2) and (shape[0] < min_draws)) or ( - (len(shape) > 1) and (shape[1] < min_draws) - ) - - if chain_error or draw_error: - _log.warning(error_msg) - - return nan_error | chain_error | draw_error - - -def get_log_likelihood(idata, var_name=None): - """Retrieve the log likelihood dataarray of a given variable.""" - if ( - not hasattr(idata, "log_likelihood") - and hasattr(idata, "sample_stats") - and hasattr(idata.sample_stats, "log_likelihood") - ): - warnings.warn( - "Storing the log_likelihood in sample_stats groups has been deprecated", - DeprecationWarning, - ) - return idata.sample_stats.log_likelihood - if not hasattr(idata, "log_likelihood"): - raise TypeError("log likelihood not found in inference data object") - if var_name is None: - var_names = list(idata.log_likelihood.data_vars) - if len(var_names) > 1: - raise TypeError( - f"Found several log likelihood arrays {var_names}, var_name cannot be None" - ) - return idata.log_likelihood[var_names[0]] - else: - try: - log_likelihood = idata.log_likelihood[var_name] - except KeyError as err: - raise TypeError(f"No log likelihood data named {var_name} found") from err - return log_likelihood - - -BASE_FMT = """Computed from {{n_samples}} by {{n_points}} log-likelihood matrix - -{{0:{0}}} Estimate SE -{{scale}}_{{kind}} {{1:8.2f}} {{2:7.2f}} -p_{{kind:{1}}} {{3:8.2f}} -""" -POINTWISE_LOO_FMT = """------ - -Pareto k diagnostic values: - {{0:>{0}}} {{1:>6}} -(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}% - (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}% - (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}% - (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}% -""" -SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"} - - -class ELPDData(pd.Series): # pylint: disable=too-many-ancestors - """Class to contain the data from elpd information criterion like waic or loo.""" - - def __str__(self): - """Print elpd data in a user friendly way.""" - kind = self.index[0] - - if kind not in ("loo", "waic"): - raise ValueError("Invalid ELPDData object") - - scale_str = SCALE_DICT[self[f"{kind}_scale"]] - padding = len(scale_str) + len(kind) + 1 - base = BASE_FMT.format(padding, padding - 2) - base = base.format( - "", - kind=kind, - scale=scale_str, - n_samples=self.n_samples, - n_points=self.n_data_points, - *self.values, - ) - - if self.warning: - base += "\n\nThere has been a warning during the calculation. Please check the results." - - if kind == "loo" and "pareto_k" in self: - bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf]) - counts, *_ = _histogram(self.pareto_k.values, bins) - extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts))))) - extended = extended.format( - "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)] - ) - base = "\n".join([base, extended]) - return base - - def __repr__(self): - """Alias to ``__str__``.""" - return self.__str__() - - def copy(self, deep=True): - """Perform a pandas deep copy of the ELPDData plus a copy of the stored data.""" - copied_obj = pd.Series.copy(self) - for key in copied_obj.keys(): - if deep: - copied_obj[key] = _deepcopy(copied_obj[key]) - else: - copied_obj[key] = _copy(copied_obj[key]) - return ELPDData(copied_obj) - - -@conditional_jit -def stats_variance_1d(data, ddof=0): - a_a, b_b = 0, 0 - for i in data: - a_a = a_a + i - b_b = b_b + i * i - var = b_b / (len(data)) - ((a_a / (len(data))) ** 2) - var = var * (len(data) / (len(data) - ddof)) - return var - - -def stats_variance_2d(data, ddof=0, axis=1): - if data.ndim == 1: - return stats_variance_1d(data, ddof=ddof) - a_a, b_b = data.shape - if axis == 1: - var = np.zeros(a_a) - for i in range(a_a): - var[i] = stats_variance_1d(data[i], ddof=ddof) - return var - else: - var = np.zeros(b_b) - for i in range(b_b): - var[i] = stats_variance_1d(data[:, i], ddof=ddof) - return var - - -@conditional_vect -def _sqrt(a_a, b_b): - return (a_a + b_b) ** 0.5 - - -def _circfunc(samples, high, low, skipna): - samples = np.asarray(samples) - if skipna: - samples = samples[~np.isnan(samples)] - if samples.size == 0: - return np.nan - return _angle(samples, low, high, np.pi) - - -@conditional_vect -def _angle(samples, low, high, p_i=np.pi): - ang = (samples - low) * 2.0 * p_i / (high - low) - return ang - - -def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, axis=None): - ang = _circfunc(samples, high, low, skipna) - s_s = np.sin(ang).mean(axis=axis) - c_c = np.cos(ang).mean(axis=axis) - r_r = np.hypot(s_s, c_c) - return ((high - low) / 2.0 / np.pi) * np.sqrt(-2 * np.log(r_r)) - - -def smooth_data(obs_vals, pp_vals): - """Smooth data, helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit.""" - x = np.linspace(0, 1, len(obs_vals)) - csi = CubicSpline(x, obs_vals) - obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals))) - - x = np.linspace(0, 1, pp_vals.shape[1]) - csi = CubicSpline(x, pp_vals, axis=1) - pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1])) - - return obs_vals, pp_vals +"""Stats-utility functions for ArviZ.""" +import warnings +from collections.abc import Sequence +from copy import copy as _copy +from copy import deepcopy as _deepcopy + +import numpy as np +import pandas as pd +from scipy.fftpack import next_fast_len +from scipy.interpolate import CubicSpline +from scipy.stats.mstats import mquantiles +from xarray import apply_ufunc + +from .. import _log +from ..utils import conditional_jit, conditional_vect, conditional_dask +from .density_utils import histogram as _histogram + + +__all__ = ["autocorr", "autocov", "ELPDData", "make_ufunc", "wrap_xarray_ufunc"] + + +def autocov(ary, axis=-1): + """Compute autocovariance estimates for every lag for the input array. + + Parameters + ---------- + ary : Numpy array + An array containing MCMC samples + + Returns + ------- + acov: Numpy array same size as the input array + """ + axis = axis if axis > 0 else len(ary.shape) + axis + n = ary.shape[axis] + m = next_fast_len(2 * n) + + ary = ary - ary.mean(axis, keepdims=True) + + # added to silence tuple warning for a submodule + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + ifft_ary = np.fft.rfft(ary, n=m, axis=axis) + ifft_ary *= np.conjugate(ifft_ary) + + shape = tuple( + slice(None) if dim_len != axis else slice(0, n) for dim_len, _ in enumerate(ary.shape) + ) + cov = np.fft.irfft(ifft_ary, n=m, axis=axis)[shape] + cov /= n + + return cov + + +def autocorr(ary, axis=-1): + """Compute autocorrelation using FFT for every lag for the input array. + + See https://en.wikipedia.org/wiki/autocorrelation#Efficient_computation + + Parameters + ---------- + ary : Numpy array + An array containing MCMC samples + + Returns + ------- + acorr: Numpy array same size as the input array + """ + corr = autocov(ary, axis=axis) + axis = axis = axis if axis > 0 else len(corr.shape) + axis + norm = tuple( + slice(None, None) if dim != axis else slice(None, 1) for dim, _ in enumerate(corr.shape) + ) + with np.errstate(invalid="ignore"): + corr /= corr[norm] + return corr + + +def make_ufunc( + func, n_dims=2, n_output=1, n_input=1, index=Ellipsis, ravel=True, check_shape=None +): # noqa: D202 + """Make ufunc from a function taking 1D array input. + + Parameters + ---------- + func : callable + n_dims : int, optional + Number of core dimensions not broadcasted. Dimensions are skipped from the end. + At minimum n_dims > 0. + n_output : int, optional + Select number of results returned by `func`. + If n_output > 1, ufunc returns a tuple of objects else returns an object. + n_input : int, optional + Number of **array** inputs to func, i.e. ``n_input=2`` means that func is called + with ``func(ary1, ary2, *args, **kwargs)`` + index : int, optional + Slice ndarray with `index`. Defaults to `Ellipsis`. + ravel : bool, optional + If true, ravel the ndarray before calling `func`. + check_shape: bool, optional + If false, do not check if the shape of the output is compatible with n_dims and + n_output. By default, True only for n_input=1. If n_input is larger than 1, the last + input array is used to check the shape, however, shape checking with multiple inputs + may not be correct. + + Returns + ------- + callable + ufunc wrapper for `func`. + """ + if n_dims < 1: + raise TypeError("n_dims must be one or higher.") + + if n_input == 1 and check_shape is None: + check_shape = True + elif check_shape is None: + check_shape = False + + def _ufunc(*args, out=None, out_shape=None, **kwargs): + """General ufunc for single-output function.""" + arys = args[:n_input] + n_dims_out = None + if out is None: + if out_shape is None: + out = np.empty(arys[-1].shape[:-n_dims]) + else: + out = np.empty((*arys[-1].shape[:-n_dims], *out_shape)) + n_dims_out = -len(out_shape) + elif check_shape: + if out.shape != arys[-1].shape[:-n_dims]: + msg = f"Shape incorrect for `out`: {out.shape}." + msg += f" Correct shape is {arys[-1].shape[:-n_dims]}" + raise TypeError(msg) + for idx in np.ndindex(out.shape[:n_dims_out]): + arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys] + out[idx] = np.asarray(func(*arys_idx, *args[n_input:], **kwargs))[index] + return out + + def _multi_ufunc(*args, out=None, out_shape=None, **kwargs): + """General ufunc for multi-output function.""" + arys = args[:n_input] + element_shape = arys[-1].shape[:-n_dims] + if out is None: + if out_shape is None: + out = tuple(np.empty(element_shape) for _ in range(n_output)) + else: + out = tuple(np.empty((*element_shape, *out_shape[i])) for i in range(n_output)) + + elif check_shape: + raise_error = False + correct_shape = tuple(element_shape for _ in range(n_output)) + if isinstance(out, tuple): + out_shape = tuple(item.shape for item in out) + if out_shape != correct_shape: + raise_error = True + else: + raise_error = True + out_shape = "not tuple, type={type(out)}" + if raise_error: + msg = f"Shapes incorrect for `out`: {out_shape}." + msg += f" Correct shapes are {correct_shape}" + raise TypeError(msg) + for idx in np.ndindex(element_shape): + arys_idx = [ary[idx].ravel() if ravel else ary[idx] for ary in arys] + results = func(*arys_idx, *args[n_input:], **kwargs) + for i, res in enumerate(results): + out[i][idx] = np.asarray(res)[index] + return out + + if n_output > 1: + ufunc = _multi_ufunc + else: + ufunc = _ufunc + + update_docstring(ufunc, func, n_output) + return ufunc + + +@conditional_dask +def wrap_xarray_ufunc( + ufunc, + *datasets, + ufunc_kwargs=None, + func_args=None, + func_kwargs=None, + dask_kwargs=None, + **kwargs, +): + """Wrap make_ufunc with xarray.apply_ufunc. + + Parameters + ---------- + ufunc : callable + datasets : xarray.dataset + ufunc_kwargs : dict + Keyword arguments passed to `make_ufunc`. + - 'n_dims', int, by default 2 + - 'n_output', int, by default 1 + - 'n_input', int, by default len(datasets) + - 'index', slice, by default Ellipsis + - 'ravel', bool, by default True + func_args : tuple + Arguments passed to 'ufunc'. + func_kwargs : dict + Keyword arguments passed to 'ufunc'. + - 'out_shape', int, by default None + dask_kwargs : dict + Dask related kwargs passed to :func:`xarray:xarray.apply_ufunc`. + Use :meth:`~arviz.Dask.enable_dask` to set default kwargs. + **kwargs + Passed to xarray.apply_ufunc. + + Returns + ------- + xarray.dataset + """ + if ufunc_kwargs is None: + ufunc_kwargs = {} + ufunc_kwargs.setdefault("n_input", len(datasets)) + if func_args is None: + func_args = tuple() + if func_kwargs is None: + func_kwargs = {} + if dask_kwargs is None: + dask_kwargs = {} + + kwargs.setdefault( + "input_core_dims", tuple(("chain", "draw") for _ in range(len(func_args) + len(datasets))) + ) + ufunc_kwargs.setdefault("n_dims", len(kwargs["input_core_dims"][-1])) + kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1)))) + + callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs) + + return apply_ufunc( + callable_ufunc, *datasets, *func_args, kwargs=func_kwargs, **dask_kwargs, **kwargs + ) + + +def update_docstring(ufunc, func, n_output=1): + """Update ArviZ generated ufunc docstring.""" + module = "" + name = "" + docstring = "" + if hasattr(func, "__module__") and isinstance(func.__module__, str): + module += func.__module__ + if hasattr(func, "__name__"): + name += func.__name__ + if hasattr(func, "__doc__") and isinstance(func.__doc__, str): + docstring += func.__doc__ + ufunc.__doc__ += "\n\n" + if module or name: + ufunc.__doc__ += "This function is a ufunc wrapper for " + ufunc.__doc__ += module + "." + name + ufunc.__doc__ += "\n" + ufunc.__doc__ += 'Call ufunc with n_args from xarray against "chain" and "draw" dimensions:' + ufunc.__doc__ += "\n\n" + input_core_dims = 'tuple(("chain", "draw") for _ in range(n_args))' + if n_output > 1: + output_core_dims = f" tuple([] for _ in range({n_output}))" + msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims}, " + msg += f"output_core_dims={ output_core_dims})" + ufunc.__doc__ += msg + else: + output_core_dims = "" + msg = f"xr.apply_ufunc(ufunc, dataset, input_core_dims={input_core_dims})" + ufunc.__doc__ += msg + ufunc.__doc__ += "\n\n" + ufunc.__doc__ += "For example: np.std(data, ddof=1) --> n_args=2" + if docstring: + ufunc.__doc__ += "\n\n" + ufunc.__doc__ += module + ufunc.__doc__ += name + ufunc.__doc__ += " docstring:" + ufunc.__doc__ += "\n\n" + ufunc.__doc__ += docstring + + +def logsumexp(ary, *, b=None, b_inv=None, axis=None, keepdims=False, out=None, copy=True): + """Stable logsumexp when b >= 0 and b is scalar. + + b_inv overwrites b unless b_inv is None. + """ + # check dimensions for result arrays + ary = np.asarray(ary) + if ary.dtype.kind == "i": + ary = ary.astype(np.float64) + dtype = ary.dtype.type + shape = ary.shape + shape_len = len(shape) + if isinstance(axis, Sequence): + axis = tuple(axis_i if axis_i >= 0 else shape_len + axis_i for axis_i in axis) + agroup = axis + else: + axis = axis if (axis is None) or (axis >= 0) else shape_len + axis + agroup = (axis,) + shape_max = ( + tuple(1 for _ in shape) + if axis is None + else tuple(1 if i in agroup else d for i, d in enumerate(shape)) + ) + # create result arrays + if out is None: + if not keepdims: + out_shape = ( + tuple() + if axis is None + else tuple(d for i, d in enumerate(shape) if i not in agroup) + ) + else: + out_shape = shape_max + out = np.empty(out_shape, dtype=dtype) + if b_inv == 0: + return np.full_like(out, np.inf, dtype=dtype) if out.shape else np.inf + if b_inv is None and b == 0: + return np.full_like(out, -np.inf) if out.shape else -np.inf + ary_max = np.empty(shape_max, dtype=dtype) + # calculations + ary.max(axis=axis, keepdims=True, out=ary_max) + if copy: + ary = ary.copy() + ary -= ary_max + np.exp(ary, out=ary) + ary.sum(axis=axis, keepdims=keepdims, out=out) + np.log(out, out=out) + if b_inv is not None: + ary_max -= np.log(b_inv) + elif b: + ary_max += np.log(b) + out += ary_max.squeeze() if not keepdims else ary_max + # transform to scalar if possible + return out if out.shape else dtype(out) + + +def quantile(ary, q, axis=None, limit=None): + """Use same quantile function as R (Type 7).""" + if limit is None: + limit = tuple() + return mquantiles(ary, q, alphap=1, betap=1, axis=axis, limit=limit) + + +def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwargs=None): + """Validate ndarray. + + Parameters + ---------- + ary : numpy.ndarray + check_nan : bool + Check if any value contains NaN. + check_shape : bool + Check if array has correct shape. Assumes dimensions in order (chain, draw, *shape). + For 1D arrays (shape = (n,)) assumes chain equals 1. + nan_kwargs : dict + Valid kwargs are: + axis : int, + Defaults to None. + how : str, {"all", "any"} + Default to "any". + shape_kwargs : dict + Valid kwargs are: + min_chains : int + Defaults to 1. + min_draws : int + Defaults to 4. + + Returns + ------- + bool + """ + ary = np.asarray(ary) + + nan_error = False + draw_error = False + chain_error = False + + if check_nan: + if nan_kwargs is None: + nan_kwargs = {} + + isnan = np.isnan(ary) + axis = nan_kwargs.get("axis", None) + if nan_kwargs.get("how", "any").lower() == "all": + nan_error = isnan.all(axis) + else: + nan_error = isnan.any(axis) + + if (isinstance(nan_error, bool) and nan_error) or nan_error.any(): + _log.warning("Array contains NaN-value.") + + if check_shape: + shape = ary.shape + + if shape_kwargs is None: + shape_kwargs = {} + + min_chains = shape_kwargs.get("min_chains", 2) + min_draws = shape_kwargs.get("min_draws", 4) + error_msg = f"Shape validation failed: input_shape: {shape}, " + error_msg += f"minimum_shape: (chains={min_chains}, draws={min_draws})" + + chain_error = ((min_chains > 1) and (len(shape) < 2)) or (shape[0] < min_chains) + draw_error = ((len(shape) < 2) and (shape[0] < min_draws)) or ( + (len(shape) > 1) and (shape[1] < min_draws) + ) + + if chain_error or draw_error: + _log.warning(error_msg) + + return nan_error | chain_error | draw_error + + +def get_log_likelihood(idata, var_name=None): + """Retrieve the log likelihood dataarray of a given variable.""" + if ( + not hasattr(idata, "log_likelihood") + and hasattr(idata, "sample_stats") + and hasattr(idata.sample_stats, "log_likelihood") + ): + warnings.warn( + "Storing the log_likelihood in sample_stats groups has been deprecated", + DeprecationWarning, + ) + return idata.sample_stats.log_likelihood + if not hasattr(idata, "log_likelihood"): + raise TypeError("log likelihood not found in inference data object") + if var_name is None: + var_names = list(idata.log_likelihood.data_vars) + if len(var_names) > 1: + raise TypeError( + f"Found several log likelihood arrays {var_names}, var_name cannot be None" + ) + return idata.log_likelihood[var_names[0]] + else: + try: + log_likelihood = idata.log_likelihood[var_name] + except KeyError as err: + raise TypeError(f"No log likelihood data named {var_name} found") from err + return log_likelihood + + +BASE_FMT = """Computed from {{n_samples}} by {{n_points}} log-likelihood matrix + +{{0:{0}}} Estimate SE +{{scale}}_{{kind}} {{1:8.2f}} {{2:7.2f}} +p_{{kind:{1}}} {{3:8.2f}} -""" +POINTWISE_LOO_FMT = """------ + +Pareto k diagnostic values: + {{0:>{0}}} {{1:>6}} +(-Inf, 0.5] (good) {{2:{0}d}} {{6:6.1f}}% + (0.5, 0.7] (ok) {{3:{0}d}} {{7:6.1f}}% + (0.7, 1] (bad) {{4:{0}d}} {{8:6.1f}}% + (1, Inf) (very bad) {{5:{0}d}} {{9:6.1f}}% +""" +SCALE_DICT = {"deviance": "deviance", "log": "elpd", "negative_log": "-elpd"} + + +class ELPDData(pd.Series): # pylint: disable=too-many-ancestors + """Class to contain the data from elpd information criterion like waic or loo.""" + + def __str__(self): + """Print elpd data in a user friendly way.""" + kind = self.index[0] + + if kind not in ("loo", "waic"): + raise ValueError("Invalid ELPDData object") + + scale_str = SCALE_DICT[self[f"{kind}_scale"]] + padding = len(scale_str) + len(kind) + 1 + base = BASE_FMT.format(padding, padding - 2) + base = base.format( + "", + kind=kind, + scale=scale_str, + n_samples=self.n_samples, + n_points=self.n_data_points, + *self.values, + ) + + if self.warning: + base += "\n\nThere has been a warning during the calculation. Please check the results." + + if kind == "loo" and "pareto_k" in self: + bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf]) + counts, *_ = _histogram(self.pareto_k.values, bins) + extended = POINTWISE_LOO_FMT.format(max(4, len(str(np.max(counts))))) + extended = extended.format( + "Count", "Pct.", *[*counts, *(counts / np.sum(counts) * 100)] + ) + base = "\n".join([base, extended]) + return base + + def __repr__(self): + """Alias to ``__str__``.""" + return self.__str__() + + def copy(self, deep=True): + """Perform a pandas deep copy of the ELPDData plus a copy of the stored data.""" + copied_obj = pd.Series.copy(self) + for key in copied_obj.keys(): + if deep: + copied_obj[key] = _deepcopy(copied_obj[key]) + else: + copied_obj[key] = _copy(copied_obj[key]) + return ELPDData(copied_obj) + + +@conditional_jit +def stats_variance_1d(data, ddof=0): + a_a, b_b = 0, 0 + for i in data: + a_a = a_a + i + b_b = b_b + i * i + var = b_b / (len(data)) - ((a_a / (len(data))) ** 2) + var = var * (len(data) / (len(data) - ddof)) + return var + + +def stats_variance_2d(data, ddof=0, axis=1): + if data.ndim == 1: + return stats_variance_1d(data, ddof=ddof) + a_a, b_b = data.shape + if axis == 1: + var = np.zeros(a_a) + for i in range(a_a): + var[i] = stats_variance_1d(data[i], ddof=ddof) + return var + else: + var = np.zeros(b_b) + for i in range(b_b): + var[i] = stats_variance_1d(data[:, i], ddof=ddof) + return var + + +@conditional_vect +def _sqrt(a_a, b_b): + return (a_a + b_b) ** 0.5 + + +def _circfunc(samples, high, low, skipna): + samples = np.asarray(samples) + if skipna: + samples = samples[~np.isnan(samples)] + if samples.size == 0: + return np.nan + return _angle(samples, low, high, np.pi) + + +@conditional_vect +def _angle(samples, low, high, p_i=np.pi): + ang = (samples - low) * 2.0 * p_i / (high - low) + return ang + + +def _circular_standard_deviation(samples, high=2 * np.pi, low=0, skipna=False, axis=None): + ang = _circfunc(samples, high, low, skipna) + s_s = np.sin(ang).mean(axis=axis) + c_c = np.cos(ang).mean(axis=axis) + r_r = np.hypot(s_s, c_c) + return ((high - low) / 2.0 / np.pi) * np.sqrt(-2 * np.log(r_r)) + + +def smooth_data(obs_vals, pp_vals): + """Smooth data, helper function for discrete data in plot_pbv, loo_pit and plot_loo_pit.""" + x = np.linspace(0, 1, len(obs_vals)) + csi = CubicSpline(x, obs_vals) + obs_vals = csi(np.linspace(0.01, 0.99, len(obs_vals))) + + x = np.linspace(0, 1, pp_vals.shape[1]) + csi = CubicSpline(x, pp_vals, axis=1) + pp_vals = csi(np.linspace(0.01, 0.99, pp_vals.shape[1])) + + return obs_vals, pp_vals From 92fcbf0577b203c16d98bbea464ed5d5c43b82d7 Mon Sep 17 00:00:00 2001 From: "Oriol (ZBook)" Date: Sun, 14 Nov 2021 12:39:51 +0200 Subject: [PATCH 37/37] fix docs and changelog --- CHANGELOG.md | 4 +--- arviz/data/datasets.py | 6 ++---- arviz/data/inference_data.py | 4 +++- arviz/data/io_netcdf.py | 7 +++++-- requirements-optional.txt | 1 - 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a7920a92f..c4836cd00e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## v0.x.x Unreleased ### New features +* [experimental] Enable dask chunking information to be passed to `InferenceData.from_netcdf` with regex support ([1749](https://github.com/arviz-devs/arviz/pull/1749)) ### Maintenance and fixes * Bokeh 3 compatibility. ([1919](https://github.com/arviz-devs/arviz/pull/1919)) @@ -32,9 +33,6 @@ * Added ability to plot HDI contours to `plot_kde` with the new `hdi_probs` parameter. ([1665](https://github.com/arviz-devs/arviz/pull/1665)) * Add dtype parsing and setting in all Stan converters ([1632](https://github.com/arviz-devs/arviz/pull/1632)) * Add option to specify colors for each element in ppc_plot ([1769](https://github.com/arviz-devs/arviz/pull/1769)) -* Enable dask chunking information to be passed to `InferenceData.from_netcdf` ([1749](https://github.com/arviz-devs/arviz/pull/1749)) -* Enable dask chunking information to be passed to `InferenceData.from_netcdf` with regex support ([1749](https://github.com/arviz-devs/arviz/pull/1749)) -* Enable dask chunking information to be passed to `InferenceData.from_netcdf` with regex support to enable parallel io ([1749](https://github.com/arviz-devs/arviz/pull/1749)) ### Maintenance and fixes * Fix conversion for numpyro models with ImproperUniform latent sites ([1713](https://github.com/arviz-devs/arviz/pull/1713)) diff --git a/arviz/data/datasets.py b/arviz/data/datasets.py index 9ef5bd3563..5f9e779f1e 100644 --- a/arviz/data/datasets.py +++ b/arviz/data/datasets.py @@ -220,13 +220,11 @@ def load_arviz_data(dataset=None, data_home=None, regex=False, **kwargs): regex : bool, optional Specifies regex support for chunking information in - `arviz.io_netcdf.from_netcdf`. This feature is currently experimental. - See :meth:`arviz.io_netcdf.from_netcdf` + :func:`arviz.from_netcdf`. This feature is currently experimental. **kwargs : dict of {str: dict}, optional - Keyword arguments to be passed into arviz.io_netcdf.from_netcdf`. + Keyword arguments to be passed to :func:`arviz.from_netcdf`. This feature is currently experimental. - See :meth:`arviz.io_netcdf.from_netcdf` Returns ------- diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 84e688f704..de218dfb05 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -328,7 +328,9 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData": filename : str location of netcdf file group_kwargs : dict of {str: dict}, optional - Keyword arguments to be passed into each call of {func}`xarray.open_dataset`. + Keyword arguments to be passed into each call of :func:`xarray.open_dataset`. + The keys of the higher level should be group names or regex matching group + names, the inner dicts re passed to ``open_dataset`` This feature is currently experimental. regex : bool, default False Specifies where regex search should be used to extend the keyword arguments. diff --git a/arviz/data/io_netcdf.py b/arviz/data/io_netcdf.py index 2928cc61f0..93574a6cdd 100644 --- a/arviz/data/io_netcdf.py +++ b/arviz/data/io_netcdf.py @@ -11,8 +11,11 @@ def from_netcdf(filename, group_kwargs=None, regex=False): ---------- filename : str name or path of the file to load trace - group_kwargs : dict of dict - Keyword arguments to be passed into each call of `xarray.open_dataset`. + group_kwargs : dict of {str: dict} + Keyword arguments to be passed into each call of :func:`xarray.open_dataset`. + The keys of the higher level should be group names or regex matching group + names, the inner dicts re passed to ``open_dataset``. + This feature is currently experimental regex : str Specifies where regex search should be used to extend the keyword arguments. diff --git a/requirements-optional.txt b/requirements-optional.txt index 2bb206a62d..8eb8fd7874 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -1,6 +1,5 @@ numba bokeh>=1.4.0 ujson -dask dask[distributed] zarr>=2.5.0