Skip to content

Commit

Permalink
Precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
FNTwin committed Sep 25, 2024
1 parent e923e49 commit 33bf207
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 18 deletions.
53 changes: 35 additions & 18 deletions docs/optim.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
Expand All @@ -40,7 +40,16 @@
"hide-cell"
]
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/bin/bash: -c: line 1: unexpected EOF while looking for matching `\"'\n",
"/bin/bash: -c: line 2: syntax error: unexpected end of file\n"
]
}
],
"source": [
"import sys\n",
"\n",
Expand All @@ -52,14 +61,22 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {
"id": "j7HdZdtPdEXU",
"tags": [
"hide-cell"
]
},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
Expand All @@ -84,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
Expand All @@ -96,10 +113,10 @@
"outputs": [
{
"data": {
"application/3dmoljs_load.v0": "<div id=\"3dmolviewer_17146571511503549\" style=\"position: relative; width: 640px; height: 480px;\">\n <p id=\"3dmolwarning_17146571511503549\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n </div>\n<script>\n\nvar loadScriptAsync = function(uri){\n return new Promise((resolve, reject) => {\n //this is to ignore the existence of requirejs amd\n var savedexports, savedmodule;\n if (typeof exports !== 'undefined') savedexports = exports;\n else exports = {}\n if (typeof module !== 'undefined') savedmodule = module;\n else module = {}\n\n var tag = document.createElement('script');\n tag.src = uri;\n tag.async = true;\n tag.onload = () => {\n exports = savedexports;\n module = savedmodule;\n resolve();\n };\n var firstScriptTag = document.getElementsByTagName('script')[0];\n firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);\n});\n};\n\nif(typeof $3Dmolpromise === 'undefined') {\n$3Dmolpromise = null;\n $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.1.0/3Dmol-min.js');\n}\n\nvar viewer_17146571511503549 = null;\nvar warn = document.getElementById(\"3dmolwarning_17146571511503549\");\nif(warn) {\n warn.parentNode.removeChild(warn);\n}\n$3Dmolpromise.then(function() {\nviewer_17146571511503549 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_17146571511503549\"),{backgroundColor:\"white\"});\nviewer_17146571511503549.zoomTo();\n\tviewer_17146571511503549.addModel(\"5\\n\\nC\\t0.\\t0.\\t0.\\nH\\t 0.62558327\\t-0.62558327\\t 0.62558327\\nH\\t-0.62558327\\t 0.62558327\\t 0.62558327\\nH\\t 0.62558327\\t 0.62558327\\t-0.62558327\\nH\\t-0.62558327\\t-0.62558327\\t-0.62558327\\n\");\n\tviewer_17146571511503549.setStyle({\"stick\": {\"radius\": 0.1}});\nviewer_17146571511503549.render();\n});\n</script>",
"application/3dmoljs_load.v0": "<div id=\"3dmolviewer_17272701051676118\" style=\"position: relative; width: 640px; height: 480px;\">\n <p id=\"3dmolwarning_17272701051676118\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n </div>\n<script>\n\nvar loadScriptAsync = function(uri){\n return new Promise((resolve, reject) => {\n //this is to ignore the existence of requirejs amd\n var savedexports, savedmodule;\n if (typeof exports !== 'undefined') savedexports = exports;\n else exports = {}\n if (typeof module !== 'undefined') savedmodule = module;\n else module = {}\n\n var tag = document.createElement('script');\n tag.src = uri;\n tag.async = true;\n tag.onload = () => {\n exports = savedexports;\n module = savedmodule;\n resolve();\n };\n var firstScriptTag = document.getElementsByTagName('script')[0];\n firstScriptTag.parentNode.insertBefore(tag, firstScriptTag);\n});\n};\n\nif(typeof $3Dmolpromise === 'undefined') {\n$3Dmolpromise = null;\n $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.0/3Dmol-min.js');\n}\n\nvar viewer_17272701051676118 = null;\nvar warn = document.getElementById(\"3dmolwarning_17272701051676118\");\nif(warn) {\n warn.parentNode.removeChild(warn);\n}\n$3Dmolpromise.then(function() {\nviewer_17272701051676118 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_17272701051676118\"),{backgroundColor:\"white\"});\nviewer_17272701051676118.zoomTo();\n\tviewer_17272701051676118.addModel(\"5\\n\\nC\\t0.\\t0.\\t0.\\nH\\t 0.62558327\\t-0.62558327\\t 0.62558327\\nH\\t-0.62558327\\t 0.62558327\\t 0.62558327\\nH\\t 0.62558327\\t 0.62558327\\t-0.62558327\\nH\\t-0.62558327\\t-0.62558327\\t-0.62558327\\n\");\n\tviewer_17272701051676118.setStyle({\"stick\": {\"radius\": 0.1}});\nviewer_17272701051676118.render();\n});\n</script>",
"text/html": [
"<div id=\"3dmolviewer_17146571511503549\" style=\"position: relative; width: 640px; height: 480px;\">\n",
" <p id=\"3dmolwarning_17146571511503549\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n",
"<div id=\"3dmolviewer_17272701051676118\" style=\"position: relative; width: 640px; height: 480px;\">\n",
" <p id=\"3dmolwarning_17272701051676118\" style=\"background-color:#ffcccc;color:black\">3Dmol.js failed to load for some reason. Please check your browser console for error messages.<br></p>\n",
" </div>\n",
"<script>\n",
"\n",
Expand Down Expand Up @@ -127,20 +144,20 @@
"\n",
"if(typeof $3Dmolpromise === 'undefined') {\n",
"$3Dmolpromise = null;\n",
" $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.1.0/3Dmol-min.js');\n",
" $3Dmolpromise = loadScriptAsync('https://cdnjs.cloudflare.com/ajax/libs/3Dmol/2.4.0/3Dmol-min.js');\n",
"}\n",
"\n",
"var viewer_17146571511503549 = null;\n",
"var warn = document.getElementById(\"3dmolwarning_17146571511503549\");\n",
"var viewer_17272701051676118 = null;\n",
"var warn = document.getElementById(\"3dmolwarning_17272701051676118\");\n",
"if(warn) {\n",
" warn.parentNode.removeChild(warn);\n",
"}\n",
"$3Dmolpromise.then(function() {\n",
"viewer_17146571511503549 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_17146571511503549\"),{backgroundColor:\"white\"});\n",
"viewer_17146571511503549.zoomTo();\n",
"\tviewer_17146571511503549.addModel(\"5\\n\\nC\\t0.\\t0.\\t0.\\nH\\t 0.62558327\\t-0.62558327\\t 0.62558327\\nH\\t-0.62558327\\t 0.62558327\\t 0.62558327\\nH\\t 0.62558327\\t 0.62558327\\t-0.62558327\\nH\\t-0.62558327\\t-0.62558327\\t-0.62558327\\n\");\n",
"\tviewer_17146571511503549.setStyle({\"stick\": {\"radius\": 0.1}});\n",
"viewer_17146571511503549.render();\n",
"viewer_17272701051676118 = $3Dmol.createViewer(document.getElementById(\"3dmolviewer_17272701051676118\"),{backgroundColor:\"white\"});\n",
"viewer_17272701051676118.zoomTo();\n",
"\tviewer_17272701051676118.addModel(\"5\\n\\nC\\t0.\\t0.\\t0.\\nH\\t 0.62558327\\t-0.62558327\\t 0.62558327\\nH\\t-0.62558327\\t 0.62558327\\t 0.62558327\\nH\\t 0.62558327\\t 0.62558327\\t-0.62558327\\nH\\t-0.62558327\\t-0.62558327\\t-0.62558327\\n\");\n",
"\tviewer_17272701051676118.setStyle({\"stick\": {\"radius\": 0.1}});\n",
"viewer_17272701051676118.render();\n",
"});\n",
"</script>"
]
Expand All @@ -154,7 +171,7 @@
"Structure(atomic_number=i64[5](numpy), position=f64[5,3](numpy))"
]
},
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -347,7 +364,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.11.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
20 changes: 20 additions & 0 deletions mess/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mess.basis import Basis, basisset
from mess.structure import Structure
from mess.units import to_bohr
from mess.package_utils import requires_package


def to_pyscf(structure: Structure, basis_name: str = "sto-3g") -> "gto.Mole":
Expand All @@ -34,3 +35,22 @@ def from_pyscf(mol: "gto.Mole") -> Tuple[Structure, Basis]:
basis = basisset(structure, basis_name=mol.basis)

return structure, basis


@requires_package("pyquante2")
def from_pyquante(name: str) -> Structure:
"""Load molecular structure from pyquante2.geo.samples module
Args:
name (str): Possible names include ch4, c6h6, aspirin, caffeine, hmx, petn,
prozan, rdx, taxol, tylenol, viagara, zoloft
Returns:
Structure
"""
from pyquante2.geo import samples

pqmol = getattr(samples, name)
atomic_number, position = zip(*[(a.Z, a.r) for a in pqmol])
atomic_number, position = [np.asarray(x) for x in (atomic_number, position)]
return Structure(atomic_number, position)
103 changes: 103 additions & 0 deletions mess/package_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import importlib
from functools import wraps
from typing import Any, Callable, TypeVar

F = TypeVar("F", bound=Callable[..., Any])


class MissingOptionalDependencyError(BaseException):
"""
An exception raised when an optional dependency is required
but cannot be found.
Attributes
----------
library_name
The name of the missing library.
"""

def __init__(self, library_name: str):
"""
Parameters
----------
library_name
The name of the missing library.
license_issue
Whether the library was importable but was unusable due
to a missing license.
"""

message = f"The required {library_name} module could not be imported."

super(MissingOptionalDependencyError, self).__init__(message)

self.library_name = library_name


def has_package(package_name: str) -> bool:
"""
Helper function to generically check if a Python package is installed.
Intended to be used to check for optional dependencies.
Parameters
----------
package_name : str
The name of the Python package to check the availability of
Returns
-------
package_available : bool
Boolean indicator if the package is available or not
Examples
--------
>>> has_numpy = has_package('numpy')
>>> has_numpy
True
>>> has_foo = has_package('other_non_installed_package')
>>> has_foo
False
"""
try:
importlib.import_module(package_name)
except ModuleNotFoundError:
return False
return True


def requires_package(package_name: str) -> Callable[..., Any]:
"""
Helper function to denote that a funciton requires some optional
dependency. A function decorated with this decorator will raise
`MissingOptionalDependencyError` if the package is not found by
`importlib.import_module()`.
Parameters
----------
package_name : str
The name of the module to be imported.
Raises
------
MissingOptionalDependencyError
"""

def inner_decorator(function: F) -> F:
@wraps(function)
def wrapper(*args, **kwargs):
import importlib

try:
importlib.import_module(package_name)
except ImportError:
raise MissingOptionalDependencyError(library_name=package_name)
except Exception as e:
raise e

return function(*args, **kwargs)

return wrapper

return inner_decorator
18 changes: 18 additions & 0 deletions test/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,34 @@

from mess.basis import basisset
from mess.hamiltonian import Hamiltonian, minimise
from mess.zeropad_integrals import overlap_basis_zeropad
from mess.integrals import (
eri_basis_sparse,
kinetic_basis,
nuclear_basis,
overlap_basis,
)
from mess.structure import molecule
from mess.interop import from_pyquante
from mess.package_utils import has_package
from conftest import is_mem_limited


@pytest.mark.parametrize("func", [overlap_basis, overlap_basis_zeropad, kinetic_basis])
@pytest.mark.skipif(
not has_package("pyquante2"), reason="Missing Optional Dependency: pyquante2"
)
def test_benzene(func, benchmark):
mol = from_pyquante("c6h6")
basis = basisset(mol, "def2-TZVPPD")
basis = jax.device_put(basis)

def harness():
return func(basis).block_until_ready()

benchmark(harness)


@pytest.mark.parametrize("mol_name", ["h2", "water"])
@pytest.mark.parametrize(
"func", [overlap_basis, kinetic_basis, nuclear_basis, eri_basis_sparse]
Expand Down

0 comments on commit 33bf207

Please sign in to comment.