From 57fbca119129069d0cf1d60f54b0777bff277817 Mon Sep 17 00:00:00 2001 From: saal Date: Thu, 8 Aug 2024 14:50:16 -0700 Subject: [PATCH 1/4] Started updating tests but will probably fail --- .pre-commit-config.yaml | 10 +++ AUTHORS.rst | 2 - CHANGELOG.md | 16 ++++ CHANGELOG.rst | 3 - CODE_OF_CONDUCT.md | 111 ++++++++++++++++++++++++++ CODE_OF_CONDUCT.rst | 81 ------------------- CONTRIBUTING.rst => CONTRIBUTING.md | 11 +-- LICENSE | 4 +- MANIFEST.in | 12 --- README.md | 14 ++++ README.rst | 11 --- pyproject.toml | 61 ++++++++++++++ setup.py | 62 -------------- sphinx/conf.py | 4 +- src/anml/__about__.py | 10 ++- src/anml/data/component.py | 20 ++--- src/anml/data/prototype.py | 9 +-- src/anml/data/validator.py | 12 +-- src/anml/getter/prior.py | 29 ++++--- src/anml/getter/spline.py | 63 ++++++++------- src/anml/parameter/main.py | 79 +++++++++++------- src/anml/parameter/smoothmapping.py | 41 ++++------ src/anml/prior/main.py | 27 +++---- src/anml/prior/utils.py | 10 +-- src/anml/variable/main.py | 27 ++++--- src/anml/variable/spline.py | 26 +++--- tests/data/test_component.py | 7 +- tests/data/test_data_prototype.py | 7 +- tests/data/test_example.py | 3 +- tests/getter/test_spline.py | 11 ++- tests/parameter/test_main.py | 11 ++- tests/parameter/test_smoothmapping.py | 19 +++-- tests/prior/test_main.py | 14 ++-- tests/prior/test_utils.py | 17 ++-- tests/variable/test_main.py | 28 +++++-- tests/variable/test_spline.py | 35 +++++--- tox.ini | 7 -- 37 files changed, 490 insertions(+), 424 deletions(-) create mode 100644 .pre-commit-config.yaml delete mode 100644 AUTHORS.rst create mode 100644 CHANGELOG.md delete mode 100644 CHANGELOG.rst create mode 100644 CODE_OF_CONDUCT.md delete mode 100644 CODE_OF_CONDUCT.rst rename CONTRIBUTING.rst => CONTRIBUTING.md (83%) delete mode 100644 MANIFEST.in create mode 100644 README.md delete mode 100644 README.rst create mode 100644 pyproject.toml delete mode 100644 setup.py delete mode 100644 tox.ini diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..3c9fbd0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.4.2 + hooks: + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/AUTHORS.rst b/AUTHORS.rst deleted file mode 100644 index 8c5f784..0000000 --- a/AUTHORS.rst +++ /dev/null @@ -1,2 +0,0 @@ -Authors -======= diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..0313489 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,16 @@ +## [0.1.0] - 2024-07-25 + +### Added +- pyproject + +### Fixed +- updated to support xspline API changes since version 0.0.7 + +### Changes + + +## [0.0.0] - 2020-05-29 + +**0.0.0 - 05/29/2020** + +- Repo creation \ No newline at end of file diff --git a/CHANGELOG.rst b/CHANGELOG.rst deleted file mode 100644 index 3daaa8f..0000000 --- a/CHANGELOG.rst +++ /dev/null @@ -1,3 +0,0 @@ -**0.0.0 - 05/29/2020** - -- Repo creation \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..7ea5c47 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,111 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our community a harassment-free experience for +everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity +and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, +or sexual identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take +appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, +issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for +moderation decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when an individual is officially representing +the community in public spaces. Examples of representing our community include using an official e-mail address, posting +via an official social media account, or acting as an appointed representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible +for enforcement at +[INSERT CONTACT METHOD]. All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem +in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the +community. + +**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation +and an explanation of why the behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of actions. + +**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including +unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding +interactions in community spaces as well as external channels like social media. Violating these terms may lead to a +temporary or permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified +period of time. No public or private interaction with the people involved, including unsolicited interaction with those +enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate +behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at +[https://www.contributor-covenant.org/version/2/0/code_of_conduct.html][v2.0]. + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder][Mozilla CoC]. + +For answers to common questions about this code of conduct, see the FAQ at +[https://www.contributor-covenant.org/faq][FAQ]. Translations are available +at [https://www.contributor-covenant.org/translations][translations]. + +[homepage]: https://www.contributor-covenant.org + +[v2.0]: https://www.contributor-covenant.org/version/2/0/code_of_conduct.html + +[Mozilla CoC]: https://github.com/mozilla/diversity + +[FAQ]: https://www.contributor-covenant.org/faq + +[translations]: https://www.contributor-covenant.org/translations diff --git a/CODE_OF_CONDUCT.rst b/CODE_OF_CONDUCT.rst deleted file mode 100644 index 8aac068..0000000 --- a/CODE_OF_CONDUCT.rst +++ /dev/null @@ -1,81 +0,0 @@ -Code of Conduct -=============== - -Our Pledge ----------- - -In the interest of fostering an open and welcoming environment, we as -contributors and maintainers pledge to making participation in our project and -our community a harassment-free experience for everyone, regardless of age, body -size, disability, ethnicity, gender identity and expression, level of experience, -nationality, personal appearance, race, religion, or sexual identity and -orientation. - -Our Standards -------------- - -Examples of behavior that contributes to creating a positive environment -include: - -- Using welcoming and inclusive language -- Being respectful of differing viewpoints and experiences -- Gracefully accepting constructive criticism -- Focusing on what is best for the community -- Showing empathy towards other community members - -Examples of unacceptable behavior by participants include: - -- The use of sexualized language or imagery and unwelcome sexual attention or - advances -- Trolling, insulting/derogatory comments, and personal or political attacks -- Public or private harassment -- Publishing others' private information, such as a physical or electronic - address, without explicit permission -- Other conduct which could reasonably be considered inappropriate in a - professional setting - -Our Responsibilities --------------------- - -Project maintainers are responsible for clarifying the standards of acceptable -behavior and are expected to take appropriate and fair corrective action in -response to any instances of unacceptable behavior. - -Project maintainers have the right and responsibility to remove, edit, or -reject comments, commits, code, wiki edits, issues, and other contributions -that are not aligned to this Code of Conduct, or to ban temporarily or -permanently any contributor for other behaviors that they deem inappropriate, -threatening, offensive, or harmful. - -Scope ------ - -This Code of Conduct applies both within project spaces and in public spaces -when an individual is representing the project or its community. Examples of -representing a project or community include using an official project e-mail -address, posting via an official social media account, or acting as an appointed -representative at an online or offline event. Representation of a project may be -further defined and clarified by project maintainers. - -Enforcement ------------ - -Instances of abusive, harassing, or otherwise unacceptable behavior may be -reported by contacting the current maintainer (found in this repositories AUTHORS.rst). All -complaints will be reviewed and investigated and will result in a response that -is deemed necessary and appropriate to the circumstances. The project team is -obligated to maintain confidentiality with regard to the reporter of an incident. -Further details of specific enforcement policies may be posted separately. - -Project maintainers who do not follow or enforce the Code of Conduct in good -faith may face temporary or permanent repercussions as determined by other -members of the project's leadership. - -Attribution ------------ - -This Code of Conduct is adapted from the `Contributor Covenant`_, version 1.4, -available at `version`_. - -.. _Contributor Covenant: http://contributor-covenant.org -.. _version: http://contributor-covenant.org/version/1/4/ diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.md similarity index 83% rename from CONTRIBUTING.rst rename to CONTRIBUTING.md index a9b604c..d5eed01 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.md @@ -1,5 +1,4 @@ -Contributing -============ +# Contributing When contributing to this repository, please first discuss the change you wish to make via issue, email, or any other method with the owners of this @@ -8,8 +7,7 @@ repository before making a change. Please note we have a code of conduct, please follow it in all your interactions with the project. -Submitting Changes ------------------- +## Submitting Changes - Always make a new branch for your work. - Patches should be small to facilitate easier review. Sometimes this will @@ -17,6 +15,5 @@ Submitting Changes - Larger changes should be discussed in the project's GitHub issues page. - New features and significant bug fixes should be documented in the changelog. - You must have legal permission to distribute any code you contribute to - ``anml``, and it must be available under both the GNU - GPLv3 license. - + `anml`, and it must be available under both the GNU + GPLv3 license. \ No newline at end of file diff --git a/LICENSE b/LICENSE index b35e2ce..9e3236b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ BSD 2-Clause License -Copyright (c) 2022, IHME Math Sciences +Copyright (c) 2019-2024, IHME Math Sciences All rights reserved. Redistribution and use in source and binary forms, with or without @@ -22,4 +22,4 @@ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index a0a7254..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,12 +0,0 @@ -include AUTHORS.rst -include CHANGELOG.rst -include CODE_OF_CONDUCT.rst -include CONTRIBUTING.rst -include LICENSE -include README.rst - -recursive-include docs * -prune docs/_build - -recursive-include src/anml *.py *.yaml -recursive-include tests *.py *txt *.yaml diff --git a/README.md b/README.md new file mode 100644 index 0000000..83fab34 --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +[![PyPI version](https://badge.fury.io/py/anml.svg)](https://badge.fury.io/py/anml) +![Python](https://img.shields.io/badge/python-3.6%2B-blue.svg) +[![Build Status](https://github.com/ihmeuw-msca/regmod/workflows/build/badge.svg)](https://github.com/ihmeuw-msca/regmod/actions) +[![GitHub](https://img.shields.io/github/license/ihmeuw-msca/anml)](./LICENSE) +[![docs](https://img.shields.io/badge/docs-here-green)](https://ihmeuw-msca.github.io/anml) +[![codecov](https://img.shields.io/codecov/c/github/ihmeuw-msca/anml)](https://codecov.io/gh/ihmeuw-msca/anml) +[![codacy](https://img.shields.io/codacy/grade/ae72a07785f5469eac234d1f6bdf555f)](https://app.codacy.com/gh/ihmeuw-msca/anml/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade) + + +# anml: A Nonlinear Modeling Library + +This is a nonlinear modeling library. + +**NOTE** This repository is under construction. 🚧 ⚠️ 👷 \ No newline at end of file diff --git a/README.rst b/README.rst deleted file mode 100644 index 8492999..0000000 --- a/README.rst +++ /dev/null @@ -1,11 +0,0 @@ -.. image:: https://github.com/ihmeuw-msca/regmod/workflows/build/badge.svg - :target: https://github.com/ihmeuw-msca/regmod/actions - -.. image:: https://badge.fury.io/py/anml.svg - :target: https://badge.fury.io/py/anml - -anml: A Nonlinear Modeling Library -================================== - -**NOTE** This repository is under construction. :construction: :warning: :construction_worker: -This is a nonlinear modeling library. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..839aaa9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,61 @@ +[build-system] +build-backend = "setuptools.build_meta" +requires = [ + "setuptools>=61", +] + +[project] +name = "anml" +version = "0.1.0" +description = "This is a nonlinear modeling library." +readme = "README.md" +license = { text = "BSD-2-Clause" } +requires-python = ">=3.6" +authors = [ + { name = "IHME Math Sciences", email = "ihme.math.sciences@gmail.com" }, +] + +dependencies =[ + "numpy<2.0.0", + "pandas", + "scipy", + "xspline", + "click" +] + +[project.urls] +homepage = "https://github.com/ihmeuw-msca/anml" + +[tool.pytest.ini_options] +testpaths = ["tests", "integration"] +addopts = "-v -ra -q" +log_cli = true +log_cli_level = "INFO" +log_format = "%(asctime)s %(levelname)s %(message)s" +log_date_format = "%Y-%m-%d %H:%M:%S" +minversion = "6.0" +filterwarnings = "ignore" + +[project.optional-dependencies] +docs = [ + "sphinx>=3.0.0", + "sphinx-autodoc-typehints", + "furo", + "sphinx-click", + "IPython", + "matplotlib" +] +test = [ + "pytest", + "pytest-mock" +] +dev = [ + "sphinx>=3.0.0", + "sphinx-autodoc-typehints", + "furo", + "sphinx-click", + "IPython", + "matplotlib", + "pytest", + "pytest-mock" +] diff --git a/setup.py b/setup.py deleted file mode 100644 index 92c32a2..0000000 --- a/setup.py +++ /dev/null @@ -1,62 +0,0 @@ -from pathlib import Path - -from setuptools import find_packages, setup - -if __name__ == "__main__": - base_dir = Path(__file__).parent - src_dir = base_dir / "src" - - about = {} - with (src_dir / "anml" / "__about__.py").open() as f: - exec(f.read(), about) - - with (base_dir / "README.rst").open() as f: - long_description = f.read() - - install_requirements = [ - "numpy>=1.21", - "pandas", - "scipy", - "xspline", - "click", - ] - - test_requirements = [ - "pytest", - "pytest-mock" - ] - - doc_requirements = [ - "sphinx>=3.0.0", - "sphinx-autodoc-typehints", - "furo", - "sphinx-click", - "IPython", - "matplotlib" - ] - - setup( - name=about["__title__"], - version=about["__version__"], - - description=about["__summary__"], - long_description=long_description, - license=about["__license__"], - url=about["__uri__"], - - author=about["__author__"], - author_email=about["__email__"], - - package_dir={"": "src"}, - packages=find_packages(where="src"), - include_package_data=True, - - install_requires=install_requirements, - tests_require=test_requirements, - extras_require={ - "docs": doc_requirements, - "test": test_requirements, - "dev": doc_requirements + test_requirements - }, - zip_safe=False, - ) diff --git a/sphinx/conf.py b/sphinx/conf.py index ad3175d..625d228 100644 --- a/sphinx/conf.py +++ b/sphinx/conf.py @@ -59,7 +59,7 @@ "NDArray": "NDArray", "DataFrame": "DataFrame", } -autodoc_member_order = 'bysource' +autodoc_member_order = "bysource" # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] @@ -97,6 +97,6 @@ "color-brand-primary": "#6FD8D1", "color-brand-content": "#6FD8D1", "color-problematic": "#FA9F50", - "color-background-secondary": "#202020" + "color-background-secondary": "#202020", }, } diff --git a/src/anml/__about__.py b/src/anml/__about__.py index ca7d9b0..f433b45 100644 --- a/src/anml/__about__.py +++ b/src/anml/__about__.py @@ -1,6 +1,12 @@ __all__ = [ - "__title__", "__summary__", "__uri__", "__version__", "__author__", - "__email__", "__license__", "__copyright__", + "__title__", + "__summary__", + "__uri__", + "__version__", + "__author__", + "__email__", + "__license__", + "__copyright__", ] __title__ = "anml" diff --git a/src/anml/data/component.py b/src/anml/data/component.py index 96a0022..cf73fa6 100644 --- a/src/anml/data/component.py +++ b/src/anml/data/component.py @@ -38,10 +38,12 @@ class Component: """ - def __init__(self, - key: str, - validators: Optional[List[Validator]] = None, - default_value: Optional[Any] = None): + def __init__( + self, + key: str, + validators: Optional[List[Validator]] = None, + default_value: Optional[Any] = None, + ): self.key = key self.validators = validators self.default_value = default_value @@ -58,9 +60,9 @@ def validators(self, validators: Optional[List[Validator]]): if validators is None: self._validators = [] else: - if ((not isinstance(validators, Iterable)) or - (not all(isinstance(validator, Validator) - for validator in validators))): + if (not isinstance(validators, Iterable)) or ( + not all(isinstance(validator, Validator) for validator in validators) + ): raise TypeError("Validators must be a list of validator.") self._validators = list(validators) @@ -89,9 +91,7 @@ def attach(self, df: DataFrame): self._value = value def clear(self): - """Clear stored value. - - """ + """Clear stored value.""" self._value = None def __repr__(self) -> str: diff --git a/src/anml/data/prototype.py b/src/anml/data/prototype.py index aa2fe95..646c2e7 100644 --- a/src/anml/data/prototype.py +++ b/src/anml/data/prototype.py @@ -34,8 +34,9 @@ def components(self, components: Dict[str, Component]): if not isinstance(name, str): raise TypeError("Components key must be a string.") if not isinstance(component, Component): - raise TypeError(f"Components {name} value must be a instance " - "of Component") + raise TypeError( + f"Components {name} value must be a instance " "of Component" + ) self._components = components def attach(self, df: DataFrame): @@ -51,9 +52,7 @@ def attach(self, df: DataFrame): getattr(self, name).attach(df) def clear(self): - """Clear stored value for each component. - - """ + """Clear stored value for each component.""" for name in self.components.keys(): getattr(self, name).clear() diff --git a/src/anml/data/validator.py b/src/anml/data/validator.py index 8d4a48e..37612ee 100644 --- a/src/anml/data/validator.py +++ b/src/anml/data/validator.py @@ -19,9 +19,7 @@ def __repr__(self) -> str: class NoNans(Validator): - """Validate there is no 'nan's in the array. - - """ + """Validate there is no 'nan's in the array.""" def __call__(self, key: str, value: NDArray): if np.isnan(value).any(): @@ -29,9 +27,7 @@ def __call__(self, key: str, value: NDArray): class Positive(Validator): - """Validate there is no non-poisitive value in the array. - - """ + """Validate there is no non-poisitive value in the array.""" def __call__(self, key: str, value: NDArray): if (value <= 0).any(): @@ -39,9 +35,7 @@ def __call__(self, key: str, value: NDArray): class Unique(Validator): - """Validate all the values in the array are unique. - - """ + """Validate all the values in the array are unique.""" def __call__(self, key: str, value: NDArray): if len(np.unique(value)) < value.shape[0]: diff --git a/src/anml/getter/prior.py b/src/anml/getter/prior.py index 1a93f7b..cb95a61 100644 --- a/src/anml/getter/prior.py +++ b/src/anml/getter/prior.py @@ -106,12 +106,14 @@ class SplinePriorGetter: """ - def __init__(self, - prior: Prior, - size: int = 100, - order: int = 0, - domain: Tuple[float, float] = (0.0, 1.0), - domain_type: str = "rel"): + def __init__( + self, + prior: Prior, + size: int = 100, + order: int = 0, + domain: Tuple[float, float] = (0.0, 1.0), + domain_type: str = "rel", + ): self.prior = prior self.size = size self.order = order @@ -144,8 +146,9 @@ def order(self, order: int): def domain(self, domain: Tuple[float, float]): domain = tuple(domain) if len(domain) != 2: - raise ValueError("Domain must contains two numbers for lower and " - "upper bound.") + raise ValueError( + "Domain must contains two numbers for lower and " "upper bound." + ) domain_lb, domain_ub = domain if domain_lb > domain_ub: raise ValueError("Domain lb must be less than or equal to ub.") @@ -182,9 +185,11 @@ def get_prior(self, spline: XSpline) -> Prior: knots_lb, knots_ub = spline.knots[0], spline.knots[-1] domain_lb, domain_ub = self.domain if self.domain_type == "rel": - domain_lb = knots_lb + (knots_ub - knots_lb)*domain_lb - domain_ub = knots_lb + (knots_ub - knots_lb)*domain_ub + domain_lb = knots_lb + (knots_ub - knots_lb) * domain_lb + domain_ub = knots_lb + (knots_ub - knots_lb) * domain_ub points = np.linspace(domain_lb, domain_ub, self.size) - self.prior.mat = spline.design_dmat(points, order=self.order, - l_extra=True, r_extra=True) + self.prior.mat = spline.get_design_mat( + points, + order=self.order, + ) return self.prior diff --git a/src/anml/getter/spline.py b/src/anml/getter/spline.py index 87c7d9a..fe0a171 100644 --- a/src/anml/getter/spline.py +++ b/src/anml/getter/spline.py @@ -3,11 +3,12 @@ import numpy as np from numpy.typing import NDArray from xspline import XSpline +from typing import Optional class SplineGetter: - """Spline getter for :class:`XSpline` instance. Given the settings of the - spline, when attach the data it can infer the knots position, construct and + """Spline getter for :class:`XSpline` instance. Given the settings of the + spline, when attach the data it can infer the knots position, construct and return an instance of :class:`XSpline`. Parameters @@ -17,10 +18,10 @@ class SplineGetter: used differently. degree Degree of the spline. Default to be 3. - l_linear - If `True`, spline will use left linear tail. Default to be `False`. - r_linear - If `True`, spline will use right linear tail. Default to be `False`. + ldegree + Left extrapolation polynomial degree. + rdegree + Right extrapolation polynomial degree. include_first_basis If `True`, spline will include the first basis of the spline. Default to be `True`. @@ -47,34 +48,33 @@ class SplineGetter: """ - def __init__(self, - knots: NDArray, - degree: int = 3, - l_linear: bool = False, - r_linear: bool = False, - include_first_basis: bool = False, - knots_type: str = "abs"): + def __init__( + self, + knots: NDArray, + degree: int = 3, + ldegree: Optional[int] = None, + rdegree: Optional[int] = None, + knots_type: str = "abs", + ): self.knots = knots self.degree = degree - self.l_linear = l_linear - self.r_linear = r_linear - self.include_first_basis = include_first_basis + self.ldegree = ldegree + self.rdegree = rdegree self.knots_type = knots_type @knots_type.setter def knots_type(self, knots_type: str): if knots_type not in ["abs", "rel_domain", "rel_freq"]: - raise ValueError("Knots type must be one of 'abs', 'rel_domain' or 'rel_freq'.") + raise ValueError( + "Knots type must be one of 'abs', 'rel_domain' or 'rel_freq'." + ) self._knots_type = knots_type - @property - def num_spline_bases(self) -> int: - """Number of the spline bases. - - """ - inner_knots = self.knots[int(self.l_linear): - len(self.knots) - int(self.r_linear)] - return len(inner_knots) - 2 + self.degree + int(self.include_first_basis) + # @property + # def num_spline_bases(self) -> int: + # """Number of the spline bases.""" + # inner_knots = self.knots[self.ldegree : len(self.knots) - self.rdegree] + # return len(inner_knots) - 2 + self.degree def get_spline(self, data: NDArray) -> XSpline: """Get spline instance given data array. @@ -94,12 +94,13 @@ def get_spline(self, data: NDArray) -> XSpline: else: if self.knots_type == "rel_domain": lb, ub = data.min(), data.max() - knots = lb + self.knots*(ub - lb) + knots = lb + self.knots * (ub - lb) else: knots = np.quantile(data, self.knots) - return XSpline(knots, - self.degree, - l_linear=self.l_linear, - r_linear=self.r_linear, - include_first_basis=self.include_first_basis) + return XSpline( + knots, + self.degree, + ldegree=self.ldegree, + rdegree=self.rdegree, + ) diff --git a/src/anml/parameter/main.py b/src/anml/parameter/main.py index d9cce1f..433d317 100644 --- a/src/anml/parameter/main.py +++ b/src/anml/parameter/main.py @@ -81,11 +81,13 @@ class Parameter: """ - def __init__(self, - variables: List[Variable], - transform: Optional[SmoothMapping] = None, - offset: Optional[Union[str, Component]] = None, - priors: Optional[List[Prior]] = None): + def __init__( + self, + variables: List[Variable], + transform: Optional[SmoothMapping] = None, + offset: Optional[Union[str, Component]] = None, + priors: Optional[List[Prior]] = None, + ): self.variables = variables self.transform = transform self.offset = offset @@ -98,15 +100,18 @@ def __init__(self, def variables(self, variables: List[Variable]): variables = list(variables) if not all(isinstance(variable, Variable) for variable in variables): - raise TypeError("Parameter input variables must be a list of " - "instances of Variable.") + raise TypeError( + "Parameter input variables must be a list of " "instances of Variable." + ) self._variables = variables @transform.setter def transform(self, transform: Optional[SmoothMapping]): if transform is not None and not isinstance(transform, SmoothMapping): - raise TypeError("Parameter input transform must be an instance " - "of SmoothMapping or None.") + raise TypeError( + "Parameter input transform must be an instance " + "of SmoothMapping or None." + ) if transform is None: transform = Identity() self._transform = transform @@ -115,8 +120,10 @@ def transform(self, transform: Optional[SmoothMapping]): def offset(self, offset: Optional[Union[str, Component]]): if offset is not None: if not isinstance(offset, (str, Component)): - raise TypeError("Parameter input offset has to be a string or " - "an instance of Component.") + raise TypeError( + "Parameter input offset has to be a string or " + "an instance of Component." + ) if isinstance(offset, str): offset = Component(offset, validators=[NoNans()]) self._offset = offset @@ -125,8 +132,9 @@ def offset(self, offset: Optional[Union[str, Component]]): def priors(self, priors: Optional[List[Prior]]): priors = list(priors) if priors is not None else [] if not all(isinstance(prior, Prior) for prior in priors): - raise TypeError("Parameter input priors must be a list of " - "instances of Prior.") + raise TypeError( + "Parameter input priors must be a list of " "instances of Prior." + ) self._priors = priors @property @@ -149,9 +157,9 @@ def attach(self, df: DataFrame): """ if self.offset is not None: self.offset.attach(df) - self.design_mat = np.hstack([ - variable.get_design_mat(df) for variable in self.variables - ]) + self.design_mat = np.hstack( + [variable.get_design_mat(df) for variable in self.variables] + ) for prior_category in ["direct", "linear"]: for prior_type in ["UniformPrior", "GaussianPrior"]: getattr(self, f"_get_{prior_category}_prior")(prior_type) @@ -169,8 +177,12 @@ def _get_direct_prior(self, prior_type: str): Given name of the prior type. """ - params = np.hstack([variable.get_direct_prior_params(prior_type) - for variable in self.variables]) + params = np.hstack( + [ + variable.get_direct_prior_params(prior_type) + for variable in self.variables + ] + ) self.prior_dict["direct"][prior_type] = get_prior_type(prior_type)( params[0], params[1] ) @@ -188,8 +200,14 @@ def _get_linear_prior(self, prior_type: str): """ - params, mat = tuple(zip(*[variable.get_linear_prior_params(prior_type) - for variable in self.variables])) + params, mat = tuple( + zip( + *[ + variable.get_linear_prior_params(prior_type) + for variable in self.variables + ] + ) + ) params = np.hstack(params) mat = block_diag(*mat) @@ -208,10 +226,9 @@ def _get_linear_prior(self, prior_type: str): params[0], params[1], mat ) - def get_params(self, - x: NDArray, - df: Optional[DataFrame] = None, - order: int = 0) -> NDArray: + def get_params( + self, x: NDArray, df: Optional[DataFrame] = None, order: int = 0 + ) -> NDArray: """Compute and return the parameter. Denote :math:`x` as the coefficients, :math:`A` as the design matrix, :math:`z` as the offset, :math:`f` as the transformation function, the parameter :math:`p` can @@ -258,9 +275,9 @@ def get_params(self, return z if order == 1: return z[:, np.newaxis] * self.design_mat - return (z[:, np.newaxis, np.newaxis] * - (self.design_mat[..., np.newaxis] * - self.design_mat[:, np.newaxis, :])) + return z[:, np.newaxis, np.newaxis] * ( + self.design_mat[..., np.newaxis] * self.design_mat[:, np.newaxis, :] + ) def prior_objective(self, x: NDArray) -> float: """Objective function from the prior. @@ -323,6 +340,8 @@ def prior_hessian(self, x: NDArray) -> NDArray: return value def __repr__(self) -> str: - return (f"{type(self).__name__}(variables={self.variables}, " - f"transform={self.transform}, offset={self.offset}, " - f"priors={self.priors})") + return ( + f"{type(self).__name__}(variables={self.variables}, " + f"transform={self.transform}, offset={self.offset}, " + f"priors={self.priors})" + ) diff --git a/src/anml/parameter/smoothmapping.py b/src/anml/parameter/smoothmapping.py index 55ab55a..deeb494 100644 --- a/src/anml/parameter/smoothmapping.py +++ b/src/anml/parameter/smoothmapping.py @@ -7,7 +7,7 @@ class SmoothMapping(ABC): - """Smooth mapping that contains function, first and second derivative + """Smooth mapping that contains function, first and second derivative information. """ @@ -49,15 +49,11 @@ def __repr__(self) -> str: class Identity(SmoothMapping): - """Identity smooth mapping. - - """ + """Identity smooth mapping.""" @property def inverse(self) -> SmoothMapping: - """Inverse of :class:`Identity` is :class:`Identity`. - - """ + """Inverse of :class:`Identity` is :class:`Identity`.""" return Identity() def __call__(self, x: NDArray, order: int = 0) -> NDArray: @@ -70,15 +66,11 @@ def __call__(self, x: NDArray, order: int = 0) -> NDArray: class Exp(SmoothMapping): - """Exponential smooth mapping. - - """ + """Exponential smooth mapping.""" @property def inverse(self) -> SmoothMapping: - """Inverse of :class:`Exp` is :class:`Log`. - - """ + """Inverse of :class:`Exp` is :class:`Log`.""" return Log() def __call__(self, x: NDArray, order: int = 0) -> NDArray: @@ -98,9 +90,7 @@ class Log(SmoothMapping): @property def inverse(self) -> SmoothMapping: - """Inverse of :class:`Log` is :class:`Exp`. - - """ + """Inverse of :class:`Log` is :class:`Exp`.""" return Exp() def __call__(self, x: NDArray, order: int = 0) -> NDArray: @@ -123,9 +113,7 @@ class Expit(SmoothMapping): @property def inverse(self) -> SmoothMapping: - """Inverse of :class:`Expit` is :class:`Logit`. - - """ + """Inverse of :class:`Expit` is :class:`Logit`.""" return Logit() def __call__(self, x: NDArray, order: int = 0) -> NDArray: @@ -134,8 +122,8 @@ def __call__(self, x: NDArray, order: int = 0) -> NDArray: if order == 0: return 1 / (1 + z) elif order == 1: - return z / (1 + z)**2 - return z * (z - 1) / (z + 1)**3 + return z / (1 + z) ** 2 + return z * (z - 1) / (z + 1) ** 3 class Logit(SmoothMapping): @@ -152,18 +140,17 @@ class Logit(SmoothMapping): @property def inverse(self) -> SmoothMapping: - """Inverse of :class:`Logit` is :class:`Expit`. - - """ + """Inverse of :class:`Logit` is :class:`Expit`.""" return Expit() def __call__(self, x: NDArray, order: int = 0) -> NDArray: self._validate_order(order) if not ((x > 0).all() and (x < 1).all()): - raise ValueError("All values for logit function must be strictly " - "between 0 and 1.") + raise ValueError( + "All values for logit function must be strictly " "between 0 and 1." + ) if order == 0: return np.log(x / (1 - x)) elif order == 1: return 1 / (x * (1 - x)) - return (2 * x - 1) / (x * (1 - x))**2 + return (2 * x - 1) / (x * (1 - x)) ** 2 diff --git a/src/anml/prior/main.py b/src/anml/prior/main.py index be5c1af..472350c 100644 --- a/src/anml/prior/main.py +++ b/src/anml/prior/main.py @@ -44,9 +44,7 @@ class Prior: """ - def __init__(self, - params: List[ArrayLike], - mat: Optional[ArrayLike] = None): + def __init__(self, params: List[ArrayLike], mat: Optional[ArrayLike] = None): self.params = params self.mat = mat @@ -170,23 +168,19 @@ class GaussianPrior(Prior): """ - def __init__(self, - mean: ArrayLike, - sd: ArrayLike, - mat: Optional[ArrayLike] = None): + def __init__(self, mean: ArrayLike, sd: ArrayLike, mat: Optional[ArrayLike] = None): super().__init__([mean, sd], mat=mat) if not (self.params[1] > 0.0).all(): - raise ValueError("Gaussian prior standard deviations must be " - "positive.") + raise ValueError("Gaussian prior standard deviations must be " "positive.") self.mean = self.params[0] self.sd = self.params[1] def objective(self, x: NDArray) -> float: if self.mat is None: - return 0.5*np.sum(((x - self.mean) / self.sd)**2) + return 0.5 * np.sum(((x - self.mean) / self.sd) ** 2) if self.mat.size == 0: return 0.0 - return 0.5*np.sum(((self.mat.dot(x) - self.mean) / self.sd)**2) + return 0.5 * np.sum(((self.mat.dot(x) - self.mean) / self.sd) ** 2) def gradient(self, x: NDArray) -> NDArray: if self.mat is None: @@ -243,13 +237,12 @@ class UniformPrior(Prior): """ - def __init__(self, - lb: ArrayLike, - ub: ArrayLike, - mat: Optional[ArrayLike] = None): + def __init__(self, lb: ArrayLike, ub: ArrayLike, mat: Optional[ArrayLike] = None): super().__init__([lb, ub], mat=mat) if not (self.params[0] <= self.params[1]).all(): - raise ValueError("Uniform prior lower bounds have to be less than " - "or equal to the upper bounds.") + raise ValueError( + "Uniform prior lower bounds have to be less than " + "or equal to the upper bounds." + ) self.lb = self.params[0] self.ub = self.params[1] diff --git a/src/anml/prior/utils.py b/src/anml/prior/utils.py index 1fa4228..4bcfc45 100644 --- a/src/anml/prior/utils.py +++ b/src/anml/prior/utils.py @@ -27,10 +27,10 @@ def get_prior_type(prior_type: str) -> Type: return getattr(anml.prior.main, prior_type) -def filter_priors(priors: List[Prior], - prior_type: str, - with_mat: Optional[bool] = None) -> List[Prior]: - """Filter priors from a list of priors by their type and do they contain +def filter_priors( + priors: List[Prior], prior_type: str, with_mat: Optional[bool] = None +) -> List[Prior]: + """Filter priors from a list of priors by their type and do they contain linear map or not. Parameters @@ -42,7 +42,7 @@ def filter_priors(priors: List[Prior], Given prior type name. with_mat If the filtered priors are all contain a linear map. Default to `None`. - If `with_mat=None`, the final list will include priors that both + If `with_mat=None`, the final list will include priors that both contain or not contain the linear map. Returns diff --git a/src/anml/variable/main.py b/src/anml/variable/main.py index c176fbd..f7381d0 100644 --- a/src/anml/variable/main.py +++ b/src/anml/variable/main.py @@ -54,17 +54,19 @@ class Variable: """ - def __init__(self, - component: Union[str, Component], - priors: Optional[List[Prior]] = None): + def __init__( + self, component: Union[str, Component], priors: Optional[List[Prior]] = None + ): self.component = component self.priors = priors @component.setter def component(self, component: Union[str, Component]): if not isinstance(component, (str, Component)): - raise TypeError("Variable input component has to be a string or " - "an instance of Component.") + raise TypeError( + "Variable input component has to be a string or " + "an instance of Component." + ) if isinstance(component, str): component = Component(component, validators=[NoNans()]) self._component = component @@ -73,15 +75,14 @@ def component(self, component: Union[str, Component]): def priors(self, priors: Optional[List[Prior]]): priors = list(priors) if priors is not None else [] if not all(isinstance(prior, self._prior_types) for prior in priors): - raise TypeError("Variable input priors must be a list of " - "instances of Prior.") + raise TypeError( + "Variable input priors must be a list of " "instances of Prior." + ) self._priors = priors @property def size(self) -> Optional[int]: - """Size of the variable. - - """ + """Size of the variable.""" return 1 def attach(self, df: DataFrame): @@ -151,7 +152,7 @@ def get_direct_prior_params(self, prior_type: str) -> NDArray: def get_linear_prior_params(self, prior_type: str) -> Tuple[NDArray, NDArray]: """Get the linear prior parameters. The linear prior refers to the - priors that contain a linear map. If there is no linear prior in the + priors that contain a linear map. If there is no linear prior in the prior list, we will return empty arrays that match the size of the variable. @@ -183,4 +184,6 @@ def get_linear_prior_params(self, prior_type: str) -> Tuple[NDArray, NDArray]: return params, mat def __repr__(self) -> str: - return f"{type(self).__name__}(component={self.component}, priors={self.priors})" + return ( + f"{type(self).__name__}(component={self.component}, priors={self.priors})" + ) diff --git a/src/anml/variable/spline.py b/src/anml/variable/spline.py index fddbd96..086bd7a 100644 --- a/src/anml/variable/spline.py +++ b/src/anml/variable/spline.py @@ -54,23 +54,27 @@ class SplineVariable(Variable): """ _prior_types: Tuple[Type, ...] = SplineVariablePrior.__args__ - def __init__(self, - component: Union[str, Component], - spline: Union[XSpline, SplineGetter], - priors: Optional[List[SplineVariablePrior]] = None): + def __init__( + self, + component: Union[str, Component], + spline: Union[XSpline, SplineGetter], + priors: Optional[List[SplineVariablePrior]] = None, + ): super().__init__(component, priors) self.spline = spline @spline.setter def spline(self, spline: Union[XSpline, SplineGetter]): if not isinstance(spline, (XSpline, SplineGetter)): - raise TypeError("Spline variable input spline must be an instance " - "of XSpline or SplineGetter.") + raise TypeError( + "Spline variable input spline must be an instance " + "of XSpline or SplineGetter." + ) self._spline = spline - @property - def size(self) -> int: - return self.spline.num_spline_bases + # @property + # def size(self) -> int: + # return self.spline.num_spline_bases def attach(self, df: DataFrame): """Attach the data to variable. It will attach data to the component. @@ -91,4 +95,6 @@ def attach(self, df: DataFrame): def get_design_mat(self, df: DataFrame) -> NDArray: self.attach(df) - return self.spline.design_mat(self.component.value, l_extra=True, r_extra=True) + return self.spline.get_design_mat( + self.component.value, + ) diff --git a/tests/data/test_component.py b/tests/data/test_component.py index fb1892a..6d140d9 100644 --- a/tests/data/test_component.py +++ b/tests/data/test_component.py @@ -8,9 +8,7 @@ @pytest.fixture def df(): np.random.seed(123) - return pd.DataFrame({ - "col": np.random.randn(10) - }) + return pd.DataFrame({"col": np.random.randn(10)}) @pytest.mark.parametrize("key", ["col"]) @@ -29,8 +27,7 @@ def test_key_setter_illegal(key): def test_validators_setter_legal(validators): comp = Component("col", validators) assert isinstance(comp.validators, list) - assert all(isinstance(validator, Validator) - for validator in comp.validators) + assert all(isinstance(validator, Validator) for validator in comp.validators) @pytest.mark.parametrize("validators", [1, [1, 2, 3]]) diff --git a/tests/data/test_data_prototype.py b/tests/data/test_data_prototype.py index 5f31699..d4316fc 100644 --- a/tests/data/test_data_prototype.py +++ b/tests/data/test_data_prototype.py @@ -9,17 +9,14 @@ @pytest.fixture def df(): np.random.seed(123) - return DataFrame({ - "obs": np.random.randn(5), - "obs_se": np.random.rand(5) - }) + return DataFrame({"obs": np.random.randn(5), "obs_se": np.random.rand(5)}) @pytest.fixture def components(): return { "obs": Component("obs", [NoNans()]), - "obs_se": Component("obs_se", [NoNans(), Positive()]) + "obs_se": Component("obs_se", [NoNans(), Positive()]), } diff --git a/tests/data/test_example.py b/tests/data/test_example.py index a54969a..431db26 100644 --- a/tests/data/test_example.py +++ b/tests/data/test_example.py @@ -6,8 +6,7 @@ @pytest.fixture def df(): - return pd.DataFrame({"obs": np.random.randn(5), - "obs_se": np.ones(5)}) + return pd.DataFrame({"obs": np.random.randn(5), "obs_se": np.ones(5)}) @pytest.mark.parametrize("obs", ["obs"]) diff --git a/tests/getter/test_spline.py b/tests/getter/test_spline.py index 5358b0a..e67429d 100644 --- a/tests/getter/test_spline.py +++ b/tests/getter/test_spline.py @@ -12,16 +12,15 @@ def data(): @pytest.mark.parametrize("knots", [np.linspace(0.0, 1.0, 5)]) @pytest.mark.parametrize("degree", [3]) -@pytest.mark.parametrize("l_linear", [True, False]) -@pytest.mark.parametrize("r_linear", [True, False]) -@pytest.mark.parametrize("include_first_basis", [True, False]) +@pytest.mark.parametrize("ldegree", [0, 1]) +@pytest.mark.parametrize("rdegree", [0, 1]) @pytest.mark.parametrize("knots_type", ["rel_domain", "rel_freq", "abs"]) -def test_splinegetter(data, knots, degree, l_linear, r_linear, include_first_basis, knots_type): - splinegetter = SplineGetter(knots, degree, l_linear, r_linear, include_first_basis, knots_type) +def test_splinegetter(data, knots, degree, ldegree, rdegree, knots_type): + splinegetter = SplineGetter(knots, degree, ldegree, rdegree, knots_type) spline = splinegetter.get_spline(data) assert isinstance(spline, XSpline) - assert spline.num_spline_bases == splinegetter.num_spline_bases + # assert spline.num_spline_bases == splinegetter.num_spline_bases if knots_type.startswith("rel"): assert np.isclose(spline.knots[0], data.min()) assert np.isclose(spline.knots[-1], data.max()) diff --git a/tests/parameter/test_main.py b/tests/parameter/test_main.py index c6dce24..be7c511 100644 --- a/tests/parameter/test_main.py +++ b/tests/parameter/test_main.py @@ -65,9 +65,14 @@ def test_offset_illegal(variables, offset): Parameter(variables=variables, offset=offset) -@pytest.mark.parametrize("priors", [None, - [GaussianPrior(mean=np.zeros(2), sd=np.ones(2))], - [UniformPrior(lb=np.zeros(2), ub=np.ones(2))]]) +@pytest.mark.parametrize( + "priors", + [ + None, + [GaussianPrior(mean=np.zeros(2), sd=np.ones(2))], + [UniformPrior(lb=np.zeros(2), ub=np.ones(2))], + ], +) def test_priors_setter_legal(variables, priors): p = Parameter(variables=variables, priors=priors) if priors is None: diff --git a/tests/parameter/test_smoothmapping.py b/tests/parameter/test_smoothmapping.py index 009ebc5..4ef6067 100644 --- a/tests/parameter/test_smoothmapping.py +++ b/tests/parameter/test_smoothmapping.py @@ -10,15 +10,15 @@ def ad_jacobian(fun, x, out_shape=(), eps=1e-10): g = np.zeros((*out_shape, *x.shape)) if len(out_shape) == 0: for i in np.ndindex(x.shape): - c[i] += eps*1j - g[i] = fun(c).imag/eps - c[i] -= eps*1j + c[i] += eps * 1j + g[i] = fun(c).imag / eps + c[i] -= eps * 1j else: for j in np.ndindex(out_shape): for i in np.ndindex(x.shape): - c[i] += eps*1j - g[j][i] = fun(c)[j].imag/eps - c[i] -= eps*1j + c[i] += eps * 1j + g[j][i] = fun(c)[j].imag / eps + c[i] -= eps * 1j return g @@ -71,10 +71,9 @@ def test_log_illegal_input(x): log(x) -@pytest.mark.parametrize("x", [np.array([-1.0]), - np.array([0.0]), - np.array([1.0]), - np.array([2.0])]) +@pytest.mark.parametrize( + "x", [np.array([-1.0]), np.array([0.0]), np.array([1.0]), np.array([2.0])] +) def test_logit_illegal_input(x): logit = Logit() with pytest.raises(ValueError): diff --git a/tests/prior/test_main.py b/tests/prior/test_main.py index 6ab216e..9d244f7 100644 --- a/tests/prior/test_main.py +++ b/tests/prior/test_main.py @@ -8,15 +8,15 @@ def ad_jacobian(fun, x, out_shape=(), eps=1e-10): g = np.zeros((*out_shape, *x.shape)) if len(out_shape) == 0: for i in np.ndindex(x.shape): - c[i] += eps*1j - g[i] = fun(c).imag/eps - c[i] -= eps*1j + c[i] += eps * 1j + g[i] = fun(c).imag / eps + c[i] -= eps * 1j else: for j in np.ndindex(out_shape): for i in np.ndindex(x.shape): - c[i] += eps*1j - g[j][i] = fun(c)[j].imag/eps - c[i] -= eps*1j + c[i] += eps * 1j + g[j][i] = fun(c)[j].imag / eps + c[i] -= eps * 1j return g @@ -68,7 +68,7 @@ def test_gaussian_prior_init_illegal(mean, sd): def test_gaussian_prior_objective(x): mean, sd = 2.0, 0.5 prior = GaussianPrior(mean, sd) - assert prior.objective(x) == 0.5*np.sum(((x - mean) / sd)**2) + assert prior.objective(x) == 0.5 * np.sum(((x - mean) / sd) ** 2) @pytest.mark.parametrize("x", [np.array([1.0]), np.array([2.0])]) diff --git a/tests/prior/test_utils.py b/tests/prior/test_utils.py index 67a600d..7b8548d 100644 --- a/tests/prior/test_utils.py +++ b/tests/prior/test_utils.py @@ -6,15 +6,18 @@ @pytest.fixture def priors(): - return [GaussianPrior(mean=0.0, sd=1.0), - UniformPrior(lb=0.0, ub=1.0), - GaussianPrior(mean=0.0, sd=1.0, mat=np.identity(5)), - UniformPrior(lb=0.0, ub=1.0, mat=np.identity(5))] + return [ + GaussianPrior(mean=0.0, sd=1.0), + UniformPrior(lb=0.0, ub=1.0), + GaussianPrior(mean=0.0, sd=1.0, mat=np.identity(5)), + UniformPrior(lb=0.0, ub=1.0, mat=np.identity(5)), + ] -@pytest.mark.parametrize(("prior_type", "tr_prior_type"), - [("GaussianPrior", GaussianPrior), - ("UniformPrior", UniformPrior)]) +@pytest.mark.parametrize( + ("prior_type", "tr_prior_type"), + [("GaussianPrior", GaussianPrior), ("UniformPrior", UniformPrior)], +) def test_get_prior_type(prior_type, tr_prior_type): prior_type = get_prior_type(prior_type) assert prior_type == tr_prior_type diff --git a/tests/variable/test_main.py b/tests/variable/test_main.py index aa773f0..648c9ad 100644 --- a/tests/variable/test_main.py +++ b/tests/variable/test_main.py @@ -29,9 +29,9 @@ def test_component_setter_illegal(component): Variable(component) -@pytest.mark.parametrize("priors", [None, - [GaussianPrior(mean=0.0, sd=1.0)], - [UniformPrior(lb=0.0, ub=1.0)]]) +@pytest.mark.parametrize( + "priors", [None, [GaussianPrior(mean=0.0, sd=1.0)], [UniformPrior(lb=0.0, ub=1.0)]] +) def test_priors_setter_legal(priors): v = Variable("cov", priors=priors) if priors is None: @@ -80,8 +80,15 @@ def test_get_direct_uniform_prior_params(v, priors): assert np.allclose(params, np.array([[0.0], [1.0]])) -@pytest.mark.parametrize("priors", [[GaussianPrior(0.0, 1.0, np.ones((3, 1))), - GaussianPrior(0.0, 1.0, np.ones((3, 1)))]]) +@pytest.mark.parametrize( + "priors", + [ + [ + GaussianPrior(0.0, 1.0, np.ones((3, 1))), + GaussianPrior(0.0, 1.0, np.ones((3, 1))), + ] + ], +) def test_get_linear_gaussian_prior_params(v, priors): v.priors = priors params = v.get_linear_prior_params("GaussianPrior") @@ -89,8 +96,15 @@ def test_get_linear_gaussian_prior_params(v, priors): assert np.allclose(params[1], np.ones((6, 1))) -@pytest.mark.parametrize("priors", [[UniformPrior(0.0, 1.0, np.ones((3, 1))), - UniformPrior(0.0, 1.0, np.ones((3, 1)))]]) +@pytest.mark.parametrize( + "priors", + [ + [ + UniformPrior(0.0, 1.0, np.ones((3, 1))), + UniformPrior(0.0, 1.0, np.ones((3, 1))), + ] + ], +) def test_get_linear_uniform_prior_params(v, priors): v.priors = priors params = v.get_linear_prior_params("UniformPrior") diff --git a/tests/variable/test_spline.py b/tests/variable/test_spline.py index 827a945..3cdc5be 100644 --- a/tests/variable/test_spline.py +++ b/tests/variable/test_spline.py @@ -13,9 +13,13 @@ def data(): return pd.DataFrame({"cov": np.random.randn(5)}) -@pytest.mark.parametrize("spline", - [XSpline(knots=np.array([0.0, 0.5, 1.0]), degree=3), - SplineGetter(knots=np.array([0.0, 0.5, 1.0]), degree=3)]) +@pytest.mark.parametrize( + "spline", + [ + XSpline(knots=np.array([0.0, 0.5, 1.0]), degree=3), + SplineGetter(knots=np.array([0.0, 0.5, 1.0]), degree=3), + ], +) def test_spline_setter_legal(spline): sv = SplineVariable("cov", spline=spline) assert sv.spline.degree == 3 @@ -29,19 +33,22 @@ def test_spline_setter_illegal(spline): @pytest.mark.parametrize("knots", [np.linspace(0.0, 1.0, 5)]) @pytest.mark.parametrize("degree", [3]) -@pytest.mark.parametrize("l_linear", [True, False]) -@pytest.mark.parametrize("r_linear", [True, False]) -@pytest.mark.parametrize("include_first_basis", [True, False]) -def test_spline_size(knots, degree, l_linear, r_linear, include_first_basis): - splinegetter = SplineGetter(knots, degree, l_linear, r_linear, include_first_basis) - spline = XSpline(knots, degree, l_linear, r_linear, include_first_basis) +@pytest.mark.parametrize("ldegree", [0, 1]) +@pytest.mark.parametrize("rdegree", [0, 1]) +def test_spline_size(knots, degree, ldegree, rdegree, include_first_basis): + splinegetter = SplineGetter(knots, degree, ldegree, rdegree, include_first_basis) + spline = XSpline(knots, degree, ldegree, rdegree, include_first_basis) sv = SplineVariable("cov", spline=splinegetter) - assert sv.size == spline.num_spline_bases + # assert sv.size == spline.num_spline_bases -@pytest.mark.parametrize("spline", [SplineGetter(knots=np.array([0.0, 0.5, 1.0]), degree=3)]) -@pytest.mark.parametrize("priors", [[SplinePriorGetter(UniformPrior(0.0, 1.0), order=1)]]) +@pytest.mark.parametrize( + "spline", [SplineGetter(knots=np.array([0.0, 0.5, 1.0]), degree=3)] +) +@pytest.mark.parametrize( + "priors", [[SplinePriorGetter(UniformPrior(0.0, 1.0), order=1)]] +) def test_attach(spline, priors, data): sv = SplineVariable("cov", spline, priors=priors) sv.attach(data) @@ -49,7 +56,9 @@ def test_attach(spline, priors, data): assert isinstance(sv.priors[0], UniformPrior) -@pytest.mark.parametrize("spline", [SplineGetter(knots=np.array([0.0, 0.5, 1.0]), degree=3)]) +@pytest.mark.parametrize( + "spline", [SplineGetter(knots=np.array([0.0, 0.5, 1.0]), degree=3)] +) def test_get_design_mat(data, spline): sv = SplineVariable("cov", spline=spline) design_mat = sv.get_design_mat(data) diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 8e2d198..0000000 --- a/tox.ini +++ /dev/null @@ -1,7 +0,0 @@ -[tox] -envlist = py37 -[pytest] -xfail_strict = true -[testenv] -deps=-rrequirements.txt -commands=py.test From 5a6145a889f6d0631c5acec21256c0df870a7a5d Mon Sep 17 00:00:00 2001 From: saal Date: Fri, 11 Oct 2024 10:36:29 -0700 Subject: [PATCH 2/4] Updated var and getter Co-authored-by: Sameer --- src/anml/getter/spline.py | 26 +++++++++++++------------- src/anml/variable/spline.py | 19 ++++++++++++++++--- tests/variable/test_spline.py | 8 ++++---- 3 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/anml/getter/spline.py b/src/anml/getter/spline.py index fe0a171..b63541c 100644 --- a/src/anml/getter/spline.py +++ b/src/anml/getter/spline.py @@ -1,6 +1,5 @@ -from operator import attrgetter - import numpy as np +from operator import attrgetter from numpy.typing import NDArray from xspline import XSpline from typing import Optional @@ -22,12 +21,9 @@ class SplineGetter: Left extrapolation polynomial degree. rdegree Right extrapolation polynomial degree. - include_first_basis - If `True`, spline will include the first basis of the spline. Default - to be `True`. knots_type : {'abs', 'rel_domain', 'rel_freq'} Type of the spline knots. Can only be choosen from three options, - `'abs'`, `'rel_domian'` and `'rel_freq'`. When it is `'abs'` + `'abs'`, `'rel_domain'` and `'rel_freq'`. When it is `'abs'` which standards for absolute, the knots will be used as it is. When it is `rel_domain` which standards for relative domain, the knots requires to be between 0 and 1, and will be interpreted as the @@ -58,8 +54,8 @@ def __init__( ): self.knots = knots self.degree = degree - self.ldegree = ldegree - self.rdegree = rdegree + self.ldegree = min(ldegree if ldegree is not None else 0, len(knots) - 2) + self.rdegree = min(rdegree if rdegree is not None else 0, len(knots) - 2) self.knots_type = knots_type @knots_type.setter @@ -70,11 +66,15 @@ def knots_type(self, knots_type: str): ) self._knots_type = knots_type - # @property - # def num_spline_bases(self) -> int: - # """Number of the spline bases.""" - # inner_knots = self.knots[self.ldegree : len(self.knots) - self.rdegree] - # return len(inner_knots) - 2 + self.degree + @property + def num_spline_bases(self) -> int: + """Number of the spline bases.""" + ldegree = self.ldegree or 0 + rdegree = self.rdegree or 0 + + inner_knots = self.knots[ldegree : len(self.knots) - rdegree] + + return len(inner_knots) - 1 + self.degree def get_spline(self, data: NDArray) -> XSpline: """Get spline instance given data array. diff --git a/src/anml/variable/spline.py b/src/anml/variable/spline.py index 086bd7a..9a25a25 100644 --- a/src/anml/variable/spline.py +++ b/src/anml/variable/spline.py @@ -72,9 +72,22 @@ def spline(self, spline: Union[XSpline, SplineGetter]): ) self._spline = spline - # @property - # def size(self) -> int: - # return self.spline.num_spline_bases + @property + def size(self) -> int: + """Number of the spline bases.""" + if isinstance(self.spline, XSpline): + knots = self.spline.knots + degree = self.spline.degree + ldegree = self.spline.ldegree or 0 + rdegree = self.spline.rdegree or 0 + inner_knots = knots[ldegree : len(knots) - rdegree] + return len(inner_knots) - 1 + degree + + elif isinstance(self.spline, SplineGetter): + return self.spline.num_spline_bases + + else: + raise TypeError("Unknown spline type") def attach(self, df: DataFrame): """Attach the data to variable. It will attach data to the component. diff --git a/tests/variable/test_spline.py b/tests/variable/test_spline.py index 3cdc5be..6ebf08c 100644 --- a/tests/variable/test_spline.py +++ b/tests/variable/test_spline.py @@ -35,12 +35,12 @@ def test_spline_setter_illegal(spline): @pytest.mark.parametrize("degree", [3]) @pytest.mark.parametrize("ldegree", [0, 1]) @pytest.mark.parametrize("rdegree", [0, 1]) -def test_spline_size(knots, degree, ldegree, rdegree, include_first_basis): - splinegetter = SplineGetter(knots, degree, ldegree, rdegree, include_first_basis) - spline = XSpline(knots, degree, ldegree, rdegree, include_first_basis) +def test_spline_size(knots, degree, ldegree, rdegree): + splinegetter = SplineGetter(knots, degree, ldegree, rdegree) + spline = XSpline(knots, degree, ldegree, rdegree) sv = SplineVariable("cov", spline=splinegetter) - # assert sv.size == spline.num_spline_bases + assert sv.size == splinegetter.num_spline_bases @pytest.mark.parametrize( From 6576d1cf46db8483c5a5e6825e657b212bbdadd7 Mon Sep 17 00:00:00 2001 From: saal Date: Fri, 11 Oct 2024 10:44:12 -0700 Subject: [PATCH 3/4] Updated build python --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b64e8bc..2e2c156 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,10 +7,10 @@ jobs: steps: - uses: actions/checkout@v2 - - name: Set up Python 3.8 + - name: Set up Python 3.12.4 uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: 3.12.4 - name: Install dependencies run: python -m pip install .[dev] --upgrade pip - name: Test with pytest From 1262f67afd947653f18393d779a99b9cedd6b936 Mon Sep 17 00:00:00 2001 From: saal Date: Fri, 11 Oct 2024 11:37:19 -0700 Subject: [PATCH 4/4] Updated version and tags --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 839aaa9..5672941 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ requires = [ [project] name = "anml" -version = "0.1.0" +version = "0.3.0" description = "This is a nonlinear modeling library." readme = "README.md" license = { text = "BSD-2-Clause" }