From 0e349a3b494e2101e5cd00bef2bebf0e2e323b14 Mon Sep 17 00:00:00 2001 From: SimonBoothroyd Date: Tue, 24 Oct 2023 08:46:54 -0400 Subject: [PATCH] Reset project in preparation of re-write (#42) --- .codecov.yml | 14 - .devcontainer/Dockerfile | 5 + .devcontainer/devcontainer.json | 4 + .github/CONTRIBUTING.md | 42 - .github/PULL_REQUEST_TEMPLATE.md | 12 - .github/dependabot.yml | 7 - .github/workflows/ci.yaml | 60 +- .github/workflows/lint.yaml | 47 - .gitignore | 9 +- .pre-commit-config.yaml | 38 + LICENSE | 2 +- MANIFEST.in | 6 - Makefile | 32 + descent/__init__.py | 12 +- descent/_version.py | 320 +++- descent/data/__init__.py | 3 - descent/data/data.py | 78 - descent/data/energy.py | 789 --------- descent/metrics.py | 17 - descent/models/__init__.py | 3 - descent/models/models.py | 18 - descent/models/smirnoff.py | 350 ---- descent/tests/__init__.py | 64 - descent/tests/conftest.py | 72 - descent/tests/data/__init__.py | 0 descent/tests/data/test_data.py | 48 - descent/tests/data/test_energy.py | 385 ----- descent/tests/geometric.py | 233 --- descent/tests/mocking/__init__.py | 0 descent/tests/mocking/qcdata.py | 136 -- descent/tests/mocking/systems.py | 45 - descent/tests/models/__init__.py | 0 descent/tests/models/test_smirnoff.py | 221 --- descent/tests/test_metrics.py | 26 - descent/tests/test_transforms.py | 47 - descent/tests/utilities/__init__.py | 0 descent/tests/utilities/test_smirnoff.py | 234 --- descent/tests/utilities/test_utilities.py | 13 - descent/transforms.py | 42 - descent/utilities/__init__.py | 3 - descent/utilities/smirnoff.py | 131 -- descent/utilities/utilities.py | 21 - devtools/conda-envs/meta.yaml | 41 - devtools/envs/base.yaml | 43 + examples/README.md | 3 + examples/energy-and-gradient.ipynb | 551 ------- integration-tests/test_energy_training.py | 185 --- pyproject.toml | 56 + setup.cfg | 48 - setup.py | 36 - versioneer.py | 1822 --------------------- 51 files changed, 448 insertions(+), 5926 deletions(-) delete mode 100644 .codecov.yml create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json delete mode 100644 .github/CONTRIBUTING.md delete mode 100644 .github/PULL_REQUEST_TEMPLATE.md delete mode 100644 .github/dependabot.yml delete mode 100644 .github/workflows/lint.yaml create mode 100644 .pre-commit-config.yaml delete mode 100644 MANIFEST.in create mode 100644 Makefile delete mode 100644 descent/data/__init__.py delete mode 100644 descent/data/data.py delete mode 100644 descent/data/energy.py delete mode 100644 descent/metrics.py delete mode 100644 descent/models/__init__.py delete mode 100644 descent/models/models.py delete mode 100644 descent/models/smirnoff.py delete mode 100644 descent/tests/data/__init__.py delete mode 100644 descent/tests/data/test_data.py delete mode 100644 descent/tests/data/test_energy.py delete mode 100644 descent/tests/geometric.py delete mode 100644 descent/tests/mocking/__init__.py delete mode 100644 descent/tests/mocking/qcdata.py delete mode 100644 descent/tests/mocking/systems.py delete mode 100644 descent/tests/models/__init__.py delete mode 100644 descent/tests/models/test_smirnoff.py delete mode 100644 descent/tests/test_metrics.py delete mode 100644 descent/tests/test_transforms.py delete mode 100644 descent/tests/utilities/__init__.py delete mode 100644 descent/tests/utilities/test_smirnoff.py delete mode 100644 descent/tests/utilities/test_utilities.py delete mode 100644 descent/transforms.py delete mode 100644 descent/utilities/__init__.py delete mode 100644 descent/utilities/smirnoff.py delete mode 100644 descent/utilities/utilities.py delete mode 100644 devtools/conda-envs/meta.yaml create mode 100644 devtools/envs/base.yaml create mode 100644 examples/README.md delete mode 100644 examples/energy-and-gradient.ipynb delete mode 100644 integration-tests/test_energy_training.py create mode 100644 pyproject.toml delete mode 100644 setup.cfg delete mode 100644 setup.py delete mode 100644 versioneer.py diff --git a/.codecov.yml b/.codecov.yml deleted file mode 100644 index a3ed7f4..0000000 --- a/.codecov.yml +++ /dev/null @@ -1,14 +0,0 @@ -# Codecov configuration to make it a bit less noisy -coverage: - status: - patch: false - project: - default: - threshold: 50% -comment: - layout: "header" - require_changes: false - branches: null - behavior: default - flags: null - paths: null \ No newline at end of file diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 0000000..06a579b --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,5 @@ +FROM --platform=linux/x86_64 condaforge/mambaforge:latest + +RUN apt update \ + && apt install -y git make build-essentials \ + && rm -rf /var/lib/apt/lists/* diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..863c899 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,4 @@ +{ + "build": { "dockerfile": "Dockerfile" }, + "postCreateCommand": "make env" +} \ No newline at end of file diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md deleted file mode 100644 index accc283..0000000 --- a/.github/CONTRIBUTING.md +++ /dev/null @@ -1,42 +0,0 @@ -# How to contribute - -We welcome contributions from external contributors, and this document -describes how to merge code changes into this descent. - -## Getting Started - -* Make sure you have a [GitHub account](https://github.com/signup/free). -* [Fork](https://help.github.com/articles/fork-a-repo/) this repository on GitHub. -* On your local machine, - [clone](https://help.github.com/articles/cloning-a-repository/) your fork of - the repository. - -## Making Changes - -* Add some really awesome code to your local fork. It's usually a [good - idea](http://blog.jasonmeridth.com/posts/do-not-issue-pull-requests-from-your-main-branch/) - to make changes on a - [branch](https://help.github.com/articles/creating-and-deleting-branches-within-your-repository/) - with the branch name relating to the feature you are going to add. -* When you are ready for others to examine and comment on your new feature, - navigate to your fork of descent on GitHub and open a [pull - request](https://help.github.com/articles/using-pull-requests/) (PR). Note that - after you launch a PR from one of your fork's branches, all - subsequent commits to that branch will be added to the open pull request - automatically. Each commit added to the PR will be validated for - mergability, compilation and test suite compliance; the results of these tests - will be visible on the PR page. -* If you're providing a new feature, you must add test cases and documentation. -* When the code is ready to go, make sure you run the test suite using pytest. -* When you're ready to be considered for merging, check the "Ready to go" - box on the PR page to let the descent devs know that the changes are complete. - The code will not be merged until this box is checked, the continuous - integration returns checkmarks, - and multiple core developers give "Approved" reviews. - -# Additional Resources - -* [General GitHub documentation](https://help.github.com/) -* [PR best practices](http://codeinthehole.com/writing/pull-requests-and-other-good-practices-for-teams-using-github/) -* [A guide to contributing to software packages](http://www.contribution-guide.org) -* [Thinkful PR example](http://www.thinkful.com/learn/github-pull-request-tutorial/#Time-to-Submit-Your-First-PR) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md deleted file mode 100644 index c772b96..0000000 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ /dev/null @@ -1,12 +0,0 @@ -## Description -Provide a brief description of the PR's purpose here. - -## Todos -Notable points that this PR has either accomplished or will accomplish. - - [ ] TODO 1 - -## Questions -- [ ] Question1 - -## Status -- [ ] Ready to go \ No newline at end of file diff --git a/.github/dependabot.yml b/.github/dependabot.yml deleted file mode 100644 index ae992e3..0000000 --- a/.github/dependabot.yml +++ /dev/null @@ -1,7 +0,0 @@ -version: 2 -updates: - - - package-ecosystem: "github-actions" - directory: "/" - schedule: - interval: "daily" diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8a82511..08f3e25 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,62 +1,32 @@ name: CI +concurrency: + group: ${{ github.ref }} + cancel-in-progress: true + on: - push: - branches: - - "main" - pull_request: - branches: - - "main" - schedule: - - cron: "0 0 * * *" + push: { branches: [ "main" ] } + pull_request: { branches: [ "main" ] } jobs: test: - name: ${{ matrix.os }} python=${{ matrix.python-version }} - runs-on: ${{ matrix.os }} - - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest] - python-version: [3.7] + runs-on: ubuntu-latest + container: condaforge/mambaforge:latest steps: - - uses: actions/checkout@v2.3.4 - - - name: Setup Conda Environment - uses: conda-incubator/setup-miniconda@v2.1.1 - with: - python-version: ${{ matrix.python-version }} - environment-file: devtools/conda-envs/meta.yaml - - channels: conda-forge,defaults - - activate-environment: test - auto-update-conda: true - auto-activate-base: false - show-channel-urls: true - - - name: Install Package - shell: bash -l {0} - run: | - python setup.py develop --no-deps - - - name: Conda Environment Information - shell: bash -l {0} - run: | - conda info - conda list + - uses: actions/checkout@v3.3.0 - name: Run Tests - shell: bash -l {0} run: | - pytest -v --cov=descent --cov-report=xml --color=yes descent/tests/ + apt update && apt install -y git make + + make env + make lint + make test - name: CodeCov - uses: codecov/codecov-action@v2.0.3 + uses: codecov/codecov-action@v3.1.1 with: file: ./coverage.xml flags: unittests - name: codecov-${{ matrix.os }}-py${{ matrix.python-version }} diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml deleted file mode 100644 index 9c227bb..0000000 --- a/.github/workflows/lint.yaml +++ /dev/null @@ -1,47 +0,0 @@ -name: lint - -on: - push: - branches: - - "main" - pull_request: - branches: - - "main" - -jobs: - - lint: - - runs-on: ubuntu-latest - - steps: - - - uses: actions/checkout@v2.3.4 - - uses: actions/setup-python@v2.2.2 - with: - python-version: '3.8' - - name: Install the package - run: | - python setup.py develop --no-deps - - - name: Install isort / flake8 / black - run: | - pip install isort flake8 black - - - name: Run isort - run: | - isort --recursive --check-only descent - isort --recursive --check-only examples - isort --recursive --check-only integration-tests - - - name: Run black - run: | - black descent --check - black examples --check - black integration-tests --check - - - name: Run flake8 - run: | - flake8 descent - flake8 examples - flake8 integration-tests diff --git a/.gitignore b/.gitignore index 3a17862..f2e103a 100644 --- a/.gitignore +++ b/.gitignore @@ -105,8 +105,11 @@ ENV/ # There are reports this comes from LLVM profiling, but also Xcode 9. *profraw +# PyCharm +.idea + # OSX -*DS_Store +.DS_Store -# PyCharm -.idea \ No newline at end of file +# Local development +scratch diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..b710de2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,38 @@ +repos: + - repo: local + hooks: + - id: isort + name: "[Package] Import formatting" + language: system + entry: isort + files: \.py$ + + - id: black + name: "[Package] Code formatting" + language: system + entry: black + files: \.py$ + + - id: flake8 + name: "[Package] Linting" + language: system + entry: flake8 + files: \.py$ + + - id: isort-examples + name: "[Examples] Import formatting" + language: system + entry: nbqa isort + files: examples/.+\.ipynb$ + + - id: black-examples + name: "[Examples] Code formatting" + language: system + entry: nbqa black + files: examples/.+\.ipynb$ + + - id: flake8-examples + name: "[Examples] Linting" + language: system + entry: nbqa flake8 --ignore=E402 + files: examples/.+\.ipynb$ \ No newline at end of file diff --git a/LICENSE b/LICENSE index 24b1d62..5621e42 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,7 @@ MIT License -Copyright (c) 2021 Simon Boothroyd +Copyright (c) 2023 Simon Boothroyd Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 1eed402..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,6 +0,0 @@ -include LICENSE -include MANIFEST.in -include versioneer.py - -graft descent -global-exclude *.py[cod] __pycache__ *.so \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1f0ffdb --- /dev/null +++ b/Makefile @@ -0,0 +1,32 @@ +PACKAGE_NAME := descent +CONDA_ENV_RUN := conda run --no-capture-output --name $(PACKAGE_NAME) + +.PHONY: pip-install env lint format test test-examples + +pip-install: + $(CONDA_ENV_RUN) pip install --no-build-isolation --no-deps -e . + +env: + mamba create --name $(PACKAGE_NAME) + mamba env update --name $(PACKAGE_NAME) --file devtools/envs/base.yaml + $(CONDA_ENV_RUN) pip install --no-build-isolation --no-deps -e . + $(CONDA_ENV_RUN) pre-commit install || true + +lint: + $(CONDA_ENV_RUN) isort --check-only $(PACKAGE_NAME) + $(CONDA_ENV_RUN) black --check $(PACKAGE_NAME) + $(CONDA_ENV_RUN) flake8 $(PACKAGE_NAME) + $(CONDA_ENV_RUN) nbqa isort --check-only examples + $(CONDA_ENV_RUN) nbqa black --check examples + $(CONDA_ENV_RUN) nbqa flake8 --ignore=E402 examples + +format: + $(CONDA_ENV_RUN) isort $(PACKAGE_NAME) + $(CONDA_ENV_RUN) black $(PACKAGE_NAME) + $(CONDA_ENV_RUN) flake8 $(PACKAGE_NAME) + $(CONDA_ENV_RUN) nbqa isort examples + $(CONDA_ENV_RUN) nbqa black examples + $(CONDA_ENV_RUN) nbqa flake8 --ignore=E402 examples + +test: + $(CONDA_ENV_RUN) pytest -v --cov=$(PACKAGE_NAME) --cov-report=xml --color=yes $(PACKAGE_NAME)/tests/ diff --git a/descent/__init__.py b/descent/__init__.py index debcd33..f27a734 100755 --- a/descent/__init__.py +++ b/descent/__init__.py @@ -1,12 +1,10 @@ """ -DESCENT +descent -Optimize force field parameters against QC data using `pytorch` +Optimize classical force field parameters against reference data """ -from ._version import get_versions +from . import _version -versions = get_versions() -__version__ = versions["version"] -__git_revision__ = versions["full-revisionid"] -del get_versions, versions +__version__ = _version.get_versions()["version"] +__all__ = ["__version__"] diff --git a/descent/_version.py b/descent/_version.py index c24acb3..e7eaf40 100644 --- a/descent/_version.py +++ b/descent/_version.py @@ -4,19 +4,22 @@ # directories (produced by setup.py build) will contain a much shorter file # that just contains the computed version number. -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) +# This file is released into the public domain. +# Generated by versioneer-0.29 +# https://github.com/python-versioneer/python-versioneer """Git implementation of _version.py.""" import errno +import functools import os import re import subprocess import sys +from typing import Any, Callable, Dict, List, Optional, Tuple -def get_keywords(): +def get_keywords() -> Dict[str, str]: """Get the keywords needed to look up the version information.""" # these strings will be replaced by git during git-archive. # setup.py/versioneer.py will grep for the variable names, so they must @@ -32,8 +35,15 @@ def get_keywords(): class VersioneerConfig: """Container for Versioneer configuration parameters.""" + VCS: str + style: str + tag_prefix: str + parentdir_prefix: str + versionfile_source: str + verbose: bool -def get_config(): + +def get_config() -> VersioneerConfig: """Create, populate and return the VersioneerConfig() object.""" # these strings are filled in when 'setup.py versioneer' creates # _version.py @@ -41,7 +51,7 @@ def get_config(): cfg.VCS = "git" cfg.style = "pep440" cfg.tag_prefix = "" - cfg.parentdir_prefix = "None" + cfg.parentdir_prefix = "descent-" cfg.versionfile_source = "descent/_version.py" cfg.verbose = False return cfg @@ -51,14 +61,14 @@ class NotThisMethod(Exception): """Exception raised if a method is not valid for the current scenario.""" -LONG_VERSION_PY = {} -HANDLERS = {} +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" +def register_vcs_handler(vcs: str, method: str) -> Callable: # decorator + """Create decorator to mark a method as the handler of a VCS.""" - def decorate(f): + def decorate(f: Callable) -> Callable: """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} @@ -68,24 +78,39 @@ def decorate(f): return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): +def run_command( + commands: List[str], + args: List[str], + cwd: Optional[str] = None, + verbose: bool = False, + hide_stderr: bool = False, + env: Optional[Dict[str, str]] = None, +) -> Tuple[Optional[str], Optional[int]]: """Call the given command(s).""" assert isinstance(commands, list) - p = None - for c in commands: + process = None + + popen_kwargs: Dict[str, Any] = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: try: - dispcmd = str([c] + args) + dispcmd = str([command] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen( - [c] + args, + process = subprocess.Popen( + [command] + args, cwd=cwd, env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, ) break - except EnvironmentError: - e = sys.exc_info()[1] + except OSError as e: if e.errno == errno.ENOENT: continue if verbose: @@ -96,18 +121,20 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= if verbose: print("unable to find command, tried %s" % (commands,)) return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: if verbose: print("unable to run %s (error)" % dispcmd) print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode + return None, process.returncode + return stdout, process.returncode -def versions_from_parentdir(parentdir_prefix, root, verbose): +def versions_from_parentdir( + parentdir_prefix: str, + root: str, + verbose: bool, +) -> Dict[str, Any]: """Try to determine the version from the parent directory name. Source tarballs conventionally unpack into a directory that includes both @@ -116,7 +143,7 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): """ rootdirs = [] - for i in range(3): + for _ in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): return { @@ -126,9 +153,8 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): "error": None, "date": None, } - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level + rootdirs.append(root) + root = os.path.dirname(root) # up a level if verbose: print( @@ -139,41 +165,48 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): @register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): +def git_get_keywords(versionfile_abs: str) -> Dict[str, str]: """Extract version information from the given file.""" # the code embedded in _version.py can just fetch the value of these # keywords. When used from setup.py, we don't want to import _version.py, # so we do it with a regexp instead. This function is not used from # _version.py. - keywords = {} + keywords: Dict[str, str] = {} try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: pass return keywords @register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): +def git_versions_from_keywords( + keywords: Dict[str, str], + tag_prefix: str, + verbose: bool, +) -> Dict[str, Any]: """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") date = keywords.get("date") if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 # -like" string, which we must then edit to make compliant), because @@ -186,11 +219,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): if verbose: print("keywords are unexpanded, not using") raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) + refs = {r.strip() for r in refnames.strip("()").split(",")} # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -198,8 +231,8 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # refs/heads/ and refs/tags/ prefixes that would let us distinguish # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "main". - tags = set([r for r in refs if re.search(r"\d", r)]) + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r"\d", r)} if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -208,6 +241,11 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): r = ref[len(tag_prefix) :] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r"\d", r): + continue if verbose: print("picking %s" % r) return { @@ -230,7 +268,9 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): @register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): +def git_pieces_from_vcs( + tag_prefix: str, root: str, verbose: bool, runner: Callable = run_command +) -> Dict[str, Any]: """Get version from 'git describe' in the root of the source tree. This only gets called if the git-archive 'subst' keywords were *not* @@ -241,7 +281,14 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -249,7 +296,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command( + describe_out, rc = runner( GITS, [ "describe", @@ -258,7 +305,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): "--always", "--long", "--match", - "%s*" % tag_prefix, + f"{tag_prefix}[[:digit:]]*", ], cwd=root, ) @@ -266,16 +313,48 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if describe_out is None: raise NotThisMethod("'git describe' failed") describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) if full_out is None: raise NotThisMethod("'git rev-parse' failed") full_out = full_out.strip() - pieces = {} + pieces: Dict[str, Any] = {} pieces["long"] = full_out pieces["short"] = full_out[:7] # maybe improved later pieces["error"] = None + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] # TAG might have hyphens. git_describe = describe_out @@ -292,7 +371,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # TAG-NUM-gHEX mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: - # unparseable. Maybe git-describe is misbehaving? + # unparsable. Maybe git-describe is misbehaving? pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces @@ -318,26 +397,27 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) - pieces["distance"] = int(count_out) # total number of commits + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ - 0 - ].strip() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces -def plus_or_dot(pieces): +def plus_or_dot(pieces: Dict[str, Any]) -> str: """Return a + if we don't already have one, else return a .""" if "+" in pieces.get("closest-tag", ""): return "." return "+" -def render_pep440(pieces): +def render_pep440(pieces: Dict[str, Any]) -> str: """Build up version string, with post-release "local version identifier". Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you @@ -361,23 +441,70 @@ def render_pep440(pieces): return rendered -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. +def render_pep440_branch(pieces: Dict[str, Any]) -> str: + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). Exceptions: - 1: no tags. 0.post.devDISTANCE + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] """ if pieces["closest-tag"]: rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver: str) -> Tuple[str, Optional[int]]: + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces: Dict[str, Any]) -> str: + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] else: # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] + rendered = "0.post0.dev%d" % pieces["distance"] return rendered -def render_pep440_post(pieces): +def render_pep440_post(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]+gHEX] . The ".dev0" means dirty. Note that .dev0 sorts backwards @@ -404,12 +531,41 @@ def render_pep440_post(pieces): return rendered -def render_pep440_old(pieces): +def render_pep440_post_branch(pieces: Dict[str, Any]) -> str: + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces: Dict[str, Any]) -> str: """TAG[.postDISTANCE[.dev0]] . The ".dev0" means dirty. - Eexceptions: + Exceptions: 1: no tags. 0.postDISTANCE[.dev0] """ if pieces["closest-tag"]: @@ -426,7 +582,7 @@ def render_pep440_old(pieces): return rendered -def render_git_describe(pieces): +def render_git_describe(pieces: Dict[str, Any]) -> str: """TAG[-DISTANCE-gHEX][-dirty]. Like 'git describe --tags --dirty --always'. @@ -446,7 +602,7 @@ def render_git_describe(pieces): return rendered -def render_git_describe_long(pieces): +def render_git_describe_long(pieces: Dict[str, Any]) -> str: """TAG-DISTANCE-gHEX[-dirty]. Like 'git describe --tags --dirty --always -long'. @@ -466,7 +622,7 @@ def render_git_describe_long(pieces): return rendered -def render(pieces, style): +def render(pieces: Dict[str, Any], style: str) -> Dict[str, Any]: """Render the given version pieces into the requested style.""" if pieces["error"]: return { @@ -482,10 +638,14 @@ def render(pieces, style): if style == "pep440": rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) elif style == "pep440-pre": rendered = render_pep440_pre(pieces) elif style == "pep440-post": rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) elif style == "pep440-old": rendered = render_pep440_old(pieces) elif style == "git-describe": @@ -504,7 +664,7 @@ def render(pieces, style): } -def get_versions(): +def get_versions() -> Dict[str, Any]: """Get version information or return default if unable to do so.""" # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have # __file__, we can work backwards from there to the root. Some diff --git a/descent/data/__init__.py b/descent/data/__init__.py deleted file mode 100644 index 342c428..0000000 --- a/descent/data/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from descent.data.data import Dataset, DatasetEntry - -__all__ = [DatasetEntry, Dataset] diff --git a/descent/data/data.py b/descent/data/data.py deleted file mode 100644 index e0c5903..0000000 --- a/descent/data/data.py +++ /dev/null @@ -1,78 +0,0 @@ -import abc -from typing import Generic, Iterator, Sequence, TypeVar, Union - -import torch.utils.data -from openff.interchange.components.interchange import Interchange -from smirnoffee.smirnoff import vectorize_system - -from descent.models import ParameterizationModel -from descent.models.models import VectorizedSystem - -T_co = TypeVar("T_co", covariant=True) - - -class DatasetEntry(abc.ABC): - """The base class for storing labels associated with an input datum, such as - an OpenFF interchange object or an Espaloma graph model.""" - - @property - def model_input(self) -> VectorizedSystem: - return self._model_input - - def __init__(self, model_input: Union[Interchange]): - """ - - Args: - model_input: The input that will be passed to the model being trained in - order to yield a vectorized view of a parameterised molecule. If the - input is an interchange object it will be vectorised prior to being - used as a model input. - """ - - self._model_input = ( - model_input - if not isinstance(model_input, Interchange) - else vectorize_system(model_input) - ) - - @abc.abstractmethod - def evaluate_loss(self, model: ParameterizationModel, **kwargs) -> torch.Tensor: - """Evaluates the contribution to the total loss function of the data stored - in this entry using a specified model. - - Args: - model: The model that will return vectorized view of a parameterised - molecule. - - Returns: - The loss contribution of this entry. - """ - raise NotImplementedError() - - def __call__(self, model: ParameterizationModel, **kwargs) -> torch.Tensor: - """Evaluate the objective using a specified model. - - Args: - model: The model that will return vectorized view of a parameterised - molecule. - - Returns: - The loss contribution of this entry. - """ - return self.evaluate_loss(model, **kwargs) - - -class Dataset(torch.utils.data.IterableDataset[T_co], Generic[T_co]): - r"""An class representing a :class:`Dataset`.""" - - def __init__(self, entries: Sequence): - self._entries = entries - - def __getitem__(self, index: int) -> T_co: - return self._entries[index] - - def __iter__(self) -> Iterator[T_co]: - return self._entries.__iter__() - - def __len__(self) -> int: - return len(self._entries) diff --git a/descent/data/energy.py b/descent/data/energy.py deleted file mode 100644 index 223d994..0000000 --- a/descent/data/energy.py +++ /dev/null @@ -1,789 +0,0 @@ -import functools -from collections import defaultdict -from multiprocessing import Pool -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union - -import torch -from openff.interchange.components.interchange import Interchange -from openff.toolkit.topology import Molecule, Topology -from openff.toolkit.typing.engines.smirnoff import ForceField -from openff.units import unit -from smirnoffee.geometry.internal import ( - cartesian_to_internal, - detect_internal_coordinates, -) -from smirnoffee.potentials.potentials import evaluate_vectorized_system_energy -from torch._vmap_internals import vmap -from torch.autograd import grad -from tqdm import tqdm -from typing_extensions import Literal - -from descent import metrics, transforms -from descent.data import Dataset, DatasetEntry -from descent.metrics import LossMetric -from descent.models import ParameterizationModel -from descent.transforms import LossTransform - -if TYPE_CHECKING: - - from openff.qcsubmit.results import OptimizationResultCollection - from qcportal.models import ObjectId - -_HARTREE_TO_KJ_MOL = ( - (1.0 * unit.hartree * unit.avogadro_constant) - .to(unit.kilojoule / unit.mole) - .magnitude -) -_INVERSE_BOHR_TO_ANGSTROM = (1.0 * unit.bohr ** -1).to(unit.angstrom ** -1).magnitude - - -class EnergyEntry(DatasetEntry): - """A object that stores reference energy, gradient and hessian labels for a molecule - in multiple conforms.""" - - def __init__( - self, - model_input: Union[Interchange], - conformers: torch.Tensor, - reference_energies: Optional[torch.Tensor] = None, - reference_gradients: Optional[torch.Tensor] = None, - gradient_coordinate_system: Literal["cartesian", "ric"] = "cartesian", - reference_hessians: Optional[torch.Tensor] = None, - hessian_coordinate_system: Literal["cartesian", "ric"] = "cartesian", - ): - """ - - Args: - reference_energies: The reference energies with shape=(n_conformers, 1) - and units of [kJ / mol]. - reference_gradients: The reference gradients with - shape=(n_conformers, n_atoms, 3) and units of [kJ / mol / A]. - gradient_coordinate_system: The coordinate system to project the QM and MM - gradients to before computing the loss metric. - reference_hessians: The reference gradients with - shape=(n_conformers, n_atoms * 3, n_atoms * 3) and units of - [kJ / mol / A^2]. - hessian_coordinate_system: The coordinate system to project the QM and MM - hessians to before computing the loss metric. - """ - - super(EnergyEntry, self).__init__(model_input) - - self._validate_inputs( - conformers, - reference_energies, - reference_gradients, - reference_hessians, - model_input, - ) - - self._conformers = conformers - - internal_coordinate_systems = { - key for key in [gradient_coordinate_system, hessian_coordinate_system] - } - self._inverse_b_matrices = { - coordinate_system.lower(): self._initialize_internal_coordinates( - coordinate_system, model_input.topology, reference_hessians is not None - ) - for coordinate_system in internal_coordinate_systems - if coordinate_system is not None - and coordinate_system.lower() != "cartesian" - } - - self._reference_energies = reference_energies - - if reference_hessians is not None: - - reference_hessians = self._project_hessians( - reference_hessians, reference_gradients, hessian_coordinate_system - ) - - self._reference_hessians = reference_hessians - self._hessian_coordinate_system = hessian_coordinate_system - - if reference_gradients is not None: - - reference_gradients = self._project_gradients( - reference_gradients, gradient_coordinate_system - ) - - self._reference_gradients = reference_gradients - self._gradient_coordinate_system = gradient_coordinate_system - - @classmethod - def _validate_inputs( - cls, - conformers: torch.Tensor, - reference_energies: Optional[torch.Tensor], - reference_gradients: Optional[torch.Tensor], - reference_hessians: Optional[torch.Tensor], - system: Interchange, - ): - """Validate the shapes of the input tensors.""" - - if system.topology.n_topology_molecules != 1: - raise NotImplementedError("only single molecules are supported") - - assert ( - len(conformers.shape) == 3 - ), "conformers must have shape=(n_conformers, n_atoms, 3)" - - n_conformers, n_atoms, _ = conformers.shape - - assert system.topology.n_topology_atoms == n_atoms, ( - "the number of atoms in the interchange must match the number in the " - "conformer" - ) - - if reference_energies is not None: - assert ( - n_conformers > 1 - ), "at least two conformers must be provided when training to energies" - - reference_tensors = ( - ([] if reference_energies is None else [reference_energies]) - + ([] if reference_gradients is None else [reference_gradients]) - + ([] if reference_hessians is None else [reference_hessians]) - ) - assert ( - len(reference_tensors) > 0 - ), "at least one type of reference data must be provided" - - assert all( - len(reference_tensor) == n_conformers - for reference_tensor in reference_tensors - ), ( - "the number of conformers and reference energies / " - "gradients / hessians must match" - ) - - assert reference_energies is None or reference_energies.shape == ( - n_conformers, - 1, - ), "reference energy tensor must have shape=(n_conformers, 1)" - - assert reference_gradients is None or reference_gradients.shape == ( - n_conformers, - n_atoms, - 3, - ), "reference gradient tensor must have shape=(n_conformers, n_atoms, 3)" - - assert reference_hessians is None or reference_hessians.shape == ( - n_conformers, - n_atoms * 3, - n_atoms * 3, - ), ( - "reference hessian tensor must have shape=(n_conformers, n_atoms * 3, " - "n_atoms * 3)" - ) - - def _initialize_internal_coordinates( - self, - coordinate_system: Literal["ric"], - topology: Topology, - compute_hessians: bool, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Computes the B, inverse G and B' matrices [1] used to project a set of - cartesian conformers, gradients and hessians into a particular coordinate system. - - Args: - coordinate_system: - topology: - compute_hessians: - - References: - 1. P. Pulay, G. Fogarasi, F. Pang, and J. E. Boggs J. Am. Chem. Soc. 1979, - 101, 10, 2550–2560 - - Returns: - The B, inverse G and B' matrices with shapes of: - - * (n_conformers, n_internal_coords, n_atoms * 3) - * (n_conformers, n_internal_coords, n_internal_coords) - * (n_conformers, n_internal_coords, n_atoms * 3, n_atoms * 3) - - respectively. - """ - - bond_tensor = torch.tensor( - [ - (bond.atom1_index, bond.atom2_index) - for bond in next(iter(topology.reference_molecules)).bonds - ] - ) - - b_matrices = [] - g_inverses = [] - - b_matrix_gradients = [] - - n_internal_degrees = 0 - - for conformer in self._conformers: - - conformer = conformer.detach().clone().requires_grad_() - - internal_coordinate_indices = detect_internal_coordinates( - conformer, bond_tensor, coordinate_system="ric" - ) - - internal_coordinates = torch.cat( - [ - ic_values[1].flatten() - for ic_values in cartesian_to_internal( - conformer, - ic_indices=internal_coordinate_indices, - coordinate_system=coordinate_system, - ).values() - ] - ) - internal_coordinate_indices = [ - atom_indices - for ic_type in internal_coordinate_indices - for atom_indices in internal_coordinate_indices[ic_type] - ] - - b_matrix = torch.zeros( - ( - len(internal_coordinate_indices), - conformer.shape[0] * conformer.shape[1], - ) - ) - b_matrix_gradient = ( - torch.zeros( - ( - len(internal_coordinate_indices), - *conformer.shape, - *conformer.shape, - ) - ) - if compute_hessians - else torch.tensor([]) - ) - - for row_index, (atom_indices, row) in enumerate( - zip(internal_coordinate_indices, internal_coordinates) - ): - - (gradient,) = grad(row, conformer, create_graph=True) - - # noinspection PyArgumentList - b_matrix[row_index] = gradient.flatten() - - if not compute_hessians: - continue - - # TODO: computing the hessians in this way is still ~4x5 times slower - # than geomeTRIC. Explicit equations should likely be used rather - # than auto-diffing. - gradient_subset = gradient[atom_indices].flatten() - basis_vectors = torch.eye(len(gradient_subset)) - - def get_vjp(v): - - return torch.autograd.grad( - gradient_subset, conformer, v, retain_graph=True - )[0] - - row_hessian = vmap(get_vjp)(basis_vectors).reshape( - (len(atom_indices), 3, *conformer.shape) - ) - - for i, atom_index in enumerate(atom_indices): - b_matrix_gradient[row_index, atom_index, :, :, :] = row_hessian[i] - - if len(b_matrix_gradient) > 0: - - b_matrix_gradient = b_matrix_gradient.reshape( - b_matrix.shape[0], b_matrix.shape[1], b_matrix.shape[1] - ) - - # rcond was selected here to match geomeTRIC = 0.9.7.2 - g_inverse = torch.pinverse(b_matrix @ b_matrix.T, rcond=1.0e-6) - - b_matrices.append(b_matrix.detach()) - g_inverses.append(g_inverse.detach()) - - b_matrix_gradients.append(b_matrix_gradient.detach()) - - n_internal_degrees = max(n_internal_degrees, len(g_inverse)) - - # We pad the tensors with zeros to ensure that they all have the same - # dimensions to allow easier batch calculations. This can occur when - # certain conformers contain different ICs, e.g. one with a planar N and - # another with a pyramidal N. - for j, tensor, i, tensors in ( - (j, tensor, i, tensors) - for i, tensors in enumerate((b_matrices, g_inverses, b_matrix_gradients)) - for j, tensor in enumerate(tensors) - ): - - if len(tensor) == n_internal_degrees: - continue - - n_pad = n_internal_degrees - tensor.shape[0] - - tensor = torch.cat((tensor, torch.zeros((n_pad, *tensor.shape[1:]))), dim=0) - - if i == 1: - tensor = torch.cat( - (tensor, torch.zeros((tensor.shape[0], n_pad))), dim=1 - ) - - tensors[j] = tensor - - return ( - torch.stack(b_matrices), - torch.stack(g_inverses), - torch.stack(b_matrix_gradients), - ) - - def _project_gradients( - self, gradients: torch.Tensor, coordinate_system: Literal["cartesian", "ric"] - ) -> torch.Tensor: - """Projects a set of gradients from cartesian to a specified coordinate - system.""" - - if coordinate_system.lower() == "cartesian": - return gradients - - b_matrix, g_inverse, _ = self._inverse_b_matrices[coordinate_system.lower()] - - # From doi:10.1002/(sici)1096-987x(19960115)17:1<49::aid-jcc5>3.0.co;2-0 - # Eqn (3) g_q = G^- @ B @ g_x - gradients = torch.bmm( - torch.bmm(g_inverse, b_matrix), - gradients.reshape((len(g_inverse), -1, 1)), - ) - - return gradients - - def _project_hessians( - self, - hessians: torch.Tensor, - gradients: torch.Tensor, - coordinate_system: Literal["cartesian", "ric"], - ) -> torch.Tensor: - """Projects a set of hessians from cartesian to a specified coordinate system.""" - - if coordinate_system.lower() == "cartesian": - return hessians - - assert ( - gradients is not None - ), "gradients must be provided if using internal coordinate hessians" - - b_matrix, g_inverse, b_matrix_gradient = self._inverse_b_matrices[ - coordinate_system.lower() - ] - - # See ``_initialize_gradients``. - gradients = torch.bmm( - torch.bmm(g_inverse, b_matrix), gradients.reshape((len(g_inverse), -1, 1)) - ) - hessian_delta = ( - hessians - - torch.bmm( - b_matrix_gradient.reshape( - len(gradients), g_inverse.shape[1], -1 - ).transpose(1, 2), - gradients, - ).reshape(hessians.shape) - ) - - hessians = g_inverse - - for matrix in [ - b_matrix, - hessian_delta, - b_matrix.transpose(1, 2), - g_inverse, - ]: - hessians = torch.bmm(hessians, matrix) - - return hessians - - def _evaluate_mm_energies( - self, - model: ParameterizationModel, - compute_gradients: bool = False, - compute_hessians: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Evaluate the perturbed MM energies, gradients and hessians of the system - associated with this entry. - - Args: - model: The model that will return vectorized view of a parameterised - molecule. - """ - - vectorized_system = model.forward(self._model_input) - conformers = self._conformers.detach().clone().requires_grad_() - - mm_energies, mm_gradients, mm_hessians = [], [], [] - - # TODO: replace with either vmap or vectorize smirnoffee - for conformer in conformers: - - mm_energy = evaluate_vectorized_system_energy(vectorized_system, conformer) - mm_energies.append(mm_energy) - - if not compute_gradients and not compute_hessians: - continue - - (mm_gradient,) = torch.autograd.grad( - mm_energy, conformer, create_graph=compute_gradients or compute_hessians - ) - mm_gradients.append(mm_gradient) - - if not compute_hessians: - continue - - # noinspection PyArgumentList - mm_hessian = torch.cat( - [ - torch.autograd.grad(value, conformer, retain_graph=True)[0] - for value in mm_gradient.flatten() - ] - ).reshape( - ( - conformer.shape[0] * conformer.shape[1], - conformer.shape[0] * conformer.shape[1], - ) - ) - mm_hessians.append(mm_hessian) - - return ( - torch.stack(mm_energies), - None if not compute_gradients else torch.stack(mm_gradients), - None if not compute_hessians else torch.stack(mm_hessians), - ) - - @staticmethod - def _evaluate_loss_contribution( - reference_tensor: torch.Tensor, - computed_tensor: torch.Tensor, - data_transforms: Union[LossTransform, List[LossTransform]], - data_metric: LossMetric, - ) -> torch.Tensor: - """Computes the loss contribution for a set of computed and reference labels. - - Args: - reference_tensor: The reference tensor. - computed_tensor: The computed tensor. - data_transforms: Transforms to apply to the reference and computed tensors. - data_metric: The loss metric (e.g. MSE) to compute. - - Returns: - The loss contribution. - """ - - transformed_reference_tensor = transforms.transform_tensor( - reference_tensor, data_transforms - ) - transformed_computed_tensor = transforms.transform_tensor( - computed_tensor, data_transforms - ) - - return data_metric(transformed_computed_tensor, transformed_reference_tensor) - - def evaluate_loss( - self, - model: ParameterizationModel, - energy_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, - energy_metric: Optional[LossMetric] = None, - gradient_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, - gradient_metric: Optional[LossMetric] = None, - hessian_transforms: Optional[Union[LossTransform, List[LossTransform]]] = None, - hessian_metric: Optional[LossMetric] = None, - ) -> torch.Tensor: - """ - - Args: - model: The model that will return vectorized view of a parameterised - molecule. - energy_transforms: Transforms to apply to the QM and MM energies - before computing the loss metric. By default - ``descent.transforms.relative(index=0)`` is used if no value is provided. - energy_metric: The loss metric (e.g. MSE) to compute from the QM and MM - energies. By default ``descent.metrics.mse()`` is used if no value is - provided. - gradient_transforms: Transforms to apply to the QM and MM gradients - before computing the loss metric. By default - ``descent.transforms.identity()`` is used if no value is provided. - gradient_metric: The loss metric (e.g. MSE) to compute from the QM and MM - gradients. By default ``descent.metrics.mse()`` is used if no value is - provided. - hessian_transforms: Transforms to apply to the QM and MM hessians - before computing the loss metric. By default - ``descent.transforms.identity()`` is used if no value is provided. - hessian_metric: The loss metric (e.g. MSE) to compute from the QM and MM - hessians. By default ``descent.metrics.mse()`` is used if no value is - provided. - """ - mm_energies, mm_gradients, mm_hessians = self._evaluate_mm_energies( - model, - compute_gradients=( - self._reference_hessians is not None - or self._reference_gradients is not None - ), - compute_hessians=self._reference_hessians is not None, - ) - - loss = torch.zeros(1) - - if self._reference_energies is not None: - - loss += self._evaluate_loss_contribution( - self._reference_energies, - mm_energies, - energy_transforms - if energy_transforms is not None - else transforms.relative(index=0), - energy_metric if energy_metric is not None else metrics.mse(), - ) - - if self._reference_gradients is not None: - - loss += self._evaluate_loss_contribution( - self._reference_gradients, - self._project_gradients(mm_gradients, self._gradient_coordinate_system), - gradient_transforms - if gradient_transforms is not None - else transforms.relative(index=0), - gradient_metric if gradient_metric is not None else metrics.mse(), - ) - - if self._reference_hessians is not None: - - loss += self._evaluate_loss_contribution( - self._reference_hessians, - self._project_hessians( - mm_hessians, mm_gradients, self._hessian_coordinate_system - ), - hessian_transforms - if hessian_transforms is not None - else transforms.identity(), - hessian_metric if hessian_metric is not None else metrics.mse(), - ) - - return loss - - -class EnergyDataset(Dataset[EnergyEntry]): - """A data set that stores reference energy, gradient and hessian labels.""" - - @classmethod - def _retrieve_gradient_and_hessians( - cls, - optimization_results: "OptimizationResultCollection", - include_gradients: bool, - include_hessians: bool, - verbose: bool = True, - ) -> Tuple[ - Dict[Tuple[str, "ObjectId"], torch.Tensor], - Dict[Tuple[str, "ObjectId"], torch.Tensor], - ]: - """Retrieves the hessians and gradients associated with a set of QC optimization - result records. - - Args: - optimization_results: The collection of result records whose matching - gradients and hessians should be retrieved where available. - include_gradients: Whether to retrieve gradient values. - include_hessians: Whether to retrieve hessian values. - verbose: Whether to log progress to the terminal. - - Returns: - The values of the gradients and hessians (if requested) stored in - dictionaries with keys of ``(server_address, molecule_id)``. - - Gradient tensors will have shape=(n_atoms, 3) and units of [kJ / mol / A] - and hessian shape=(n_atoms * 3, n_atoms * 3) and units of [kJ / mol / A^2]. - """ - - if not include_hessians and not include_gradients: - return {}, {} - - basic_result_collection = optimization_results.to_basic_result_collection( - driver=( - ([] if not include_gradients else ["gradient"]) - + ([] if not include_hessians else ["hessian"]) - ) - ) - - qc_gradients, qc_hessians = {}, {} - - for qc_record, _ in tqdm( - basic_result_collection.to_records(), - desc="Pulling gradient / hessian data", - disable=not verbose, - ): - - address = qc_record.client.address - - if qc_record.driver == "gradient" and include_gradients: - - qc_gradients[(address, qc_record.molecule)] = torch.from_numpy( - qc_record.return_result - * _HARTREE_TO_KJ_MOL - * _INVERSE_BOHR_TO_ANGSTROM - ).type(torch.float32) - - elif qc_record.driver == "hessian" and include_hessians: - - qc_hessians[(address, qc_record.molecule)] = torch.from_numpy( - qc_record.return_result - * _HARTREE_TO_KJ_MOL - * _INVERSE_BOHR_TO_ANGSTROM - * _INVERSE_BOHR_TO_ANGSTROM - ).type(torch.float32) - - return qc_gradients, qc_hessians - - @classmethod - def _from_grouped_results( - cls, - grouped_data: Tuple[ - str, - torch.Tensor, - Optional[torch.Tensor], - Optional[torch.Tensor], - Optional[torch.Tensor], - ], - force_field: ForceField, - **kwargs, - ) -> "EnergyEntry": - - cmiles, conformers, qc_energies, qc_gradients, qc_hessians = grouped_data - - molecule = Molecule.from_mapped_smiles(cmiles, allow_undefined_stereo=True) - system = Interchange.from_smirnoff(force_field, molecule.to_topology()) - - return EnergyEntry( - system, - conformers, - reference_energies=qc_energies, - reference_gradients=qc_gradients, - reference_hessians=qc_hessians, - **kwargs, - ) - - @classmethod - def from_optimization_results( - cls, - optimization_results: "OptimizationResultCollection", - initial_force_field: ForceField, - include_energies: bool = True, - include_gradients: bool = False, - gradient_coordinate_system: Literal["cartesian", "ric"] = "cartesian", - include_hessians: bool = False, - hessian_coordinate_system: Literal["cartesian", "ric"] = "cartesian", - n_processes: int = 1, - verbose: bool = True, - ) -> "EnergyDataset": - """Creates a dataset of energy entries (one per unique molecule) from the - **final** structures a set of QC optimization results. - - Args: - optimization_results: The collection of result records. - initial_force_field: The force field that will be trained. - include_energies: Whether to include energies. - include_gradients: Whether to include gradients. - gradient_coordinate_system: The coordinate system to project the QM and MM - gradients to before computing the loss metric. - include_hessians: Whether to include hessians. - hessian_coordinate_system: The coordinate system to project the QM and MM - hessians to before computing the loss metric. - n_processes: The number of processes to parallelize this function across. - verbose: Whether to log progress to the terminal. - - Returns: - A dataset of the energy entries. - """ - - from simtk import unit as simtk_unit - - # Group the results by molecule ignoring stereochemistry - per_molecule_records = defaultdict(list) - - for qc_record, molecule in tqdm( - optimization_results.to_records(), - desc="Pulling main optimisation records", - disable=not verbose, - ): - - molecule: Molecule = molecule.canonical_order_atoms() - conformer = molecule.conformers[0].value_in_unit(simtk_unit.angstrom) - - smiles = molecule.to_smiles( - isomeric=True, explicit_hydrogens=True, mapped=True - ) - - per_molecule_records[smiles].append((qc_record, conformer)) - - qc_gradients, qc_hessians = cls._retrieve_gradient_and_hessians( - optimization_results, include_gradients, include_hessians, verbose - ) - - result_tensors = [] - - for cmiles, qc_records in per_molecule_records.items(): - - # noinspection PyTypeChecker - grouped_data = [ - ( - torch.from_numpy(conformer).type(torch.float32), - torch.tensor([qc_record.get_final_energy() * _HARTREE_TO_KJ_MOL]), - # There should always be a gradient associated with the record - # and so we choose to raise a key error when the record is missing - # rather than skipping the entry. - None - if not include_gradients - else qc_gradients[ - (qc_record.client.address, qc_record.final_molecule) - ], - qc_hessians.get( - (qc_record.client.address, qc_record.final_molecule), None - ), - ) - for qc_record, conformer in qc_records - ] - - conformers, qm_energies, qm_gradients, qm_hessians = zip(*grouped_data) - - result_tensors.append( - ( - cmiles, - torch.stack(conformers), - torch.stack(qm_energies) if include_energies else None, - torch.stack(qm_gradients) if include_gradients else None, - torch.stack(qm_hessians) if include_hessians else None, - ) - ) - - with Pool(n_processes) as pool: - - entries = list( - tqdm( - pool.imap( - functools.partial( - cls._from_grouped_results, - force_field=initial_force_field, - gradient_coordinate_system=gradient_coordinate_system - if include_gradients - else None, - hessian_coordinate_system=hessian_coordinate_system - if include_hessians - else None, - ), - result_tensors, - ), - total=len(result_tensors), - disable=not verbose, - desc="Building entries.", - ) - ) - - return cls(entries) diff --git a/descent/metrics.py b/descent/metrics.py deleted file mode 100644 index 4d79876..0000000 --- a/descent/metrics.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Common and composable loss metrics.""" -from typing import Callable, Tuple, Union - -import torch - -LossMetric = Callable[[torch.Tensor, torch.Tensor], torch.Tensor] - - -def mse(dim: Union[int, Tuple[int, ...]] = None) -> LossMetric: - - if dim is None: - dim = () - - def _mse(input_tensor: torch.Tensor, reference_tensor: torch.Tensor): - return (input_tensor - reference_tensor).square().mean(dim=dim) - - return _mse diff --git a/descent/models/__init__.py b/descent/models/__init__.py deleted file mode 100644 index 78cd1ef..0000000 --- a/descent/models/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from descent.models.models import ParameterizationModel - -__all__ = ["ParameterizationModel"] diff --git a/descent/models/models.py b/descent/models/models.py deleted file mode 100644 index e83283e..0000000 --- a/descent/models/models.py +++ /dev/null @@ -1,18 +0,0 @@ -from typing import Any, Dict, Tuple - -from smirnoffee.smirnoff import VectorizedHandler -from typing_extensions import Protocol, runtime_checkable - -VectorizedSystem = Dict[Tuple[str, str], VectorizedHandler] - - -@runtime_checkable -class ParameterizationModel(Protocol): - """The interface the parameterization models must implement.""" - - def forward(self, graph: Any) -> VectorizedSystem: - """Outputs a vectorised view of a parameterized molecule.""" - - def summarise(self): - """Print a summary of the status of this model, such as the differences between - the initial and current state during training.""" diff --git a/descent/models/smirnoff.py b/descent/models/smirnoff.py deleted file mode 100644 index c94d5e6..0000000 --- a/descent/models/smirnoff.py +++ /dev/null @@ -1,350 +0,0 @@ -import io -from collections import defaultdict -from typing import Dict, List, Optional, Tuple, Union - -import torch.nn -from openff.interchange.models import PotentialKey -from openff.toolkit.typing.engines.smirnoff import ForceField -from smirnoffee.potentials import add_parameter_delta -from typing_extensions import Literal - -from descent.models.models import VectorizedSystem -from descent.utilities.smirnoff import perturb_force_field - - -class SMIRNOFFModel(torch.nn.Module): - """A model for perturbing a set of SMIRNOFF parameters that have been applied to a - models via the ``openff-interchange`` package. - """ - - @property - def parameter_delta_ids(self) -> Tuple[Tuple[str, PotentialKey, str], ...]: - """ - Returns: - The 'ids' of the force field parameters that will be perturbed by this model - where each 'id' is a tuple of ``(handler_type, smirks, attribute)``. - """ - - return tuple( - (handler_type, smirks, attribute) - for handler_type, parameter_ids in self._parameter_delta_ids.items() - for (smirks, attribute) in parameter_ids - ) - - def __init__( - self, - parameter_ids: List[Tuple[str, Union[str, PotentialKey], str]], - initial_force_field: Optional[ForceField], - covariance_tensor: Optional[torch.Tensor] = None, - ): - """ - Args: - parameter_ids: A list of the 'ids' of the parameters that will be applied - where each id should be a tuple of ``(handler_type, smirks, attribute)``. - initial_force_field: The (optional) force field used to initially - parameterize the molecules of interest. - covariance_tensor: A tensor that will be used to transform an the - ``parameter_delta`` before it is used by the ``forward`` method such - that ``parameter_delta = self.covariance_tensor @ self.parameter_delta``. - It should have shape=``(n_parameter_ids, n_hidden)`` where ``n_hidden`` - is the number of parameters prior to applying the transform. - - This can be used to scale the values of parameters: - - ``covariance_tensor = torch.eye(len(parameter_ids)) * 0.01`` - - or even define the covariance between parameters as in the case of BCCs: - - ``covariance_tensor = torch.tensor([[1.0], [-1.0]])`` - - Usually ``n_hidden <= n_parameter_ids``. - """ - - super(SMIRNOFFModel, self).__init__() - - self._initial_force_field = initial_force_field - - self._parameter_delta_ids = defaultdict(list) - - for handler_type, smirks, attribute in parameter_ids: - - self._parameter_delta_ids[handler_type].append( - ( - smirks - if isinstance(smirks, PotentialKey) - else PotentialKey(id=smirks, associated_handler=handler_type), - attribute, - ) - ) - - # Convert the ids to a normal dictionary to make sure KeyErrors get raised again. - self._parameter_delta_ids: Dict[str, List[Tuple[PotentialKey, str]]] = { - **self._parameter_delta_ids - } - - self._parameter_delta_indices = {} - counter = 0 - - for handler_type, handler_ids in self._parameter_delta_ids.items(): - - self._parameter_delta_indices[handler_type] = torch.arange( - counter, counter + len(handler_ids), dtype=torch.int64 - ) - - counter += len(handler_ids) - - self._covariance_tensor = covariance_tensor - - assert covariance_tensor is None or ( - covariance_tensor.ndim == 2 - and covariance_tensor.shape[0] == len(parameter_ids) - ), "the ``covariance_tensor`` must have shape=``(n_parameter_ids, n_hidden)``" - - n_parameter_deltas = ( - len(parameter_ids) if covariance_tensor is None else len(covariance_tensor) - ) - - self.parameter_delta = torch.nn.Parameter( - torch.zeros(n_parameter_deltas).requires_grad_(), requires_grad=True - ) - - def forward(self, graph: VectorizedSystem) -> VectorizedSystem: - """Perturb the parameters of an already applied force field using this models - current parameter 'deltas'. - """ - - if len(self._parameter_delta_ids) == 0: - return graph - - parameter_delta = self.parameter_delta - - if self._covariance_tensor: - parameter_delta = self._covariance_tensor @ parameter_delta - - output = {} - - for (handler_type, handler_expression), vectorized_handler in graph.items(): - - handler_delta_ids = self._parameter_delta_ids.get(handler_type, []) - - if len(handler_delta_ids) == 0: - - output[(handler_type, handler_expression)] = vectorized_handler - continue - - handler_delta_indices = self._parameter_delta_indices[handler_type] - handler_delta = parameter_delta[handler_delta_indices] - - indices, handler_parameters, handler_parameter_ids = vectorized_handler - - perturbed_parameters = add_parameter_delta( - handler_parameters, - handler_parameter_ids, - handler_delta, - handler_delta_ids, - ) - - output[(handler_type, handler_expression)] = ( - indices, - perturbed_parameters, - handler_parameter_ids, - ) - - return output - - def to_force_field(self) -> ForceField: - """Returns the current force field (i.e. initial_force_field + parameter_delta) - as an OpenFF force field object. - """ - - return perturb_force_field( - self._initial_force_field, - self.parameter_delta - if self._covariance_tensor is None - else self._covariance_tensor @ self.parameter_delta, - [ - (handler_type, smirks, attribute) - for handler_type, handler_ids in self._parameter_delta_ids.items() - for (smirks, attribute) in handler_ids - ], - ) - - def summarise( - self, - parameter_id_type: Literal["smirks", "id"] = "smirks", - print_to_terminal: bool = True, - ) -> str: - """ - - Args: - parameter_id_type: The type of ID to show for each parameter. Currently - this can either be the unique ``'id'`` associated with the parameter or - the ``'smirks'`` pattern that encodes the chemical environment the - parameter is applied to. - print_to_terminal: Whether to print the summary to the terminal - - Returns: - A string containing the summary. - """ - - from openff.units.simtk import from_simtk - - final_force_field = self.to_force_field() - - # Reshape the data into dictionaries to make tabulation easier - table_data = defaultdict(lambda: defaultdict(dict)) - attribute_units = {} - - for handler_type, potential_key, attribute in [ - (handler_type, potential_key, attribute) - for handler_type, parameter_ids in self._parameter_delta_ids.items() - for (potential_key, attribute) in parameter_ids - ]: - - smirks = potential_key.id - - attribute = ( - attribute - if potential_key.mult is None - else f"{attribute}{potential_key.mult}" - ) - - initial_value = from_simtk( - getattr( - self._initial_force_field[handler_type].parameters[smirks], - attribute, - ) - ) - final_value = from_simtk( - getattr(final_force_field[handler_type].parameters[smirks], attribute) - ) - - if (handler_type, attribute) not in attribute_units: - attribute_units[(handler_type, attribute)] = initial_value.units - - unit = attribute_units[(handler_type, attribute)] - - attribute = f"{attribute} ({unit:P~})" - - if parameter_id_type == "id": - - smirks = self._initial_force_field[handler_type].parameters[smirks].id - smirks = smirks if smirks is not None else "NO ID" - - table_data[handler_type][attribute][smirks] = ( - initial_value.to(unit).m, - final_value.to(unit).m, - ) - - # Construct the final return value: - return_value = io.StringIO() - - for handler_type, attribute_data in table_data.items(): - - print(f"\n{handler_type.center(80, '=')}\n", file=return_value) - - attribute_headers = sorted(attribute_data) - - attribute_widths = { - attribute: max( - [ - len(f"{value:.4f}") - for value_tuple in attribute_data[attribute].values() - for value in value_tuple - ] - ) - * 2 - + 1 - for attribute in attribute_headers - } - attribute_widths = { - # Make sure the width of the column - 1 is divisible by 2 - attribute: max(int((column_width - 1) / 2.0 + 0.5) * 2 + 1, 15) - for attribute, column_width in attribute_widths.items() - } - - smirks_width = max( - len(smirks) - for smirks_data in attribute_data.values() - for smirks in smirks_data - ) - - first_header = ( - " " * (smirks_width) - + " " - + " ".join( - [ - attribute.center(attribute_widths[attribute], " ") - for attribute in attribute_headers - ] - ) - ) - second_header = ( - " " * (smirks_width) - + " " - + " ".join( - [ - "INITIAL".center((column_width - 1) // 2, " ") - + " " - + "FINAL".center((column_width - 1) // 2, " ") - for attribute, column_width in attribute_widths.items() - ] - ) - ) - border = ( - "-" * smirks_width - + " " - + " ".join( - [ - "-" * attribute_widths[attribute] - for attribute in attribute_headers - ] - ) - ) - - smirks_data = defaultdict(dict) - - for attribute in attribute_data: - for smirks, value_tuple in attribute_data[attribute].items(): - smirks_data[smirks][attribute] = value_tuple - - print(border, file=return_value) - print(first_header, file=return_value) - print(second_header, file=return_value) - print(border, file=return_value) - - for smirks in sorted(smirks_data): - - def format_column(attr, value_tuple): - - if value_tuple is None: - return " " * attribute_widths[attr] - - value_width = (attribute_widths[attr] - 1) // 2 - return ( - f"{value_tuple[0]:.4f}".ljust(value_width, " ") - + " " - + f"{value_tuple[1]:.4f}".ljust(value_width, " ") - ) - - row = ( - f"{smirks.ljust(smirks_width)}" - + " " - + " ".join( - [ - format_column( - attribute, smirks_data[smirks].get(attribute, None) - ) - for attribute in attribute_headers - ] - ) - ) - - print(row, file=return_value) - - return_value = return_value.getvalue() - - if print_to_terminal: - print(return_value) - - return return_value diff --git a/descent/tests/__init__.py b/descent/tests/__init__.py index e83e4de..e69de29 100755 --- a/descent/tests/__init__.py +++ b/descent/tests/__init__.py @@ -1,64 +0,0 @@ -from typing import Callable, Union - -import numpy -from openff.toolkit.utils import string_to_unit, unit_to_string -from openff.units import unit -from simtk import unit as simtk_unit - - -def _compare_values( - a: Union[float, unit.Quantity, simtk_unit.Quantity], - b: Union[float, unit.Quantity, simtk_unit.Quantity], - predicate: Callable[ - [Union[float, numpy.ndarray], Union[float, numpy.ndarray]], bool - ], -) -> bool: - """Compare to values using a specified predicate taking units into account.""" - - if isinstance(a, simtk_unit.Quantity): - - expected_unit = unit.Unit(unit_to_string(a.unit)) - a = a.value_in_unit(a.unit) - - elif isinstance(a, unit.Quantity): - - expected_unit = a.units - a = a.to(expected_unit).magnitude - - else: - - expected_unit = None - - if isinstance(b, simtk_unit.Quantity): - - assert expected_unit is not None, "cannot compare quantity with unit-less." - b = b.value_in_unit(string_to_unit(f"{expected_unit:!s}")) - - elif isinstance(b, unit.Quantity): - - assert expected_unit is not None, "cannot compare quantity with unit-less." - b = b.to(expected_unit).magnitude - - else: - - assert expected_unit is None, "cannot compare quantity with unit-less." - - return predicate(a, b) - - -def is_close( - a: Union[float, unit.Quantity, simtk_unit.Quantity], - b: Union[float, unit.Quantity, simtk_unit.Quantity], -) -> bool: - """Compare whether two values are close taking units into account.""" - - return _compare_values(a, b, numpy.isclose) - - -def all_close( - a: Union[numpy.ndarray, unit.Quantity, simtk_unit.Quantity], - b: Union[numpy.ndarray, unit.Quantity, simtk_unit.Quantity], -) -> bool: - """Compare whether all elements in two array are close taking units into account.""" - - return _compare_values(a, b, numpy.allclose) diff --git a/descent/tests/conftest.py b/descent/tests/conftest.py index ecfe320..e69de29 100644 --- a/descent/tests/conftest.py +++ b/descent/tests/conftest.py @@ -1,72 +0,0 @@ -import pytest -import torch -from openff.interchange.components.interchange import Interchange -from openff.toolkit.topology import Molecule -from openff.toolkit.typing.engines.smirnoff import ForceField - - -@pytest.fixture(scope="module") -def default_force_field() -> ForceField: - """Returns the OpenFF 1.3.0 force field with constraints removed.""" - - force_field = ForceField("openff-1.0.0.offxml") - # force_field.deregister_parameter_handler("ToolkitAM1BCC") - force_field.deregister_parameter_handler("Constraints") - - return force_field - - -@pytest.fixture(scope="module") -def ethanol() -> Molecule: - """Returns an OpenFF ethanol molecule with a fixed atom order.""" - - return Molecule.from_mapped_smiles( - "[H:5][C:2]([H:6])([H:7])[C:3]([H:8])([H:9])[O:1][H:4]" - ) - - -@pytest.fixture(scope="module") -def ethanol_conformer(ethanol) -> torch.Tensor: - """Returns a conformer [A] of ethanol with an ordering which matches the - ``ethanol`` fixture.""" - - from simtk import unit as simtk_unit - - ethanol.generate_conformers(n_conformers=1) - conformer = ethanol.conformers[0].value_in_unit(simtk_unit.angstrom) - - return torch.from_numpy(conformer).type(torch.float) - - -@pytest.fixture(scope="module") -def ethanol_system(ethanol, default_force_field) -> Interchange: - """Returns a parametermized system of ethanol.""" - - return Interchange.from_smirnoff(default_force_field, ethanol.to_topology()) - - -@pytest.fixture(scope="module") -def formaldehyde() -> Molecule: - """Returns an OpenFF formaldehyde molecule with a fixed atom order..""" - - return Molecule.from_mapped_smiles("[H:3][C:1](=[O:2])[H:4]") - - -@pytest.fixture(scope="module") -def formaldehyde_conformer(formaldehyde) -> torch.Tensor: - """Returns a conformer [A] of formaldehyde with an ordering which matches the - ``formaldehyde`` fixture.""" - - from simtk import unit as simtk_unit - - formaldehyde.generate_conformers(n_conformers=1) - conformer = formaldehyde.conformers[0].value_in_unit(simtk_unit.angstrom) - - return torch.from_numpy(conformer).type(torch.float) - - -@pytest.fixture(scope="module") -def formaldehyde_system(formaldehyde, default_force_field) -> Interchange: - """Returns a parametermized system of formaldehyde.""" - - return Interchange.from_smirnoff(default_force_field, formaldehyde.to_topology()) diff --git a/descent/tests/data/__init__.py b/descent/tests/data/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/descent/tests/data/test_data.py b/descent/tests/data/test_data.py deleted file mode 100644 index f226ae9..0000000 --- a/descent/tests/data/test_data.py +++ /dev/null @@ -1,48 +0,0 @@ -import pytest - -from descent.data import Dataset, DatasetEntry -from descent.models.smirnoff import SMIRNOFFModel -from descent.tests.mocking.systems import generate_mock_hcl_system - - -class DummyEntry(DatasetEntry): - def evaluate_loss(self, model, **kwargs): - pass - - -class DummyDataset(Dataset[DummyEntry]): - pass - - -def test_call(monkeypatch): - - evaluate_called = False - evaluate_kwargs = {} - - class LocalEntry(DatasetEntry): - def evaluate_loss(self, model, **kwargs): - nonlocal evaluate_called - evaluate_called = True - evaluate_kwargs.update(kwargs) - - LocalEntry(generate_mock_hcl_system())(SMIRNOFFModel([], None), a="a", b=2) - - assert evaluate_called - assert evaluate_kwargs == {"a": "a", "b": 2} - - -def test_dataset(): - - model_input = generate_mock_hcl_system() - - dataset = DummyDataset(entries=[DummyEntry(model_input), DummyEntry(model_input)]) - - assert dataset[0] is not None - assert dataset[1] is not None - - with pytest.raises(IndexError): - assert dataset[2] - - assert len(dataset) == 2 - - assert all(isinstance(entry.model_input, dict) for entry in dataset) diff --git a/descent/tests/data/test_energy.py b/descent/tests/data/test_energy.py deleted file mode 100644 index 3074ae9..0000000 --- a/descent/tests/data/test_energy.py +++ /dev/null @@ -1,385 +0,0 @@ -import copy -from typing import Tuple - -import numpy -import pytest -import torch -from openff.toolkit.topology import Molecule -from openff.toolkit.typing.engines.smirnoff import ForceField -from smirnoffee.geometry.internal import detect_internal_coordinates - -from descent import metrics, transforms -from descent.data.energy import EnergyDataset, EnergyEntry -from descent.models.smirnoff import SMIRNOFFModel -from descent.tests.geometric import geometric_project_derivatives -from descent.tests.mocking.qcdata import mock_optimization_result_collection -from descent.tests.mocking.systems import generate_mock_hcl_system - - -@pytest.fixture() -def mock_hcl_conformers() -> torch.Tensor: - """Creates two mock conformers for HCl - one with a bond length of 1 A and another - with a bond length of 2 A""" - - return torch.tensor( - [[[-0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], [[-1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]] - ) - - -@pytest.fixture() -def mock_hcl_system(): - return generate_mock_hcl_system() - - -@pytest.fixture() -def mock_hcl_mm_values() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """A set of energies, gradients and hessians analytically computed using - for the ``mock_hcl_conformers`` and ``mock_hcl_system``. - """ - - return ( - torch.tensor([[0.0], [1.0]]), - torch.tensor( - [[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], [[-2.0, 0.0, 0.0], [2.0, 0.0, 0.0]]] - ), - torch.tensor( - [ - [ - [+2.0, 0.0, 0.0, -2.0, 0.0, 0.0], - [+0.0, 0.0, 0.0, +0.0, 0.0, 0.0], - [+0.0, 0.0, 0.0, +0.0, 0.0, 0.0], - [-2.0, 0.0, 0.0, +2.0, 0.0, 0.0], - [+0.0, 0.0, 0.0, +0.0, 0.0, 0.0], - [+0.0, 0.0, 0.0, +0.0, 0.0, 0.0], - ], - [ - [+2.0, +0.0, +0.0, -2.0, +0.0, +0.0], - [+0.0, +1.0, +0.0, +0.0, -1.0, +0.0], - [+0.0, +0.0, +1.0, +0.0, +0.0, -1.0], - [-2.0, +0.0, +0.0, +2.0, +0.0, +0.0], - [+0.0, -1.0, +0.0, +0.0, +1.0, +0.0], - [+0.0, +0.0, -1.0, +0.0, +0.0, +1.0], - ], - ] - ), - ) - - -def test_initialize_internal_coordinates(): - """Test that the internal coordinate matrices can be correctly constructed and - padding when different conformers of a molecule have different numbers of internal - coordinates. See ``test_gradient_hessian_projection`` for a more rigorous - integration test. - """ - - topology = Molecule.from_mapped_smiles("[H:1][C:2]#[C:3][H:4]").to_topology() - - conformers = torch.tensor( - [ - [[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [3.0, 1.0, 0.0]], - [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0], [3.0, 0.0, 0.0]], - ], - requires_grad=True, - ) - - entry = EnergyEntry.__new__(EnergyEntry) - entry._conformers = conformers - - b_matrix, g_inverse, b_matrix_gradient = entry._initialize_internal_coordinates( - "ric", topology, True - ) - - # The first conformer will have 6 ICs (3 bonds, 2 angles, 1 dihedral), while the - # second will have 7 as the dihedral is replaced with 2 linear angle terms. We - # should expect the first conformer then to have some zero paddings added to match - # the shape of the second conformer matricies. - assert b_matrix.shape == (2, 7, 12) - assert torch.allclose(b_matrix[0, 6], torch.tensor(0.0)) - assert not torch.allclose(b_matrix[1, 6], torch.tensor(0.0)) - - assert g_inverse.shape == (2, 7, 7) - assert torch.allclose(g_inverse[0, 6, :], torch.tensor(0.0)) - assert torch.allclose(g_inverse[0, :, 6], torch.tensor(0.0)) - - assert not torch.allclose(g_inverse[1, 6, :], torch.tensor(0.0)) - assert not torch.allclose(g_inverse[1, :, 6], torch.tensor(0.0)) - - assert b_matrix_gradient.shape == (2, 7, 12, 12) - assert torch.allclose(b_matrix_gradient[0, 6], torch.tensor(0.0)) - assert not torch.allclose(b_matrix_gradient[1, 6], torch.tensor(0.0)) - - -def test_gradient_hessian_projection(ethanol, ethanol_conformer, ethanol_system): - """An integration test of projecting a set of gradients and hessians onto internal - coordinates. The values are compared against the more established ``geomtric`` - package. - """ - - conformer = torch.tensor([[[0.0, 1.0, 0.0]], [[1.0, 0.0, 0.0]]], requires_grad=True) - - x = 0.5 * conformer.square().sum() - - x.backward(retain_graph=True) - print(x.grad) - - x.backward() - print(x.grad) - - internal_coordinate_indices = detect_internal_coordinates( - ethanol_conformer, - torch.tensor([(bond.atom1_index, bond.atom2_index) for bond in ethanol.bonds]), - ) - - reference_gradients = torch.rand((1, ethanol.n_atoms, 3)) - reference_hessians = torch.rand((1, ethanol.n_atoms * 3, ethanol.n_atoms * 3)) - - expected_gradiant, expected_hessian = geometric_project_derivatives( - ethanol, - ethanol_conformer, - internal_coordinate_indices, - reference_gradients, - reference_hessians, - ) - - entry = EnergyEntry( - ethanol_system, - ethanol_conformer.reshape(1, len(ethanol_conformer), 3), - reference_gradients=reference_gradients, - gradient_coordinate_system="ric", - reference_hessians=reference_hessians, - hessian_coordinate_system="ric", - ) - - actual_gradiant = entry._reference_gradients.numpy() - actual_hessian = entry._reference_hessians.numpy() - - assert numpy.allclose( - actual_gradiant.reshape(expected_gradiant.shape), expected_gradiant, atol=1.0e-3 - ) - assert numpy.allclose( - actual_hessian.reshape(expected_hessian.shape), expected_hessian, atol=1.0e-3 - ) - - -@pytest.mark.parametrize("compute_gradients", [True, False]) -@pytest.mark.parametrize("compute_hessians", [True, False]) -def test_evaluate_mm_energies( - compute_gradients, - compute_hessians, - mock_hcl_conformers, - mock_hcl_system, - mock_hcl_mm_values, -): - - entry = EnergyEntry(mock_hcl_system, mock_hcl_conformers, torch.zeros((2, 1))) - - mm_energies, mm_gradients, mm_hessians = entry._evaluate_mm_energies( - SMIRNOFFModel([], None), compute_gradients, compute_hessians - ) - - expected_energies, expected_gradients, expected_hessians = mock_hcl_mm_values - - assert mm_energies.shape == expected_energies.shape - assert torch.allclose(mm_energies, expected_energies) - - if compute_gradients: - assert mm_gradients.shape == expected_gradients.shape - assert torch.allclose(mm_gradients, expected_gradients) - else: - assert mm_gradients is None - - if compute_hessians: - assert mm_hessians.shape == expected_hessians.shape - assert torch.allclose(mm_hessians, expected_hessians) - else: - assert mm_hessians is None - - -def test_evaluate_energies(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_values): - - expected_energies, *_ = mock_hcl_mm_values - expected_scale = torch.rand(1) - - entry = EnergyEntry( - mock_hcl_system, - mock_hcl_conformers, - reference_energies=expected_energies + torch.ones_like(expected_energies), - ) - - loss = entry.evaluate_loss( - SMIRNOFFModel([], None), - energy_transforms=lambda x: expected_scale * x, - energy_metric=metrics.mse(), - ) - - assert loss.shape == (1,) - assert torch.isclose(loss, expected_scale.square()) - - -def test_evaluate_gradients(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_values): - - expected_energies, expected_gradients, _ = mock_hcl_mm_values - expected_scale = torch.rand(1) - - entry = EnergyEntry( - mock_hcl_system, - mock_hcl_conformers, - # Set a reference energy to make sure gradient contributions don't - # bleed between loss functions - reference_energies=expected_energies, - reference_gradients=expected_gradients + torch.ones_like(expected_gradients), - ) - - loss = entry.evaluate_loss( - SMIRNOFFModel([], None), - gradient_transforms=lambda x: expected_scale * x, - gradient_metric=metrics.mse(()), - ) - - assert loss.shape == (1,) - assert torch.isclose(loss, expected_scale.square()) - - -def test_evaluate_hessians(mock_hcl_conformers, mock_hcl_system, mock_hcl_mm_values): - - expected_energies, expected_gradients, expected_hessians = mock_hcl_mm_values - expected_scale = torch.rand(1) - - entry = EnergyEntry( - mock_hcl_system, - mock_hcl_conformers, - # Set a reference energy to make sure gradient contributions don't - # bleed between loss functions - reference_energies=expected_energies, - reference_gradients=expected_gradients, - reference_hessians=expected_hessians + torch.ones_like(expected_hessians), - ) - - loss = entry.evaluate_loss( - SMIRNOFFModel([], None), - hessian_transforms=lambda x: expected_scale * x, - hessian_metric=metrics.mse(()), - ) - - assert loss.shape == (1,) - assert torch.isclose(loss, expected_scale.square()) - - -def test_evaluate_loss_contribution(): - - reference_tensor = torch.tensor([[1.0], [2.0]]) - computed_tensor = torch.tensor([[4.0], [8.0]]) - - loss = EnergyEntry._evaluate_loss_contribution( - reference_tensor, computed_tensor, transforms.relative(), metrics.mse() - ) - - assert torch.isclose(loss, torch.tensor((4.0 - 1.0) ** 2 * 0.5)) - - -def test_from_grouped_results(mock_hcl_conformers, mock_hcl_mm_values): - - mock_energies, mock_gradients, mock_hessians = mock_hcl_mm_values - - created_term = EnergyDataset._from_grouped_results( - ( - "[Cl:1][Cl:2]", - mock_hcl_conformers, - mock_energies, - mock_gradients, - mock_hessians, - ), - ForceField("openff_unconstrained-1.0.0.offxml"), - ) - - assert created_term._model_input is not None - - assert torch.allclose(created_term._conformers, mock_hcl_conformers) - assert torch.allclose(created_term._reference_energies, mock_energies) - - assert torch.allclose(created_term._reference_gradients, mock_gradients) - assert torch.allclose(created_term._reference_hessians, mock_hessians) - - -@pytest.mark.parametrize("include_energies", [True, False]) -@pytest.mark.parametrize("include_gradients", [True, False]) -@pytest.mark.parametrize("include_hessians", [True, False]) -def test_from_optimization_results( - monkeypatch, include_energies, include_gradients, include_hessians -): - - from simtk import unit as simtk_unit - - if not include_energies and not include_gradients and not include_hessians: - pytest.skip("unsupported combination") - - molecules = [] - - for smiles in ["C", "CC"]: - - molecule: Molecule = Molecule.from_smiles(smiles) - molecule.generate_conformers(n_conformers=1) - - for offset in [1.0, 2.0]: - shifted_molecule = copy.deepcopy(molecule) - shifted_molecule.conformers[0] += offset * simtk_unit.angstrom - - molecules.append(shifted_molecule) - - optimization_collection = mock_optimization_result_collection( - molecules, monkeypatch - ) - - energy_dataset = EnergyDataset.from_optimization_results( - optimization_collection, - initial_force_field=ForceField(), - include_energies=include_energies, - include_gradients=include_gradients, - gradient_coordinate_system="cartesian", - include_hessians=include_hessians, - hessian_coordinate_system="cartesian", - ) - - assert len(energy_dataset) == 2 - - for energy_entry, n_atoms in zip(energy_dataset, [5, 8]): - - if not include_energies: - assert energy_entry._reference_energies is None - else: - - assert energy_entry._reference_energies is not None - assert energy_entry._reference_energies.shape == (2, 1) - - assert not torch.allclose( - energy_entry._reference_energies, - torch.zeros_like(energy_entry._reference_energies), - ) - - if not include_gradients: - assert energy_entry._reference_gradients is None - else: - - assert energy_entry._reference_gradients is not None - assert energy_entry._reference_gradients.shape == (2, n_atoms, 3) - - assert not torch.allclose( - energy_entry._reference_gradients, - torch.zeros_like(energy_entry._reference_gradients), - ) - - if not include_hessians: - assert energy_entry._reference_hessians is None - else: - - assert energy_entry._reference_hessians is not None - assert energy_entry._reference_hessians.shape == ( - 2, - n_atoms * 3, - n_atoms * 3, - ) - - assert not torch.allclose( - energy_entry._reference_hessians, - torch.zeros_like(energy_entry._reference_hessians), - ) diff --git a/descent/tests/geometric.py b/descent/tests/geometric.py deleted file mode 100644 index c88661c..0000000 --- a/descent/tests/geometric.py +++ /dev/null @@ -1,233 +0,0 @@ -from collections import defaultdict -from typing import Dict, Tuple - -import torch -from openff.toolkit.topology import Molecule - - -def _geometric_internal_coordinate_to_indices(internal_coordinate): - """A utility method for converting a ``geometric`` internal coordinate into - a tuple of atom indices. - - Args: - internal_coordinate: The internal coordinate to convert. - - Returns: - A tuple of the relevant atom indices. - """ - - from geometric.internal import Angle, Dihedral, Distance, OutOfPlane - - if isinstance(internal_coordinate, Distance): - indices = (internal_coordinate.a, internal_coordinate.b) - elif isinstance(internal_coordinate, Angle): - indices = (internal_coordinate.a, internal_coordinate.b, internal_coordinate.c) - elif isinstance(internal_coordinate, (Dihedral, OutOfPlane)): - indices = ( - internal_coordinate.a, - internal_coordinate.b, - internal_coordinate.c, - internal_coordinate.d, - ) - else: - raise NotImplementedError() - - if indices[-1] > indices[0]: - indices = tuple(reversed(indices)) - - return indices - - -def geometric_hessian( - molecule: Molecule, - conformer: torch.Tensor, - internal_coordinates_indices: Dict[str, torch.Tensor], -) -> torch.Tensor: - """A helper method to project a set of gradients and hessians into internal - coordinates using ``geomTRIC``. - - Args: - molecule: The molecule of interest - conformer: The conformer of the molecule with units of [A] and shape=(n_atoms, 3) - internal_coordinates_indices: The indices of the atoms involved in each type - of internal coordinate. - - Returns: - The projected gradients and hessians. - """ - - from geometric.internal import Angle, Dihedral, Distance, OutOfPlane - from geometric.internal import PrimitiveInternalCoordinates as GeometricPRIC - from geometric.internal import ( - RotationA, - RotationB, - RotationC, - TranslationX, - TranslationY, - TranslationZ, - ) - from geometric.molecule import Molecule as GeometricMolecule - - geometric_molecule = GeometricMolecule() - geometric_molecule.Data = { - "resname": ["UNK"] * molecule.n_atoms, - "resid": [0] * molecule.n_atoms, - "elem": [atom.element.symbol for atom in molecule.atoms], - "bonds": [(bond.atom1_index, bond.atom2_index) for bond in molecule.bonds], - "name": molecule.name, - "xyzs": [conformer.detach().numpy()], - } - - geometric_coordinates = GeometricPRIC(geometric_molecule) - - geometric_coordinates.Internals = [ - internal - for internal in geometric_coordinates.Internals - if not isinstance( - internal, - (TranslationX, TranslationY, TranslationZ, RotationA, RotationB, RotationC), - ) - ] - - # We need to re-order the internal coordinates to generate those produced by - # smirnoffee. - ic_by_type = defaultdict(list) - - ic_type_to_name = { - Distance: "distances", - Angle: "angles", - Dihedral: "dihedrals", - OutOfPlane: "out-of-plane-angles", - } - - for internal_coordinate in geometric_coordinates.Internals: - - ic_by_type[ic_type_to_name[internal_coordinate.__class__]].append( - internal_coordinate - ) - - ordered_internals = [] - - for ic_type in internal_coordinates_indices: - - ic_by_index = { - _geometric_internal_coordinate_to_indices(ic): ic - for ic in ic_by_type[ic_type] - } - - for ic_indices in internal_coordinates_indices[ic_type]: - - ic_indices = tuple(int(i) for i in ic_indices) - - if ic_indices[-1] > ic_indices[0]: - ic_indices = tuple(reversed(ic_indices)) - - ordered_internals.append(ic_by_index[ic_indices]) - - geometric_coordinates.Internals = ordered_internals - - geometric_coordinates.derivatives(conformer.detach().numpy()) - return geometric_coordinates.second_derivatives(conformer.detach().numpy()) - - -def geometric_project_derivatives( - molecule: Molecule, - conformer: torch.Tensor, - internal_coordinates_indices: Dict[str, torch.Tensor], - reference_gradients: torch.Tensor, - reference_hessians: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - """A helper method to project a set of gradients and hessians into internal - coordinates using ``geomTRIC``. - - Args: - molecule: The molecule of interest - conformer: The conformer of the molecule with units of [A] and shape=(n_atoms, 3) - internal_coordinates_indices: The indices of the atoms involved in each type - of internal coordinate. - reference_gradients: The gradients to project. - reference_hessians: The hessians to project. - - Returns: - The projected gradients and hessians. - """ - - from geometric.internal import Angle, Dihedral, Distance, OutOfPlane - from geometric.internal import PrimitiveInternalCoordinates as GeometricPRIC - from geometric.internal import ( - RotationA, - RotationB, - RotationC, - TranslationX, - TranslationY, - TranslationZ, - ) - from geometric.molecule import Molecule as GeometricMolecule - - geometric_molecule = GeometricMolecule() - geometric_molecule.Data = { - "resname": ["UNK"] * molecule.n_atoms, - "resid": [0] * molecule.n_atoms, - "elem": [atom.element.symbol for atom in molecule.atoms], - "bonds": [(bond.atom1_index, bond.atom2_index) for bond in molecule.bonds], - "name": molecule.name, - "xyzs": [conformer.detach().numpy()], - } - - geometric_coordinates = GeometricPRIC(geometric_molecule) - - geometric_coordinates.Internals = [ - internal - for internal in geometric_coordinates.Internals - if not isinstance( - internal, - (TranslationX, TranslationY, TranslationZ, RotationA, RotationB, RotationC), - ) - ] - - # We need to re-order the internal coordinates to generate those produced by - # smirnoffee. - ic_by_type = defaultdict(list) - - ic_type_to_name = { - Distance: "distances", - Angle: "angles", - Dihedral: "dihedrals", - OutOfPlane: "out-of-plane-angles", - } - - for internal_coordinate in geometric_coordinates.Internals: - - ic_by_type[ic_type_to_name[internal_coordinate.__class__]].append( - internal_coordinate - ) - - ordered_internals = [] - - for ic_type in internal_coordinates_indices: - - ic_by_index = { - _geometric_internal_coordinate_to_indices(ic): ic - for ic in ic_by_type[ic_type] - } - - for ic_indices in internal_coordinates_indices[ic_type]: - - ic_indices = tuple(int(i) for i in ic_indices) - - if ic_indices[-1] > ic_indices[0]: - ic_indices = tuple(reversed(ic_indices)) - - ordered_internals.append(ic_by_index[ic_indices]) - - geometric_coordinates.Internals = ordered_internals - - reference_gradients = reference_gradients.numpy().flatten() - reference_hessians = reference_hessians.numpy().reshape(molecule.n_atoms * 3, -1) - - xyz = conformer.detach().numpy() - - return ( - geometric_coordinates.calcGrad(xyz, reference_gradients), - geometric_coordinates.calcHess(xyz, reference_gradients, reference_hessians), - ) diff --git a/descent/tests/mocking/__init__.py b/descent/tests/mocking/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/descent/tests/mocking/qcdata.py b/descent/tests/mocking/qcdata.py deleted file mode 100644 index 54f650a..0000000 --- a/descent/tests/mocking/qcdata.py +++ /dev/null @@ -1,136 +0,0 @@ -from typing import List, Union - -import numpy -from openff.qcsubmit.results import ( - BasicResult, - BasicResultCollection, - OptimizationResult, - OptimizationResultCollection, -) -from openff.toolkit.topology import Molecule -from pydantic import BaseModel -from qcelemental.models import DriverEnum -from qcportal.models import ObjectId, OptimizationRecord, QCSpecification -from qcportal.models.records import RecordStatusEnum, ResultRecord - -DEFAULT_SERVER_ADDRESS = "http://localhost:1234/" - - -class _FractalClient(BaseModel): - - address: str - - -def mock_basic_result_collection( - molecules: List[Molecule], - drivers: Union[DriverEnum, List[DriverEnum]], - monkeypatch, -) -> BasicResultCollection: - - if not isinstance(drivers, list): - drivers = [drivers] - - collection = BasicResultCollection( - entries={ - DEFAULT_SERVER_ADDRESS: [ - BasicResult( - record_id=ObjectId(str(i + 1)), - cmiles=molecule.to_smiles(mapped=True), - inchi_key=molecule.to_inchikey(), - ) - for i, molecule in enumerate(molecules) - ] - } - ) - - def mock_return_result(molecule, driver): - - if driver == DriverEnum.gradient: - return numpy.random.random((molecule.n_atoms, 3)) - elif driver == DriverEnum.hessian: - return numpy.random.random((molecule.n_atoms * 3, molecule.n_atoms * 3)) - - raise NotImplementedError() - - monkeypatch.setattr( - BasicResultCollection, - "to_records", - lambda self: [ - ( - ResultRecord( - id=entry.record_id, - program="psi4", - driver=driver, - method="scf", - basis="sto-3g", - molecule=entry.record_id, - status=RecordStatusEnum.complete, - client=_FractalClient(address=address), - return_result=mock_return_result( - molecules[int(entry.record_id) - 1], driver - ), - ), - molecules[int(entry.record_id) - 1], - ) - for address, entries in self.entries.items() - for entry in entries - for driver in drivers - ], - ) - - return collection - - -def mock_optimization_result_collection( - molecules: List[Molecule], monkeypatch -) -> OptimizationResultCollection: - - collection = OptimizationResultCollection( - entries={ - DEFAULT_SERVER_ADDRESS: [ - OptimizationResult( - record_id=ObjectId(str(i + 1)), - cmiles=molecule.to_smiles(mapped=True), - inchi_key=molecule.to_inchikey(), - ) - for i, molecule in enumerate(molecules) - ] - } - ) - - monkeypatch.setattr( - OptimizationResultCollection, - "to_records", - lambda self: [ - ( - OptimizationRecord( - id=entry.record_id, - program="psi4", - qc_spec=QCSpecification( - driver=DriverEnum.gradient, - method="scf", - basis="sto-3g", - program="psi4", - ), - initial_molecule=ObjectId(entry.record_id), - final_molecule=ObjectId(entry.record_id), - status=RecordStatusEnum.complete, - energies=[numpy.random.random()], - client=_FractalClient(address=address), - ), - molecules[int(entry.record_id) - 1], - ) - for address, entries in self.entries.items() - for entry in entries - ], - ) - - monkeypatch.setattr( - OptimizationResultCollection, - "to_basic_result_collection", - lambda self, driver: mock_basic_result_collection( - molecules, driver, monkeypatch - ), - ) - - return collection diff --git a/descent/tests/mocking/systems.py b/descent/tests/mocking/systems.py deleted file mode 100644 index 3bacead..0000000 --- a/descent/tests/mocking/systems.py +++ /dev/null @@ -1,45 +0,0 @@ -from openff.interchange.components.interchange import Interchange -from openff.interchange.components.mdtraj import OFFBioTop -from openff.interchange.components.potentials import Potential -from openff.interchange.components.smirnoff import SMIRNOFFBondHandler -from openff.interchange.models import PotentialKey, TopologyKey -from openff.toolkit.topology import Molecule -from openff.units import unit - - -def generate_mock_hcl_system(bond_k=None, bond_length=None): - """Creates an interchange object for HCl that contains a single bond parameter with - l=1 A and k = 2 kJ / mol by default - """ - - system = Interchange() - - system.topology = OFFBioTop() - system.topology.copy_initializer( - Molecule.from_mapped_smiles("[H:1][Cl:2]").to_topology() - ) - - bond_k = ( - bond_k - if bond_k is not None - else 2.0 * unit.kilojoule / unit.mole / unit.angstrom ** 2 - ) - bond_length = bond_length if bond_length is not None else 1.0 * unit.angstrom - - system.add_handler( - "Bonds", - SMIRNOFFBondHandler( - slot_map={ - TopologyKey(atom_indices=(0, 1)): PotentialKey( - id="[#1:1]-[#17:2]", associated_handler="Bonds" - ) - }, - potentials={ - PotentialKey( - id="[#1:1]-[#17:2]", associated_handler="Bonds" - ): Potential(parameters={"k": bond_k, "length": bond_length}) - }, - ), - ) - - return system diff --git a/descent/tests/models/__init__.py b/descent/tests/models/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/descent/tests/models/test_smirnoff.py b/descent/tests/models/test_smirnoff.py deleted file mode 100644 index 51f871f..0000000 --- a/descent/tests/models/test_smirnoff.py +++ /dev/null @@ -1,221 +0,0 @@ -import pytest -import torch -from openff.interchange.components.interchange import Interchange -from openff.interchange.models import PotentialKey -from openff.toolkit.topology import Molecule -from openff.toolkit.typing.engines.smirnoff import ForceField -from smirnoffee.smirnoff import vectorize_system - -from descent.models import ParameterizationModel -from descent.models.smirnoff import SMIRNOFFModel -from descent.tests import is_close - - -@pytest.fixture() -def mock_force_field() -> ForceField: - - from simtk import unit as simtk_unit - - force_field = ForceField() - - parameters = { - "Bonds": [ - { - "smirks": "[#1:1]-[#9:2]", - "k": 1.0 * simtk_unit.kilojoules_per_mole / simtk_unit.angstrom ** 2, - "length": 2.0 * simtk_unit.angstrom, - }, - { - "smirks": "[#1:1]-[#17:2]", - "k": 3.0 * simtk_unit.kilojoules_per_mole / simtk_unit.angstrom ** 2, - "length": 4.0 * simtk_unit.angstrom, - }, - { - "smirks": "[#1:1]-[#8:2]", - "k": 5.0 * simtk_unit.kilojoules_per_mole / simtk_unit.angstrom ** 2, - "length": 6.0 * simtk_unit.angstrom, - }, - ], - "Angles": [ - { - "smirks": "[#1:1]-[#8:2]-[#1:3]", - "k": 1.0 * simtk_unit.kilojoules_per_mole / simtk_unit.degrees ** 2, - "angle": 2.0 * simtk_unit.degrees, - } - ], - } - - for handler_type, parameter_dicts in parameters.items(): - - handler = force_field.get_parameter_handler(handler_type) - - for parameter_dict in parameter_dicts: - handler.add_parameter(parameter_dict) - - return force_field - - -def test_model_matches_protocol(): - assert issubclass(SMIRNOFFModel, ParameterizationModel) - - -def test_model_init(mock_force_field): - - expected_parameter_ids = [ - ("Bonds", "[#1:1]-[#17:2]", "length"), - ("Angles", "[#1:1]-[#8:2]-[#1:3]", "angle"), - ("Bonds", "[#1:1]-[#9:2]", "k"), - ] - - model = SMIRNOFFModel(expected_parameter_ids, initial_force_field=mock_force_field) - - assert model._initial_force_field == mock_force_field - - assert {*model._parameter_delta_ids} == {"Bonds", "Angles"} - assert model._parameter_delta_ids["Bonds"] == [ - (PotentialKey(id="[#1:1]-[#17:2]", associated_handler="Bonds"), "length"), - (PotentialKey(id="[#1:1]-[#9:2]", associated_handler="Bonds"), "k"), - ] - assert model._parameter_delta_ids["Angles"] == [ - (PotentialKey(id="[#1:1]-[#8:2]-[#1:3]", associated_handler="Angles"), "angle") - ] - - assert torch.allclose(model._parameter_delta_indices["Bonds"], torch.tensor([0, 1])) - assert torch.allclose(model._parameter_delta_indices["Angles"], torch.tensor([2])) - - assert model.parameter_delta.shape == (3,) - - -def test_model_parameter_delta_ids(): - - input_parameter_ids = [ - ("Bonds", "[#1:1]-[#17:2]", "length"), - ("Angles", "[#1:1]-[#8:2]-[#1:3]", "angle"), - ("Bonds", "[#1:1]-[#9:2]", "k"), - ] - expected_parameter_ids = tuple( - ( - handler_type, - PotentialKey(id=smirks, associated_handler=handler_type), - attribute, - ) - for handler_type, smirks, attribute in [ - ("Bonds", "[#1:1]-[#17:2]", "length"), - ("Bonds", "[#1:1]-[#9:2]", "k"), - ("Angles", "[#1:1]-[#8:2]-[#1:3]", "angle"), - ] - ) - - model = SMIRNOFFModel(input_parameter_ids, None) - assert model.parameter_delta_ids == expected_parameter_ids - - -def test_model_forward_empty_input(): - - model = SMIRNOFFModel([("Bonds", "[#1:1]-[#17:2]", "length")], None) - assert model.forward({}) == {} - - -@pytest.mark.parametrize("covariance_tensor", [None, torch.tensor([[0.1]])]) -def test_model_forward(mock_force_field, covariance_tensor): - - molecule = Molecule.from_smiles("[H]Cl") - system = Interchange.from_smirnoff(mock_force_field, molecule.to_topology()) - - model = SMIRNOFFModel( - [("Bonds", "[#1:1]-[#17:2]", "length")], - None, - covariance_tensor=covariance_tensor, - ) - model.parameter_delta = torch.nn.Parameter( - model.parameter_delta + torch.tensor([1.0]), requires_grad=True - ) - - input_system = vectorize_system(system) - output_system = model.forward(input_system) - - assert torch.isclose( - output_system[("Bonds", "k/2*(r-length)**2")][1][0, 1], - torch.tensor( - 4.0 + 1.0 * (1.0 if covariance_tensor is None else covariance_tensor) - ), - ) - - -def test_model_forward_fixed_handler(mock_force_field): - """Test that forward works for the case where a system contains a handler that - contains no parameters being trained.""" - - molecule = Molecule.from_smiles("O") - system = Interchange.from_smirnoff(mock_force_field, molecule.to_topology()) - - model = SMIRNOFFModel([("Bonds", "[#1:1]-[#17:2]", "length")], None) - - input_system = vectorize_system(system) - output_system = model.forward(input_system) - - assert ("Angles", "k/2*(theta-angle)**2") in output_system - - assert len(output_system[("Angles", "k/2*(theta-angle)**2")][0]) == 1 - assert len(output_system[("Angles", "k/2*(theta-angle)**2")][1]) == 1 - assert len(output_system[("Angles", "k/2*(theta-angle)**2")][2]) == 1 - - -@pytest.mark.parametrize("covariance_tensor", [None, torch.tensor([[0.1]])]) -def test_model_to_force_field(mock_force_field, covariance_tensor): - """Test that forward works for the case where a system contains a handler that - contains no parameters being trained.""" - - from simtk import unit as simtk_unit - - model = SMIRNOFFModel( - [("Bonds", "[#1:1]-[#17:2]", "length")], - mock_force_field, - covariance_tensor=covariance_tensor, - ) - model.parameter_delta = torch.nn.Parameter( - model.parameter_delta + torch.tensor([1.0]), requires_grad=True - ) - - output_force_field = model.to_force_field() - - assert is_close( - output_force_field["Bonds"].parameters["[#1:1]-[#17:2]"].length, - (4.0 + 1.0 * (1.0 if covariance_tensor is None else float(covariance_tensor))) - * simtk_unit.angstrom, - ) - - -@pytest.mark.parametrize("parameter_id_type", ["id", "smirks"]) -def test_model_summarise(mock_force_field, parameter_id_type): - - mock_force_field["Bonds"].parameters["[#1:1]-[#17:2]"].id = "b1" - - model = SMIRNOFFModel( - [ - ("Bonds", "[#1:1]-[#17:2]", "length"), - ("Bonds", "[#1:1]-[#9:2]", "length"), - ("Bonds", "[#1:1]-[#17:2]", "k"), - ("Angles", "[#1:1]-[#8:2]-[#1:3]", "angle"), - ("Angles", "[#1:1]-[#8:2]-[#1:3]", "k"), - ], - mock_force_field, - ) - model.parameter_delta = torch.nn.Parameter(torch.tensor([1.00001] * 5)) - return_value = model.summarise(parameter_id_type=parameter_id_type) - - assert return_value is not None - assert len(return_value) > 0 - - assert "Bonds" in return_value - assert "Angles" in return_value - - assert " k (kJ/mol/Ų) length (Å) " in return_value - assert " angle (deg) k (kJ/deg²/mol)" in return_value - - if parameter_id_type == "smirks": - assert "[#1:1]-[#17:2]" in return_value - assert "b1" not in return_value - else: - assert "[#1:1]-[#17:2]" not in return_value - assert "b1" in return_value diff --git a/descent/tests/test_metrics.py b/descent/tests/test_metrics.py deleted file mode 100644 index 879e34f..0000000 --- a/descent/tests/test_metrics.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest -import torch - -from descent import metrics - - -@pytest.mark.parametrize( - "dim, expected", - [ - (0, torch.tensor([13.0 / 2.0, 5.0 / 2.0])), - (1, torch.tensor([10.0 / 2.0, 8.0 / 2.0])), - ((), torch.tensor(18.0 / 4.0)), - ], -) -def test_mse(dim, expected): - - # 3 1 - # 2 2 - - input_a = torch.tensor([[1.0, 2.0], [5.0, 6.0]]) - input_b = torch.tensor([[4.0, 3.0], [7.0, 8.0]]) - - output = metrics.mse(dim=dim)(input_a, input_b) - - assert output.shape == expected.shape - assert torch.allclose(output, expected) diff --git a/descent/tests/test_transforms.py b/descent/tests/test_transforms.py deleted file mode 100644 index b704f08..0000000 --- a/descent/tests/test_transforms.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -import torch - -from descent import transforms -from descent.transforms import transform_tensor - - -def test_identity(): - - value = torch.rand(4) - output = transforms.identity()(value) - - assert output.shape == value.shape - assert torch.allclose(output, value) - - -@pytest.mark.parametrize( - "index, expected", - [(0, torch.tensor([0.0, 1.0, 2.0])), (1, torch.tensor([-1.0, 0.0, 1.0]))], -) -def test_relative(index, expected): - - value = torch.tensor([1.0, 2.0, 3.0]) - output = transforms.relative(index=index)(value) - - assert output.shape == expected.shape - assert torch.allclose(output, expected) - - -@pytest.mark.parametrize( - "transforms_to_apply, expected", - [ - (transforms.relative(0), torch.tensor([0.0, 1.0, 2.0])), - ([], torch.tensor([1.0, 2.0, 3.0])), - ( - [transforms.relative(0), transforms.relative(1)], - torch.tensor([-1.0, 0.0, 1.0]), - ), - ], -) -def test_transform_tensor(transforms_to_apply, expected): - - value = torch.tensor([1.0, 2.0, 3.0]) - output = transform_tensor(value, transforms_to_apply) - - assert output.shape == expected.shape - assert torch.allclose(output, expected) diff --git a/descent/tests/utilities/__init__.py b/descent/tests/utilities/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/descent/tests/utilities/test_smirnoff.py b/descent/tests/utilities/test_smirnoff.py deleted file mode 100644 index 5a3e1a3..0000000 --- a/descent/tests/utilities/test_smirnoff.py +++ /dev/null @@ -1,234 +0,0 @@ -import pytest -import torch -from openff.interchange.models import PotentialKey -from openff.toolkit.typing.engines.smirnoff import ForceField -from simtk import unit as simtk_unit - -from descent.data import DatasetEntry -from descent.tests import is_close -from descent.utilities.smirnoff import exercised_parameters, perturb_force_field - - -def test_perturb_force_field(): - - smirks = "[*:1]~[*:2]~[*:3]~[*:4]" - - initial_force_field = ForceField() - initial_force_field.get_parameter_handler("ProperTorsions") - - initial_force_field["ProperTorsions"].add_parameter( - { - "smirks": smirks, - "k": [0.1, 0.2] * simtk_unit.kilocalories_per_mole, - "phase": [0.0, 180.0] * simtk_unit.degree, - "periodicity": [1, 2], - "idivf": [2.0, 1.0], - } - ) - - perturbed_force_field = perturb_force_field( - initial_force_field, - torch.tensor([0.1, 0.2]), - [ - ("ProperTorsions", PotentialKey(id=smirks, mult=0), "k"), - ("ProperTorsions", PotentialKey(id=smirks, mult=1), "idivf"), - ], - ) - - assert is_close( - perturbed_force_field["ProperTorsions"].parameters[smirks].k1, - 0.1 * simtk_unit.kilocalories_per_mole + 0.1 * simtk_unit.kilojoules_per_mole, - ) - assert is_close( - perturbed_force_field["ProperTorsions"].parameters[smirks].idivf2, 1.2 - ) - - -@pytest.mark.parametrize( - "handlers_to_include," - "handlers_to_exclude," - "ids_to_include," - "ids_to_exclude," - "attributes_to_include," - "attributes_to_exclude," - "n_expected," - "expected_handlers," - "expected_potential_keys," - "expected_attributes", - [ - ( - None, - None, - None, - None, - None, - None, - 36, - {"Bonds", "Angles"}, - { - PotentialKey(id=smirks, mult=mult, associated_handler=handler) - for smirks in ("a", "b", "c") - for mult in (None, 0, 1) - for handler in ("Bonds", "Angles") - }, - {"k", "length", "angle"}, - ), - ( - ["Bonds"], - None, - None, - None, - None, - None, - 18, - {"Bonds"}, - { - PotentialKey(id=smirks, mult=mult, associated_handler="Bonds") - for smirks in ("a", "b", "c") - for mult in (None, 0, 1) - }, - {"k", "length"}, - ), - ( - None, - ["Bonds"], - None, - None, - None, - None, - 18, - {"Angles"}, - { - PotentialKey(id=smirks, mult=mult, associated_handler="Angles") - for smirks in ("a", "b", "c") - for mult in (None, 0, 1) - }, - {"k", "angle"}, - ), - ( - None, - None, - [PotentialKey(id="b", mult=0, associated_handler="Bonds")], - None, - None, - None, - 2, - {"Bonds"}, - {PotentialKey(id="b", mult=0, associated_handler="Bonds")}, - {"k", "length"}, - ), - ( - None, - None, - None, - [ - PotentialKey(id="b", mult=0, associated_handler="Bonds"), - ], - None, - None, - 34, - {"Bonds", "Angles"}, - { - PotentialKey(id=smirks, mult=mult, associated_handler=handler) - for handler in ("Bonds", "Angles") - for smirks in ("a", "b", "c") - for mult in (None, 0, 1) - if (smirks != "b" or mult != 0 or handler != "Bonds") - }, - {"k", "length", "angle"}, - ), - ( - None, - None, - None, - None, - ["length"], - None, - 9, - {"Bonds"}, - { - PotentialKey(id=smirks, mult=mult, associated_handler="Bonds") - for smirks in ("a", "b", "c") - for mult in (None, 0, 1) - }, - {"length"}, - ), - ( - None, - None, - None, - None, - None, - ["length"], - 27, - {"Bonds", "Angles"}, - { - PotentialKey(id=smirks, mult=mult, associated_handler=handler) - for handler in ("Bonds", "Angles") - for smirks in ("a", "b", "c") - for mult in (None, 0, 1) - }, - {"k", "angle"}, - ), - ], -) -def test_exercised_parameters( - handlers_to_include, - handlers_to_exclude, - ids_to_include, - ids_to_exclude, - attributes_to_include, - attributes_to_exclude, - n_expected, - expected_handlers, - expected_potential_keys, - expected_attributes, -): - class MockEntry(DatasetEntry): - def evaluate_loss(self, model, **kwargs): - pass - - def mock_entry(handler, patterns, mult): - - attributes = {"Bonds": ["k", "length"], "Angles": ["k", "angle"]}[handler] - - entry = MockEntry.__new__(MockEntry) - entry._model_input = { - (handler, ""): ( - None, - None, - [ - ( - PotentialKey(id=smirks, mult=mult, associated_handler=handler), - attributes, - ) - for smirks in patterns - ], - ) - } - return entry - - entries = [ - mock_entry(handler, patterns, mult) - for handler in ["Bonds", "Angles"] - for patterns in [("a", "b"), ("b", "c")] - for mult in [None, 0, 1] - ] - - parameter_keys = exercised_parameters( - entries, - handlers_to_include, - handlers_to_exclude, - ids_to_include, - ids_to_exclude, - attributes_to_include, - attributes_to_exclude, - ) - - assert len(parameter_keys) == n_expected - - actual_handlers, actual_keys, actual_attributes = zip(*parameter_keys) - - assert {*actual_handlers} == expected_handlers - assert {*actual_keys} == expected_potential_keys - assert {*actual_attributes} == expected_attributes diff --git a/descent/tests/utilities/test_utilities.py b/descent/tests/utilities/test_utilities.py deleted file mode 100644 index 2c19229..0000000 --- a/descent/tests/utilities/test_utilities.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest - -from descent.utilities import value_or_list_to_list - - -@pytest.mark.parametrize( - "function_input, expected_output", - [(None, None), (2, [2]), ("a", ["a"]), ([1, 2], [1, 2]), (["a", "b"], ["a", "b"])], -) -def test_value_or_list_to_list(function_input, expected_output): - - actual_output = value_or_list_to_list(function_input) - assert expected_output == actual_output diff --git a/descent/transforms.py b/descent/transforms.py deleted file mode 100644 index a87c235..0000000 --- a/descent/transforms.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Common and composable tensor transformations useful when computing loss metrics.""" -from typing import Callable, Iterable, List, Union - -import torch - -LossTransform = Callable[[torch.Tensor], torch.Tensor] - - -def identity() -> LossTransform: - def _identity(input_tensor: torch.Tensor): - return input_tensor - - return _identity - - -def relative(index: int = 0) -> LossTransform: - def _relative(input_tensor: torch.Tensor): - return input_tensor - input_tensor[index] - - return _relative - - -def transform_tensor( - input_tensor: torch.Tensor, transforms: Union[LossTransform, List[LossTransform]] -) -> torch.Tensor: - """Applies a set of transforms to an input tensor. - - Args: - input_tensor: The tensor to transorm. - transforms: The transforms to apply. If ``None``, the input tensor will be returned. - - Returns: - - """ - - if not isinstance(transforms, Iterable): - transforms = [transforms] - - for transform in transforms: - input_tensor = transform(input_tensor) - - return input_tensor diff --git a/descent/utilities/__init__.py b/descent/utilities/__init__.py deleted file mode 100644 index 0db1acf..0000000 --- a/descent/utilities/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from descent.utilities.utilities import value_or_list_to_list - -__all__ = [value_or_list_to_list] diff --git a/descent/utilities/smirnoff.py b/descent/utilities/smirnoff.py deleted file mode 100644 index bfe0b34..0000000 --- a/descent/utilities/smirnoff.py +++ /dev/null @@ -1,131 +0,0 @@ -import copy -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union - -import torch -from openff.interchange.components.interchange import Interchange -from openff.interchange.models import PotentialKey -from openff.toolkit.typing.engines.smirnoff import ForceField -from openff.toolkit.utils import string_to_unit -from smirnoffee.smirnoff import _DEFAULT_UNITS, vectorize_system - -from descent.utilities import value_or_list_to_list - -if TYPE_CHECKING: - from descent.data import Dataset, DatasetEntry - - -def perturb_force_field( - force_field: ForceField, - parameter_delta: torch.Tensor, - parameter_delta_ids: List[Tuple[str, PotentialKey, str]], -) -> ForceField: - """Perturbs the specified parameters in a force field by the provided delta values. - - Args: - force_field: The force field to perturb. - parameter_delta: A 1D tensor of deltas to add to the parameters referenced by - the ``parameter_delta_ids``. - parameter_delta_ids: - The unique identifiers which maps each value in the ``parameter_delta`` - tensor to a SMIRNOFF force field parameter. These should be of the form: - ``(handler_name, potential_key, attribute_name)``. - - Returns: - The perturbed force field. - """ - from simtk import unit as simtk_unit - - force_field = copy.deepcopy(force_field) - - for (handler_name, potential_key, attribute), delta in zip( - parameter_delta_ids, parameter_delta - ): - - parameter = force_field[handler_name].parameters[potential_key.id] - - delta = delta.detach().item() * string_to_unit( - f"{_DEFAULT_UNITS[handler_name][attribute]:!s}" - ) - - if potential_key.mult is not None: - attribute = f"{attribute}{potential_key.mult + 1}" - - original_value = getattr(parameter, attribute) - - if not isinstance(original_value, simtk_unit.Quantity): - delta = delta.value_in_unit(simtk_unit.dimensionless) - - setattr(parameter, attribute, original_value + delta) - - return force_field - - -def exercised_parameters( - dataset: Union["Dataset", Iterable["DatasetEntry"], Iterable[Interchange]], - handlers_to_include: Optional[Union[str, List[str]]] = None, - handlers_to_exclude: Optional[Union[str, List[str]]] = None, - ids_to_include: Optional[Union[PotentialKey, List[PotentialKey]]] = None, - ids_to_exclude: Optional[Union[PotentialKey, List[PotentialKey]]] = None, - attributes_to_include: Optional[Union[str, List[str]]] = None, - attributes_to_exclude: Optional[Union[str, List[str]]] = None, -) -> List[Tuple[str, PotentialKey, str]]: - """Returns the identifiers of each parameter that has been assigned to each molecule - in a dataset. - - Notes: - This function assumes that the dataset was created using an OpenFF interchange - object as the main input. - - Args: - dataset: The dataset, list of dataset entries, or list of interchange objects - That track a set of SMIRNOFF parameters assigned to a set of molecules. - handlers_to_include: An optional list of the parameter handlers that the returned - parameters should be associated with. - handlers_to_exclude: An optional list of the parameter handlers that the returned - parameters should **not** be associated with. - ids_to_include: An optional list of the potential keys that the parameters should - match with to be returned. - ids_to_exclude: An optional list of the potential keys that the parameters should - **not** match with to be returned. - attributes_to_include: An optional list of the attributes that the parameters - should match with to be returned. - attributes_to_exclude: An optional list of the attributes that the parameters - should **not** match with to be returned. - - Returns: - A list of tuples of the form ``(handler_type, potential_key, attribute_name)``. - """ - - def should_skip(value, to_include, to_exclude) -> bool: - - to_include = value_or_list_to_list(to_include) - to_exclude = value_or_list_to_list(to_exclude) - - return (to_include is not None and value not in to_include) or ( - to_exclude is not None and value in to_exclude - ) - - vectorized_systems = [ - entry.model_input - if not isinstance(entry, Interchange) - else vectorize_system(entry) - for entry in dataset - ] - - return_value = { - (handler_type, potential_key, attribute) - for vectorized_system in vectorized_systems - for (handler_type, _), (*_, potential_keys) in vectorized_system.items() - if not should_skip(handler_type, handlers_to_include, handlers_to_exclude) - for (potential_key, attributes) in potential_keys - if not should_skip(potential_key, ids_to_include, ids_to_exclude) - for attribute in attributes - if not should_skip(attribute, attributes_to_include, attributes_to_exclude) - } - - return_value = sorted( - return_value, - key=lambda x: (x[0], x[1].id, x[1].mult if x[1].mult is not None else -1, x[2]), - ) - - return return_value diff --git a/descent/utilities/utilities.py b/descent/utilities/utilities.py deleted file mode 100644 index 7ff7431..0000000 --- a/descent/utilities/utilities.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import List, TypeVar, Union, overload - -T = TypeVar("T") - - -@overload -def value_or_list_to_list(value: Union[T, List[T]]) -> List[T]: - ... - - -@overload -def value_or_list_to_list(value: None) -> None: - ... - - -def value_or_list_to_list(value): - - if value is None: - return value - - return value if isinstance(value, list) else [value] diff --git a/devtools/conda-envs/meta.yaml b/devtools/conda-envs/meta.yaml deleted file mode 100644 index 25e1d85..0000000 --- a/devtools/conda-envs/meta.yaml +++ /dev/null @@ -1,41 +0,0 @@ -name: test - -channels: - - conda-forge - - defaults - - simonboothroyd - -dependencies: - - # Core dependencies - - python - - pip - - tqdm - - - openff-toolkit-base >=0.9.2 - - openff-interchange ==0.1.0 - - - openff-utilities - - openff-units - - - openff-qcsubmit >= 0.2.2 - - - smirnoffee >=0.0.1a2 - - pytorch >= 1.9.0 - - # Optional dependencies - - rdkit - - ambertools - - # Testing - - pytest - - pytest-cov - - codecov - - - pydantic - - geometric - - # Development - - isort - - black - - flake8 diff --git a/devtools/envs/base.yaml b/devtools/envs/base.yaml new file mode 100644 index 0000000..7bc8174 --- /dev/null +++ b/devtools/envs/base.yaml @@ -0,0 +1,43 @@ +name: descent + +channels: + - conda-forge + +dependencies: + + - python >=3.10 + - pip + + # Core packages + # - smee + + - pytorch + - pydantic + + # Optional packages + + ### Optimize + - scipy + + # Examples + - jupyter + - nbconvert + + # Dev / Testing + - ambertools + - rdkit + + - versioneer + + - pre-commit + - isort + - black + - flake8 + - flake8-pyproject + - nbqa + + - pytest + - pytest-cov + - pytest-mock + + - codecov diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..62985b1 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,3 @@ +# Examples + +This directory contains a number of examples of how to use `descent`. They currently include: diff --git a/examples/energy-and-gradient.ipynb b/examples/energy-and-gradient.ipynb deleted file mode 100644 index f9ace63..0000000 --- a/examples/energy-and-gradient.ipynb +++ /dev/null @@ -1,551 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "source": [ - "# Training Against QM Energies and Gradients\n", - "\n", - "This notebook aims to show how the [`descent`](https://github.com/SimonBoothroyd/descent) framework in combination with\n", - "[`smirnoffee`](https://github.com/SimonBoothroyd/smirnoffee) can be used to train a set of SMIRNOFF force field bond and\n", - "angle force constant parameters against the QM computed energies and associated gradients of a small molecule in\n", - "multiple conformers.\n", - "\n", - "For the sake of clarity all warning will be disabled:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 1, - "outputs": [], - "source": [ - "import warnings\n", - "warnings.filterwarnings('ignore')\n", - "import logging\n", - "logging.getLogger(\"openff.toolkit\").setLevel(logging.ERROR)" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "### Curating a QC training set\n", - "\n", - "For this example we will be training against QM energies which have been computed by and stored within the\n", - "[QCArchive](https://qcarchive.molssi.org/), which are easily retrieved using the [OpenFF QCSubmit](https://github.com/openforcefield/openff-qcsubmit)\n", - "package.\n", - "\n", - "We begin by importing the records associated with the `OpenFF Optimization Set 1` optimization data set:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 2, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Warning: importing 'simtk.openmm' is deprecated. Import 'openmm' instead.\n" - ] - } - ], - "source": [ - "from qcportal import FractalClient\n", - "\n", - "from openff.qcsubmit.results import OptimizationResultCollection\n", - "\n", - "result_collection = OptimizationResultCollection.from_server(\n", - " client=FractalClient(),\n", - " datasets=\"OpenFF Optimization Set 1\"\n", - ")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "which we will then filter to retain a small molecule which will be fast to train on as a demonstration:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "source": [ - "from openff.qcsubmit.results.filters import ConformerRMSDFilter, SMILESFilter\n", - "\n", - "result_collection = result_collection.filter(\n", - " SMILESFilter(smiles_to_include=[\"CC(=O)NCC1=NC=CN1C\"]),\n", - " # Only retain conformers with an RMSD greater than 0.5 Å.\n", - " ConformerRMSDFilter(max_conformers=10, rmsd_tolerance=0.5)\n", - ")\n", - "\n", - "print(f\"N Molecules: {result_collection.n_molecules}\")\n", - "print(f\"N Conformers: {result_collection.n_results}\")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - }, - "execution_count": 3, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "N Molecules: 1\n", - "N Conformers: 3\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "You should see that our filtered collection contains the 6 results, which corresponds to 6 minimized conformers (and\n", - "their associated energy computed using the OpenFF default B3LYP-D3BJ spec) for the molecule we filtered for above.\n", - "\n", - "In order to be able to train our parameter against this data we need to wrap it in a PyTorch dataset object. This\n", - "is made trivial thanks to the built-in ``EnergyDataset`` object that ships with the framework. The energy dataset\n", - "will extract and store any energy, gradient, and hessian data in a format ready for evaluating a loss function.\n", - "\n", - "We first load in the initial force field parameters ($\\theta$) using the [OpenFF Toolkit](https://github.com/openforcefield/openff-toolkit):" - ], - "metadata": { - "collapsed": false - } - }, - { - "cell_type": "code", - "execution_count": 4, - "outputs": [], - "source": [ - "from openff.toolkit.typing.engines.smirnoff import ForceField\n", - "initial_force_field = ForceField(\"openff_unconstrained-1.0.0.offxml\")" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "markdown", - "source": [ - "which we can then use to construct our dataset:" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%% md\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 5, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Pulling main optimisation records: 100%|██████████| 3/3 [00:00<00:00, 183.13it/s]\n", - "Pulling gradient / hessian data: 100%|██████████| 3/3 [00:00<00:00, 3483.64it/s]\n", - "Building entries.: 0%| | 0/1 [00:00 torch.Tensor: - - model = SMIRNOFFModel(parameter_delta_ids, None) - - optimizer = torch.optim.Adam([model.parameter_delta], lr=learning_rate) - - for epoch in range(n_epochs): - - loss = loss_function(model, **loss_kwargs) - loss.backward() - - optimizer.step() - optimizer.zero_grad() - - if verbose and (epoch % 20 == 0 or epoch == n_epochs - 1): - print(f"Epoch {epoch}: loss={loss.item()}") - - return model.parameter_delta - - -@pytest.mark.parametrize( - "transform", [transforms.identity(), transforms.relative(index=0)] -) -@pytest.mark.parametrize("metric", [metrics.mse()]) -def test_energies_only(transform, metric): - - conformers = torch.tensor( - [[[-0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], [[-1.25, 0.0, 0.0], [1.25, 0.0, 0.0]]] - ) - - # Define the expected energies assuming that k=2.5 and l=2.0 - reference_energies = torch.tensor( - [[1.25 * (distance - 2.0) ** 2] for distance in [1.0, 2.5]] - ) - - starting_system = generate_mock_hcl_system( - bond_k=5.0 * unit.kilojoule / unit.mole / unit.angstrom ** 2, - bond_length=2.0 * unit.angstrom, - ) - loss_function = EnergyEntry( - starting_system, - conformers, - reference_energies=reference_energies, - ) - - actual_parameter_delta = train_parameters( - loss_function, - dict( - energy_metric=metric, - energy_transforms=transform, - ), - [("Bonds", "[#1:1]-[#17:2]", "k")], - learning_rate=0.1, - ) - expected_parameter_delta = torch.tensor([-2.5]) - - assert actual_parameter_delta.shape == expected_parameter_delta.shape - assert torch.allclose(actual_parameter_delta, expected_parameter_delta) - - print(f"EXPECTED=f{expected_parameter_delta} ACTUAL=f{actual_parameter_delta}") - - -@pytest.mark.parametrize("transform", [transforms.identity()]) -@pytest.mark.parametrize("metric", [metrics.mse()]) -@pytest.mark.parametrize("coordinate_system", ["cartesian", "ric"]) -def test_forces_only(transform, metric, coordinate_system): - - conformers = torch.tensor( - [[[-0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], [[-1.25, 0.0, 0.0], [1.25, 0.0, 0.0]]] - ) - - # Define the expected gradients assuming that k=2.5 and l=2.0 - reference_gradients = torch.tensor( - [ - [[-2.5 * (distance - 2.0), 0.0, 0.0], [2.5 * (distance - 2.0), 0.0, 0.0]] - for distance in [1.0, 2.5] - ] - ) - - starting_system = generate_mock_hcl_system( - bond_k=5.0 * unit.kilojoule / unit.mole / unit.angstrom ** 2, - bond_length=2.0 * unit.angstrom, - ) - loss_function = EnergyEntry( - starting_system, - conformers, - reference_gradients=reference_gradients, - gradient_coordinate_system=coordinate_system, - ) - - actual_parameter_delta = train_parameters( - loss_function, - dict( - gradient_metric=metric, - gradient_transforms=transform, - ), - [("Bonds", "[#1:1]-[#17:2]", "k")], - learning_rate=0.1, - ) - expected_parameter_delta = torch.tensor([-2.5]) - - assert actual_parameter_delta.shape == expected_parameter_delta.shape - assert torch.allclose(actual_parameter_delta, expected_parameter_delta) - - print(f"EXPECTED={expected_parameter_delta} ACTUAL={actual_parameter_delta}") - - -@pytest.mark.parametrize( - "energy_transform", [transforms.identity(), transforms.relative(index=0)] -) -@pytest.mark.parametrize("energy_metric", [metrics.mse()]) -@pytest.mark.parametrize( - "gradient_transform", [transforms.identity(), transforms.relative(index=0)] -) -@pytest.mark.parametrize("gradient_metric", [metrics.mse()]) -@pytest.mark.parametrize("gradient_coordinate_system", ["cartesian", "ric"]) -def test_energies_and_forces( - energy_transform, - energy_metric, - gradient_transform, - gradient_metric, - gradient_coordinate_system, -): - - conformers = torch.tensor( - [[[-0.5, 0.0, 0.0], [0.5, 0.0, 0.0]], [[-1.25, 0.0, 0.0], [1.25, 0.0, 0.0]]] - ) - - # Define the expected gradients assuming that k=2.5 and l=2.0 - reference_energies = torch.tensor( - [[1.25 * (distance - 2.0) ** 2] for distance in [1.0, 2.5]] - ) - reference_gradients = torch.tensor( - [ - [[-2.5 * (distance - 2.0), 0.0, 0.0], [2.5 * (distance - 2.0), 0.0, 0.0]] - for distance in [1.0, 2.5] - ] - ) - - starting_system = generate_mock_hcl_system( - bond_k=5.0 * unit.kilojoule / unit.mole / unit.angstrom ** 2, - bond_length=2.0 * unit.angstrom, - ) - loss_function = EnergyEntry( - starting_system, - conformers, - reference_energies=reference_energies, - reference_gradients=reference_gradients, - gradient_coordinate_system=gradient_coordinate_system, - ) - - actual_parameter_delta = train_parameters( - loss_function, - dict( - energy_transforms=energy_transform, - energy_metric=energy_metric, - gradient_metric=gradient_metric, - gradient_transforms=gradient_transform, - ), - [("Bonds", "[#1:1]-[#17:2]", "k")], - learning_rate=0.1, - ) - expected_parameter_delta = torch.tensor([-2.5]) - - assert actual_parameter_delta.shape == expected_parameter_delta.shape - assert torch.allclose(actual_parameter_delta, expected_parameter_delta) - - print(f"EXPECTED={expected_parameter_delta} ACTUAL={actual_parameter_delta}") diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ec382cf --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,56 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel", "versioneer"] +build-backend = "setuptools.build_meta" + +[project] +name = "descent" +description = "Differentiably compute energies of molecules using SMIRNOFF force fields." +authors = [ {name = "Simon Boothroyd"} ] +license = { text = "MIT" } +dynamic = ["version"] +readme = "README.md" +requires-python = ">=3.10" +classifiers = ["Programming Language :: Python :: 3"] + +[tool.setuptools] +zip-safe = false +include-package-data = true + +[tool.setuptools.dynamic] +version = {attr = "descent.__version__"} + +[tool.setuptools.packages.find] +namespaces = true +where = ["."] + +[tool.versioneer] +VCS = "git" +style = "pep440" +versionfile_source = "descent/_version.py" +versionfile_build = "descent/_version.py" +tag_prefix = "" +parentdir_prefix = "descent-" + +[tool.black] +line-length = 88 + +[tool.isort] +profile = "black" + +[tool.flake8] +max-line-length = 88 +ignore = ["E203", "E266", "E501", "W503"] +select = ["B","C","E","F","W","T4","B9"] + +[tool.coverage.run] +omit = ["**/tests/*", "**/_version.py"] + +[tool.coverage.report] +exclude_lines = [ + "@overload", + "pragma: no cover", + "raise NotImplementedError", + "if __name__ = .__main__.:", + "if TYPE_CHECKING:", + "if typing.TYPE_CHECKING:", +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 5518a42..0000000 --- a/setup.cfg +++ /dev/null @@ -1,48 +0,0 @@ -# Helper file to handle all configs - -[coverage:run] -# .coveragerc to control coverage.py and pytest-cov -omit = - # Omit the tests - */tests/* - # Omit generated versioneer - descent/_version.py - -[coverage:report] -exclude_lines = - @overload - pragma: no cover - raise NotImplementedError - if __name__ == .__main__.: - if TYPE_CHECKING: - -[flake8] -# Flake8, PyFlakes, etc -max-line-length = 88 -ignore = E203, E266, E501, W503 -select = B,C,E,F,W,T4,B9 - -[isort] -multi_line_output=3 -include_trailing_comma=True -force_grid_wrap=0 -use_parentheses=True -line_length=88 -known_third_party= - geometric - openff - pydantic - smirnoffee - torch - tqdm - -[versioneer] -# Automatic version numbering scheme -VCS = git -style = pep440 -versionfile_source = descent/_version.py -versionfile_build = descent/_version.py -tag_prefix = '' - -[aliases] -test = pytest diff --git a/setup.py b/setup.py deleted file mode 100644 index 4c5196c..0000000 --- a/setup.py +++ /dev/null @@ -1,36 +0,0 @@ -""" -descent - -Optimize force field parameters against QC data using `pytorch` -""" -import sys - -from setuptools import setup, find_packages -import versioneer - -short_description = __doc__.split("\n") - -# from https://github.com/pytest-dev/pytest-runner#conditional-requirement -needs_pytest = {'pytest', 'test', 'ptr'}.intersection(sys.argv) -pytest_runner = ['pytest-runner'] if needs_pytest else [] - -try: - with open("README.md", "r") as handle: - long_description = handle.read() -except IOError: - long_description = "\n".join(short_description[2:]) - - -setup( - name='descent', - author='Simon Boothroyd', - description=short_description[0], - long_description=long_description, - long_description_content_type="text/markdown", - version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), - license='MIT', - packages=find_packages(), - include_package_data=True, - setup_requires=[] + pytest_runner, -) diff --git a/versioneer.py b/versioneer.py deleted file mode 100644 index 64fea1c..0000000 --- a/versioneer.py +++ /dev/null @@ -1,1822 +0,0 @@ - -# Version: 0.18 - -"""The Versioneer - like a rocketeer, but for versions. - -The Versioneer -============== - -* like a rocketeer, but for versions! -* https://github.com/warner/python-versioneer -* Brian Warner -* License: Public Domain -* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy -* [![Latest Version] -(https://pypip.in/version/versioneer/badge.svg?style=flat) -](https://pypi.python.org/pypi/versioneer/) -* [![Build Status] -(https://travis-ci.org/warner/python-versioneer.png?branch=master) -](https://travis-ci.org/warner/python-versioneer) - -This is a tool for managing a recorded version number in distutils-based -python projects. The goal is to remove the tedious and error-prone "update -the embedded version string" step from your release process. Making a new -release should be as easy as recording a new tag in your version-control -system, and maybe making new tarballs. - - -## Quick Install - -* `pip install versioneer` to somewhere to your $PATH -* add a `[versioneer]` section to your setup.cfg (see below) -* run `versioneer install` in your source tree, commit the results - -## Version Identifiers - -Source trees come from a variety of places: - -* a version-control system checkout (mostly used by developers) -* a nightly tarball, produced by build automation -* a snapshot tarball, produced by a web-based VCS browser, like github's - "tarball from tag" feature -* a release tarball, produced by "setup.py sdist", distributed through PyPI - -Within each source tree, the version identifier (either a string or a number, -this tool is format-agnostic) can come from a variety of places: - -* ask the VCS tool itself, e.g. "git describe" (for checkouts), which knows - about recent "tags" and an absolute revision-id -* the name of the directory into which the tarball was unpacked -* an expanded VCS keyword ($Id$, etc) -* a `_version.py` created by some earlier build step - -For released software, the version identifier is closely related to a VCS -tag. Some projects use tag names that include more than just the version -string (e.g. "myproject-1.2" instead of just "1.2"), in which case the tool -needs to strip the tag prefix to extract the version identifier. For -unreleased software (between tags), the version identifier should provide -enough information to help developers recreate the same tree, while also -giving them an idea of roughly how old the tree is (after version 1.2, before -version 1.3). Many VCS systems can report a description that captures this, -for example `git describe --tags --dirty --always` reports things like -"0.7-1-g574ab98-dirty" to indicate that the checkout is one revision past the -0.7 tag, has a unique revision id of "574ab98", and is "dirty" (it has -uncommitted changes. - -The version identifier is used for multiple purposes: - -* to allow the module to self-identify its version: `myproject.__version__` -* to choose a name and prefix for a 'setup.py sdist' tarball - -## Theory of Operation - -Versioneer works by adding a special `_version.py` file into your source -tree, where your `__init__.py` can import it. This `_version.py` knows how to -dynamically ask the VCS tool for version information at import time. - -`_version.py` also contains `$Revision$` markers, and the installation -process marks `_version.py` to have this marker rewritten with a tag name -during the `git archive` command. As a result, generated tarballs will -contain enough information to get the proper version. - -To allow `setup.py` to compute a version too, a `versioneer.py` is added to -the top level of your source tree, next to `setup.py` and the `setup.cfg` -that configures it. This overrides several distutils/setuptools commands to -compute the version when invoked, and changes `setup.py build` and `setup.py -sdist` to replace `_version.py` with a small static file that contains just -the generated version data. - -## Installation - -See [INSTALL.md](./INSTALL.md) for detailed installation instructions. - -## Version-String Flavors - -Code which uses Versioneer can learn about its version string at runtime by -importing `_version` from your main `__init__.py` file and running the -`get_versions()` function. From the "outside" (e.g. in `setup.py`), you can -import the top-level `versioneer.py` and run `get_versions()`. - -Both functions return a dictionary with different flavors of version -information: - -* `['version']`: A condensed version string, rendered using the selected - style. This is the most commonly used value for the project's version - string. The default "pep440" style yields strings like `0.11`, - `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the "Styles" section - below for alternative styles. - -* `['full-revisionid']`: detailed revision identifier. For Git, this is the - full SHA1 commit id, e.g. "1076c978a8d3cfc70f408fe5974aa6c092c949ac". - -* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the - commit date in ISO 8601 format. This will be None if the date is not - available. - -* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that - this is only accurate if run in a VCS checkout, otherwise it is likely to - be False or None - -* `['error']`: if the version string could not be computed, this will be set - to a string describing the problem, otherwise it will be None. It may be - useful to throw an exception in setup.py if this is set, to avoid e.g. - creating tarballs with a version string of "unknown". - -Some variants are more useful than others. Including `full-revisionid` in a -bug report should allow developers to reconstruct the exact code being tested -(or indicate the presence of local changes that should be shared with the -developers). `version` is suitable for display in an "about" box or a CLI -`--version` output: it can be easily compared against release notes and lists -of bugs fixed in various releases. - -The installer adds the following text to your `__init__.py` to place a basic -version in `YOURPROJECT.__version__`: - - from ._version import get_versions - __version__ = get_versions()['version'] - del get_versions - -## Styles - -The setup.cfg `style=` configuration controls how the VCS information is -rendered into a version string. - -The default style, "pep440", produces a PEP440-compliant string, equal to the -un-prefixed tag name for actual releases, and containing an additional "local -version" section with more detail for in-between builds. For Git, this is -TAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags ---dirty --always`. For example "0.11+2.g1076c97.dirty" indicates that the -tree is like the "1076c97" commit but has uncommitted changes (".dirty"), and -that this commit is two revisions ("+2") beyond the "0.11" tag. For released -software (exactly equal to a known tag), the identifier will only contain the -stripped tag, e.g. "0.11". - -Other styles are available. See [details.md](details.md) in the Versioneer -source tree for descriptions. - -## Debugging - -Versioneer tries to avoid fatal errors: if something goes wrong, it will tend -to return a version of "0+unknown". To investigate the problem, run `setup.py -version`, which will run the version-lookup code in a verbose mode, and will -display the full contents of `get_versions()` (including the `error` string, -which may help identify what went wrong). - -## Known Limitations - -Some situations are known to cause problems for Versioneer. This details the -most significant ones. More can be found on Github -[issues page](https://github.com/warner/python-versioneer/issues). - -### Subprojects - -Versioneer has limited support for source trees in which `setup.py` is not in -the root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are -two common reasons why `setup.py` might not be in the root: - -* Source trees which contain multiple subprojects, such as - [Buildbot](https://github.com/buildbot/buildbot), which contains both - "master" and "slave" subprojects, each with their own `setup.py`, - `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI - distributions (and upload multiple independently-installable tarballs). -* Source trees whose main purpose is to contain a C library, but which also - provide bindings to Python (and perhaps other langauges) in subdirectories. - -Versioneer will look for `.git` in parent directories, and most operations -should get the right version string. However `pip` and `setuptools` have bugs -and implementation details which frequently cause `pip install .` from a -subproject directory to fail to find a correct version string (so it usually -defaults to `0+unknown`). - -`pip install --editable .` should work correctly. `setup.py install` might -work too. - -Pip-8.1.1 is known to have this problem, but hopefully it will get fixed in -some later version. - -[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking -this issue. The discussion in -[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the -issue from the Versioneer side in more detail. -[pip PR#3176](https://github.com/pypa/pip/pull/3176) and -[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve -pip to let Versioneer work correctly. - -Versioneer-0.16 and earlier only looked for a `.git` directory next to the -`setup.cfg`, so subprojects were completely unsupported with those releases. - -### Editable installs with setuptools <= 18.5 - -`setup.py develop` and `pip install --editable .` allow you to install a -project into a virtualenv once, then continue editing the source code (and -test) without re-installing after every change. - -"Entry-point scripts" (`setup(entry_points={"console_scripts": ..})`) are a -convenient way to specify executable scripts that should be installed along -with the python package. - -These both work as expected when using modern setuptools. When using -setuptools-18.5 or earlier, however, certain operations will cause -`pkg_resources.DistributionNotFound` errors when running the entrypoint -script, which must be resolved by re-installing the package. This happens -when the install happens with one version, then the egg_info data is -regenerated while a different version is checked out. Many setup.py commands -cause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into -a different virtualenv), so this can be surprising. - -[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes -this one, but upgrading to a newer version of setuptools should probably -resolve it. - -### Unicode version strings - -While Versioneer works (and is continually tested) with both Python 2 and -Python 3, it is not entirely consistent with bytes-vs-unicode distinctions. -Newer releases probably generate unicode version strings on py2. It's not -clear that this is wrong, but it may be surprising for applications when then -write these strings to a network connection or include them in bytes-oriented -APIs like cryptographic checksums. - -[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates -this question. - - -## Updating Versioneer - -To upgrade your project to a new release of Versioneer, do the following: - -* install the new Versioneer (`pip install -U versioneer` or equivalent) -* edit `setup.cfg`, if necessary, to include any new configuration settings - indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details. -* re-run `versioneer install` in your source tree, to replace - `SRC/_version.py` -* commit any changed files - -## Future Directions - -This tool is designed to make it easily extended to other version-control -systems: all VCS-specific components are in separate directories like -src/git/ . The top-level `versioneer.py` script is assembled from these -components by running make-versioneer.py . In the future, make-versioneer.py -will take a VCS name as an argument, and will construct a version of -`versioneer.py` that is specific to the given VCS. It might also take the -configuration arguments that are currently provided manually during -installation by editing setup.py . Alternatively, it might go the other -direction and include code from all supported VCS systems, reducing the -number of intermediate scripts. - - -## License - -To make Versioneer easier to embed, all its code is dedicated to the public -domain. The `_version.py` that it creates is also in the public domain. -Specifically, both are released under the Creative Commons "Public Domain -Dedication" license (CC0-1.0), as described in -https://creativecommons.org/publicdomain/zero/1.0/ . - -""" - -from __future__ import print_function -try: - import configparser -except ImportError: - import ConfigParser as configparser -import errno -import json -import os -import re -import subprocess -import sys - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_root(): - """Get the project root directory. - - We require that all commands are run from the project root, i.e. the - directory that contains setup.py, setup.cfg, and versioneer.py . - """ - root = os.path.realpath(os.path.abspath(os.getcwd())) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - # allow 'python path/to/setup.py COMMAND' - root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) - setup_py = os.path.join(root, "setup.py") - versioneer_py = os.path.join(root, "versioneer.py") - if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") - raise VersioneerBadRootError(err) - try: - # Certain runtime workflows (setup.py install/develop in a setuptools - # tree) execute all dependencies in a single python process, so - # "versioneer" may be imported multiple times, and python's shared - # module-import table will cache the first one. So we can't use - # os.path.dirname(__file__), as that will find whichever - # versioneer.py was first imported, even in later projects. - me = os.path.realpath(os.path.abspath(__file__)) - me_dir = os.path.normcase(os.path.splitext(me)[0]) - vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) - if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) - except NameError: - pass - return root - - -def get_config_from_root(root): - """Read the project setup.cfg file to determine Versioneer config.""" - # This might raise EnvironmentError (if setup.cfg is missing), or - # configparser.NoSectionError (if it lacks a [versioneer] section), or - # configparser.NoOptionError (if it lacks "VCS="). See the docstring at - # the top of versioneer.py for instructions on writing your setup.cfg . - setup_cfg = os.path.join(root, "setup.cfg") - parser = configparser.SafeConfigParser() - with open(setup_cfg, "r") as f: - parser.readfp(f) - VCS = parser.get("versioneer", "VCS") # mandatory - - def get(parser, name): - if parser.has_option("versioneer", name): - return parser.get("versioneer", name) - return None - cfg = VersioneerConfig() - cfg.VCS = VCS - cfg.style = get(parser, "style") or "" - cfg.versionfile_source = get(parser, "versionfile_source") - cfg.versionfile_build = get(parser, "versionfile_build") - cfg.tag_prefix = get(parser, "tag_prefix") - if cfg.tag_prefix in ("''", '""'): - cfg.tag_prefix = "" - cfg.parentdir_prefix = get(parser, "parentdir_prefix") - cfg.verbose = get(parser, "verbose") - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -# these dictionaries contain VCS-specific tools -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %s" % dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %s" % (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %s (error)" % dispcmd) - print("stdout was %s" % stdout) - return None, p.returncode - return stdout, p.returncode - - -LONG_VERSION_PY['git'] = ''' -# This file helps to compute a version number in source trees obtained from -# git-archive tarball (such as those provided by githubs download-from-tag -# feature). Distribution tarballs (built by setup.py sdist) and build -# directories (produced by setup.py build) will contain a much shorter file -# that just contains the computed version number. - -# This file is released into the public domain. Generated by -# versioneer-0.18 (https://github.com/warner/python-versioneer) - -"""Git implementation of _version.py.""" - -import errno -import os -import re -import subprocess -import sys - - -def get_keywords(): - """Get the keywords needed to look up the version information.""" - # these strings will be replaced by git during git-archive. - # setup.py/versioneer.py will grep for the variable names, so they must - # each be defined on a line of their own. _version.py will just call - # get_keywords(). - git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" - git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" - git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" - keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} - return keywords - - -class VersioneerConfig: - """Container for Versioneer configuration parameters.""" - - -def get_config(): - """Create, populate and return the VersioneerConfig() object.""" - # these strings are filled in when 'setup.py versioneer' creates - # _version.py - cfg = VersioneerConfig() - cfg.VCS = "git" - cfg.style = "%(STYLE)s" - cfg.tag_prefix = "%(TAG_PREFIX)s" - cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" - cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" - cfg.verbose = False - return cfg - - -class NotThisMethod(Exception): - """Exception raised if a method is not valid for the current scenario.""" - - -LONG_VERSION_PY = {} -HANDLERS = {} - - -def register_vcs_handler(vcs, method): # decorator - """Decorator to mark a method as the handler for a particular VCS.""" - def decorate(f): - """Store f in HANDLERS[vcs][method].""" - if vcs not in HANDLERS: - HANDLERS[vcs] = {} - HANDLERS[vcs][method] = f - return f - return decorate - - -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): - """Call the given command(s).""" - assert isinstance(commands, list) - p = None - for c in commands: - try: - dispcmd = str([c] + args) - # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) - break - except EnvironmentError: - e = sys.exc_info()[1] - if e.errno == errno.ENOENT: - continue - if verbose: - print("unable to run %%s" %% dispcmd) - print(e) - return None, None - else: - if verbose: - print("unable to find command, tried %%s" %% (commands,)) - return None, None - stdout = p.communicate()[0].strip() - if sys.version_info[0] >= 3: - stdout = stdout.decode() - if p.returncode != 0: - if verbose: - print("unable to run %%s (error)" %% dispcmd) - print("stdout was %%s" %% stdout) - return None, p.returncode - return stdout, p.returncode - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %%s but none started with prefix %%s" %% - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %%d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%%s', no digits" %% ",".join(refs - tags)) - if verbose: - print("likely tags: %%s" %% ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %%s" %% r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %%s not under git control" %% root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%%s*" %% tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%%s'" - %% describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%%s' doesn't start with prefix '%%s'" - print(fmt %% (full_tag, tag_prefix)) - pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" - %% (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%%d" %% pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%%d" %% pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%%s" %% pieces["short"] - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%%s" %% pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%%d" %% pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%%s'" %% style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -def get_versions(): - """Get version information or return default if unable to do so.""" - # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have - # __file__, we can work backwards from there to the root. Some - # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which - # case we can only use expanded keywords. - - cfg = get_config() - verbose = cfg.verbose - - try: - return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, - verbose) - except NotThisMethod: - pass - - try: - root = os.path.realpath(__file__) - # versionfile_source is the relative path from the top of the source - # tree (where the .git directory might live) to this file. Invert - # this to find the root from __file__. - for i in cfg.versionfile_source.split('/'): - root = os.path.dirname(root) - except NameError: - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to find root of source tree", - "date": None} - - try: - pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) - return render(pieces, cfg.style) - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - except NotThisMethod: - pass - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, - "error": "unable to compute version", "date": None} -''' - - -@register_vcs_handler("git", "get_keywords") -def git_get_keywords(versionfile_abs): - """Extract version information from the given file.""" - # the code embedded in _version.py can just fetch the value of these - # keywords. When used from setup.py, we don't want to import _version.py, - # so we do it with a regexp instead. This function is not used from - # _version.py. - keywords = {} - try: - f = open(versionfile_abs, "r") - for line in f.readlines(): - if line.strip().startswith("git_refnames ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["refnames"] = mo.group(1) - if line.strip().startswith("git_full ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["full"] = mo.group(1) - if line.strip().startswith("git_date ="): - mo = re.search(r'=\s*"(.*)"', line) - if mo: - keywords["date"] = mo.group(1) - f.close() - except EnvironmentError: - pass - return keywords - - -@register_vcs_handler("git", "keywords") -def git_versions_from_keywords(keywords, tag_prefix, verbose): - """Get version information from git keywords.""" - if not keywords: - raise NotThisMethod("no keywords at all, weird") - date = keywords.get("date") - if date is not None: - # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant - # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 - # -like" string, which we must then edit to make compliant), because - # it's been around since git-1.5.3, and it's too difficult to - # discover which version we're using, or to work around using an - # older one. - date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - refnames = keywords["refnames"].strip() - if refnames.startswith("$Format"): - if verbose: - print("keywords are unexpanded, not using") - raise NotThisMethod("unexpanded keywords, not a git-archive tarball") - refs = set([r.strip() for r in refnames.strip("()").split(",")]) - # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of - # just "foo-1.0". If we see a "tag: " prefix, prefer those. - TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) - if not tags: - # Either we're using git < 1.8.3, or there really are no tags. We use - # a heuristic: assume all version tags have a digit. The old git %d - # expansion behaves like git log --decorate=short and strips out the - # refs/heads/ and refs/tags/ prefixes that would let us distinguish - # between branches and tags. By ignoring refnames without digits, we - # filter out many common branch names like "release" and - # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) - if verbose: - print("discarding '%s', no digits" % ",".join(refs - tags)) - if verbose: - print("likely tags: %s" % ",".join(sorted(tags))) - for ref in sorted(tags): - # sorting will prefer e.g. "2.0" over "2.0rc1" - if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] - if verbose: - print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} - # no suitable tags, so version is "0+unknown", but full hex is still there - if verbose: - print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} - - -@register_vcs_handler("git", "pieces_from_vcs") -def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): - """Get version from 'git describe' in the root of the source tree. - - This only gets called if the git-archive 'subst' keywords were *not* - expanded, and _version.py hasn't already been rewritten with a short - version string, meaning we're inside a checked out source tree. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) - if rc != 0: - if verbose: - print("Directory %s not under git control" % root) - raise NotThisMethod("'git rev-parse --git-dir' returned error") - - # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] - # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) - # --long was added in git-1.5.5 - if describe_out is None: - raise NotThisMethod("'git describe' failed") - describe_out = describe_out.strip() - full_out, rc = run_command(GITS, ["rev-parse", "HEAD"], cwd=root) - if full_out is None: - raise NotThisMethod("'git rev-parse' failed") - full_out = full_out.strip() - - pieces = {} - pieces["long"] = full_out - pieces["short"] = full_out[:7] # maybe improved later - pieces["error"] = None - - # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] - # TAG might have hyphens. - git_describe = describe_out - - # look for -dirty suffix - dirty = git_describe.endswith("-dirty") - pieces["dirty"] = dirty - if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] - - # now we have TAG-NUM-gHEX or HEX - - if "-" in git_describe: - # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) - if not mo: - # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) - return pieces - - # tag - full_tag = mo.group(1) - if not full_tag.startswith(tag_prefix): - if verbose: - fmt = "tag '%s' doesn't start with prefix '%s'" - print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) - return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] - - # distance: number of commits since tag - pieces["distance"] = int(mo.group(2)) - - # commit: short hex revision ID - pieces["short"] = mo.group(3) - - else: - # HEX: no tags - pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) - pieces["distance"] = int(count_out) # total number of commits - - # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() - pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) - - return pieces - - -def do_vcs_install(manifest_in, versionfile_source, ipy): - """Git-specific installation logic for Versioneer. - - For Git, this means creating/changing .gitattributes to mark _version.py - for export-subst keyword substitution. - """ - GITS = ["git"] - if sys.platform == "win32": - GITS = ["git.cmd", "git.exe"] - files = [manifest_in, versionfile_source] - if ipy: - files.append(ipy) - try: - me = __file__ - if me.endswith(".pyc") or me.endswith(".pyo"): - me = os.path.splitext(me)[0] + ".py" - versioneer_file = os.path.relpath(me) - except NameError: - versioneer_file = "versioneer.py" - files.append(versioneer_file) - present = False - try: - f = open(".gitattributes", "r") - for line in f.readlines(): - if line.strip().startswith(versionfile_source): - if "export-subst" in line.strip().split()[1:]: - present = True - f.close() - except EnvironmentError: - pass - if not present: - f = open(".gitattributes", "a+") - f.write("%s export-subst\n" % versionfile_source) - f.close() - files.append(".gitattributes") - run_command(GITS, ["add", "--"] + files) - - -def versions_from_parentdir(parentdir_prefix, root, verbose): - """Try to determine the version from the parent directory name. - - Source tarballs conventionally unpack into a directory that includes both - the project name and a version string. We will also support searching up - two directory levels for an appropriately named parent directory - """ - rootdirs = [] - - for i in range(3): - dirname = os.path.basename(root) - if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} - else: - rootdirs.append(root) - root = os.path.dirname(root) # up a level - - if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) - raise NotThisMethod("rootdir doesn't start with parentdir_prefix") - - -SHORT_VERSION_PY = """ -# This file was generated by 'versioneer.py' (0.18) from -# revision-control system data, or from the parent directory name of an -# unpacked source archive. Distribution tarballs contain a pre-generated copy -# of this file. - -import json - -version_json = ''' -%s -''' # END VERSION_JSON - - -def get_versions(): - return json.loads(version_json) -""" - - -def versions_from_file(filename): - """Try to determine the version from _version.py if present.""" - try: - with open(filename) as f: - contents = f.read() - except EnvironmentError: - raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) - if not mo: - raise NotThisMethod("no version_json in _version.py") - return json.loads(mo.group(1)) - - -def write_to_version_file(filename, versions): - """Write the given version number to the given _version.py file.""" - os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) - with open(filename, "w") as f: - f.write(SHORT_VERSION_PY % contents) - - print("set %s to '%s'" % (filename, versions["version"])) - - -def plus_or_dot(pieces): - """Return a + if we don't already have one, else return a .""" - if "+" in pieces.get("closest-tag", ""): - return "." - return "+" - - -def render_pep440(pieces): - """Build up version string, with post-release "local version identifier". - - Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you - get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty - - Exceptions: - 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += plus_or_dot(pieces) - rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - else: - # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) - if pieces["dirty"]: - rendered += ".dirty" - return rendered - - -def render_pep440_pre(pieces): - """TAG[.post.devDISTANCE] -- No -dirty. - - Exceptions: - 1: no tags. 0.post.devDISTANCE - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += ".post.dev%d" % pieces["distance"] - else: - # exception #1 - rendered = "0.post.dev%d" % pieces["distance"] - return rendered - - -def render_pep440_post(pieces): - """TAG[.postDISTANCE[.dev0]+gHEX] . - - The ".dev0" means dirty. Note that .dev0 sorts backwards - (a dirty tree will appear "older" than the corresponding clean one), - but you shouldn't be releasing software with -dirty anyways. - - Exceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += plus_or_dot(pieces) - rendered += "g%s" % pieces["short"] - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - rendered += "+g%s" % pieces["short"] - return rendered - - -def render_pep440_old(pieces): - """TAG[.postDISTANCE[.dev0]] . - - The ".dev0" means dirty. - - Eexceptions: - 1: no tags. 0.postDISTANCE[.dev0] - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"] or pieces["dirty"]: - rendered += ".post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - else: - # exception #1 - rendered = "0.post%d" % pieces["distance"] - if pieces["dirty"]: - rendered += ".dev0" - return rendered - - -def render_git_describe(pieces): - """TAG[-DISTANCE-gHEX][-dirty]. - - Like 'git describe --tags --dirty --always'. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - if pieces["distance"]: - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render_git_describe_long(pieces): - """TAG-DISTANCE-gHEX[-dirty]. - - Like 'git describe --tags --dirty --always -long'. - The distance/hash is unconditional. - - Exceptions: - 1: no tags. HEX[-dirty] (note: no 'g' prefix) - """ - if pieces["closest-tag"]: - rendered = pieces["closest-tag"] - rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) - else: - # exception #1 - rendered = pieces["short"] - if pieces["dirty"]: - rendered += "-dirty" - return rendered - - -def render(pieces, style): - """Render the given version pieces into the requested style.""" - if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} - - if not style or style == "default": - style = "pep440" # the default - - if style == "pep440": - rendered = render_pep440(pieces) - elif style == "pep440-pre": - rendered = render_pep440_pre(pieces) - elif style == "pep440-post": - rendered = render_pep440_post(pieces) - elif style == "pep440-old": - rendered = render_pep440_old(pieces) - elif style == "git-describe": - rendered = render_git_describe(pieces) - elif style == "git-describe-long": - rendered = render_git_describe_long(pieces) - else: - raise ValueError("unknown style '%s'" % style) - - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} - - -class VersioneerBadRootError(Exception): - """The project root directory is unknown or missing key files.""" - - -def get_versions(verbose=False): - """Get the project version from whatever source is available. - - Returns dict with two keys: 'version' and 'full'. - """ - if "versioneer" in sys.modules: - # see the discussion in cmdclass.py:get_cmdclass() - del sys.modules["versioneer"] - - root = get_root() - cfg = get_config_from_root(root) - - assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" - handlers = HANDLERS.get(cfg.VCS) - assert handlers, "unrecognized VCS '%s'" % cfg.VCS - verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" - assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" - - versionfile_abs = os.path.join(root, cfg.versionfile_source) - - # extract version from first of: _version.py, VCS command (e.g. 'git - # describe'), parentdir. This is meant to work for developers using a - # source checkout, for users of a tarball created by 'setup.py sdist', - # and for users of a tarball/zipball created by 'git archive' or github's - # download-from-tag feature or the equivalent in other VCSes. - - get_keywords_f = handlers.get("get_keywords") - from_keywords_f = handlers.get("keywords") - if get_keywords_f and from_keywords_f: - try: - keywords = get_keywords_f(versionfile_abs) - ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) - if verbose: - print("got version from expanded keyword %s" % ver) - return ver - except NotThisMethod: - pass - - try: - ver = versions_from_file(versionfile_abs) - if verbose: - print("got version from file %s %s" % (versionfile_abs, ver)) - return ver - except NotThisMethod: - pass - - from_vcs_f = handlers.get("pieces_from_vcs") - if from_vcs_f: - try: - pieces = from_vcs_f(cfg.tag_prefix, root, verbose) - ver = render(pieces, cfg.style) - if verbose: - print("got version from VCS %s" % ver) - return ver - except NotThisMethod: - pass - - try: - if cfg.parentdir_prefix: - ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) - if verbose: - print("got version from parentdir %s" % ver) - return ver - except NotThisMethod: - pass - - if verbose: - print("unable to compute version") - - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} - - -def get_version(): - """Get the short version string for this project.""" - return get_versions()["version"] - - -def get_cmdclass(): - """Get the custom setuptools/distutils subclasses used by Versioneer.""" - if "versioneer" in sys.modules: - del sys.modules["versioneer"] - # this fixes the "python setup.py develop" case (also 'install' and - # 'easy_install .'), in which subdependencies of the main project are - # built (using setup.py bdist_egg) in the same python process. Assume - # a main project A and a dependency B, which use different versions - # of Versioneer. A's setup.py imports A's Versioneer, leaving it in - # sys.modules by the time B's setup.py is executed, causing B to run - # with the wrong versioneer. Setuptools wraps the sub-dep builds in a - # sandbox that restores sys.modules to it's pre-build state, so the - # parent is protected against the child's "import versioneer". By - # removing ourselves from sys.modules here, before the child build - # happens, we protect the child from the parent's versioneer too. - # Also see https://github.com/warner/python-versioneer/issues/52 - - cmds = {} - - # we add "version" to both distutils and setuptools - from distutils.core import Command - - class cmd_version(Command): - description = "report generated version string" - user_options = [] - boolean_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - vers = get_versions(verbose=True) - print("Version: %s" % vers["version"]) - print(" full-revisionid: %s" % vers.get("full-revisionid")) - print(" dirty: %s" % vers.get("dirty")) - print(" date: %s" % vers.get("date")) - if vers["error"]: - print(" error: %s" % vers["error"]) - cmds["version"] = cmd_version - - # we override "build_py" in both distutils and setuptools - # - # most invocation pathways end up running build_py: - # distutils/build -> build_py - # distutils/install -> distutils/build ->.. - # setuptools/bdist_wheel -> distutils/install ->.. - # setuptools/bdist_egg -> distutils/install_lib -> build_py - # setuptools/install -> bdist_egg ->.. - # setuptools/develop -> ? - # pip install: - # copies source tree to a tempdir before running egg_info/etc - # if .git isn't copied too, 'git describe' will fail - # then does setup.py bdist_wheel, or sometimes setup.py install - # setup.py egg_info -> ? - - # we override different "build_py" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.build_py import build_py as _build_py - else: - from distutils.command.build_py import build_py as _build_py - - class cmd_build_py(_build_py): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - _build_py.run(self) - # now locate _version.py in the new build/ directory and replace - # it with an updated value - if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - cmds["build_py"] = cmd_build_py - - if "cx_Freeze" in sys.modules: # cx_freeze enabled? - from cx_Freeze.dist import build_exe as _build_exe - # nczeczulin reports that py2exe won't like the pep440-style string - # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. - # setup(console=[{ - # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION - # "product_version": versioneer.get_version(), - # ... - - class cmd_build_exe(_build_exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _build_exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["build_exe"] = cmd_build_exe - del cmds["build_py"] - - if 'py2exe' in sys.modules: # py2exe enabled? - try: - from py2exe.distutils_buildexe import py2exe as _py2exe # py3 - except ImportError: - from py2exe.build_exe import py2exe as _py2exe # py2 - - class cmd_py2exe(_py2exe): - def run(self): - root = get_root() - cfg = get_config_from_root(root) - versions = get_versions() - target_versionfile = cfg.versionfile_source - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, versions) - - _py2exe.run(self) - os.unlink(target_versionfile) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - cmds["py2exe"] = cmd_py2exe - - # we override different "sdist" commands for both environments - if "setuptools" in sys.modules: - from setuptools.command.sdist import sdist as _sdist - else: - from distutils.command.sdist import sdist as _sdist - - class cmd_sdist(_sdist): - def run(self): - versions = get_versions() - self._versioneer_generated_versions = versions - # unless we update this, the command will keep using the old - # version - self.distribution.metadata.version = versions["version"] - return _sdist.run(self) - - def make_release_tree(self, base_dir, files): - root = get_root() - cfg = get_config_from_root(root) - _sdist.make_release_tree(self, base_dir, files) - # now locate _version.py in the new base_dir directory - # (remembering that it may be a hardlink) and replace it with an - # updated value - target_versionfile = os.path.join(base_dir, cfg.versionfile_source) - print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) - cmds["sdist"] = cmd_sdist - - return cmds - - -CONFIG_ERROR = """ -setup.cfg is missing the necessary Versioneer configuration. You need -a section like: - - [versioneer] - VCS = git - style = pep440 - versionfile_source = src/myproject/_version.py - versionfile_build = myproject/_version.py - tag_prefix = - parentdir_prefix = myproject- - -You will also need to edit your setup.py to use the results: - - import versioneer - setup(version=versioneer.get_version(), - cmdclass=versioneer.get_cmdclass(), ...) - -Please read the docstring in ./versioneer.py for configuration instructions, -edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. -""" - -SAMPLE_CONFIG = """ -# See the docstring in versioneer.py for instructions. Note that you must -# re-run 'versioneer.py setup' after changing this section, and commit the -# resulting files. - -[versioneer] -#VCS = git -#style = pep440 -#versionfile_source = -#versionfile_build = -#tag_prefix = -#parentdir_prefix = - -""" - -INIT_PY_SNIPPET = """ -from ._version import get_versions -__version__ = get_versions()['version'] -del get_versions -""" - - -def do_setup(): - """Main VCS-independent setup function for installing Versioneer.""" - root = get_root() - try: - cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: - if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) - with open(os.path.join(root, "setup.cfg"), "a") as f: - f.write(SAMPLE_CONFIG) - print(CONFIG_ERROR, file=sys.stderr) - return 1 - - print(" creating %s" % cfg.versionfile_source) - with open(cfg.versionfile_source, "w") as f: - LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") - if os.path.exists(ipy): - try: - with open(ipy, "r") as f: - old = f.read() - except EnvironmentError: - old = "" - if INIT_PY_SNIPPET not in old: - print(" appending to %s" % ipy) - with open(ipy, "a") as f: - f.write(INIT_PY_SNIPPET) - else: - print(" %s unmodified" % ipy) - else: - print(" %s doesn't exist, ok" % ipy) - ipy = None - - # Make sure both the top-level "versioneer.py" and versionfile_source - # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so - # they'll be copied into source distributions. Pip won't be able to - # install the package without this. - manifest_in = os.path.join(root, "MANIFEST.in") - simple_includes = set() - try: - with open(manifest_in, "r") as f: - for line in f: - if line.startswith("include "): - for include in line.split()[1:]: - simple_includes.add(include) - except EnvironmentError: - pass - # That doesn't cover everything MANIFEST.in can do - # (http://docs.python.org/2/distutils/sourcedist.html#commands), so - # it might give some false negatives. Appending redundant 'include' - # lines is safe, though. - if "versioneer.py" not in simple_includes: - print(" appending 'versioneer.py' to MANIFEST.in") - with open(manifest_in, "a") as f: - f.write("include versioneer.py\n") - else: - print(" 'versioneer.py' already in MANIFEST.in") - if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) - with open(manifest_in, "a") as f: - f.write("include %s\n" % cfg.versionfile_source) - else: - print(" versionfile_source already in MANIFEST.in") - - # Make VCS-specific changes. For git, this means creating/changing - # .gitattributes to mark _version.py for export-subst keyword - # substitution. - do_vcs_install(manifest_in, cfg.versionfile_source, ipy) - return 0 - - -def scan_setup_py(): - """Validate the contents of setup.py against Versioneer's expectations.""" - found = set() - setters = False - errors = 0 - with open("setup.py", "r") as f: - for line in f.readlines(): - if "import versioneer" in line: - found.add("import") - if "versioneer.get_cmdclass()" in line: - found.add("cmdclass") - if "versioneer.get_version()" in line: - found.add("get_version") - if "versioneer.VCS" in line: - setters = True - if "versioneer.versionfile_source" in line: - setters = True - if len(found) != 3: - print("") - print("Your setup.py appears to be missing some important items") - print("(but I might be wrong). Please make sure it has something") - print("roughly like the following:") - print("") - print(" import versioneer") - print(" setup( version=versioneer.get_version(),") - print(" cmdclass=versioneer.get_cmdclass(), ...)") - print("") - errors += 1 - if setters: - print("You should remove lines like 'versioneer.VCS = ' and") - print("'versioneer.versionfile_source = ' . This configuration") - print("now lives in setup.cfg, and should be removed from setup.py") - print("") - errors += 1 - return errors - - -if __name__ == "__main__": - cmd = sys.argv[1] - if cmd == "setup": - errors = do_setup() - errors += scan_setup_py() - if errors: - sys.exit(1)