Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Setting the default directory for storing MGLDataset in the current working directory. #503

Merged
merged 80 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
fc94e79
improve TensorNet model coverage
kenko911 Jun 22, 2024
53357d7
Update pyproject.toml
kenko911 Jun 22, 2024
f65cdba
Improve the unit test for SO(3) equivarance in TensorNet class
kenko911 Jun 22, 2024
93574e0
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Jun 22, 2024
16abc38
improve SO3Net model class coverage and simplify TensorNet implementa…
kenko911 Jun 23, 2024
6aaad51
Merge branch 'materialsvirtuallab:main' into main
kenko911 Jun 23, 2024
2798176
improve the coverage in MLP_norm class
kenko911 Jun 24, 2024
e24bbdd
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Jun 24, 2024
2101bb0
Merge branch 'materialsvirtuallab:main' into main
kenko911 Jun 24, 2024
147d4f5
Better documentation for M3GNet potential training with stresses
kenko911 Jun 25, 2024
e144ae2
Merge branch 'materialsvirtuallab:main' into main
kenko911 Jun 28, 2024
9b17963
Improve the implementation of three-body interactions
kenko911 Jul 3, 2024
dc6ed59
fixed black
kenko911 Jul 3, 2024
59ab134
Merge branch 'materialsvirtuallab:main' into main
kenko911 Jul 5, 2024
522809c
Optimize the speed of _compute_3body class
kenko911 Jul 5, 2024
e290538
Merge branch 'materialsvirtuallab:main' into main
kenko911 Jul 8, 2024
27c728f
type checking is added for scheduler
kenko911 Jul 8, 2024
355ee6b
Merge branch 'materialsvirtuallab:main' into main
kenko911 Jul 14, 2024
3ac4f39
update M3GNet Potential training notebook for the demonstration of ob…
kenko911 Jul 14, 2024
992c17e
Downgrade sympy to avoid crash of SO3 operations
kenko911 Jul 14, 2024
193aaff
Merge branch 'materialsvirtuallab:main' into main
kenko911 Jul 17, 2024
38df100
Smooth l1 loss function is added and united tests are improved
kenko911 Jul 17, 2024
ea64ed9
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Jul 20, 2024
40133c0
merge the method predict_structure and featurize_structure into a fun…
kenko911 Jul 21, 2024
6ebdb57
Merge branch 'materialsvirtuallab:main' into main
kenko911 Jul 28, 2024
c469ef7
remove unnecessary else statement for training magmoms
kenko911 Jul 28, 2024
6a3e736
Merge branch 'materialsvirtuallab:main' into main
kenko911 Aug 1, 2024
a4de37b
Merge branch 'materialsvirtuallab:main' into main
kenko911 Aug 7, 2024
bdb5ee0
Merge branch 'materialsvirtuallab:main' into main
kenko911 Aug 14, 2024
e5764b3
modify so3 operation implementation to make united tests pass due to …
kenko911 Aug 14, 2024
07d9bdd
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Aug 14, 2024
b57f0be
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Aug 14, 2024
ef0ce51
skip test_load_all_models for MacOS pytest now
kenko911 Aug 14, 2024
ae43929
Merge branch 'materialsvirtuallab:main' into main
kenko911 Aug 16, 2024
c355500
Merge branch 'materialsvirtuallab:main' into main
kenko911 Sep 4, 2024
7f34ed9
Reference for CHGNet is added
kenko911 Sep 4, 2024
fdae7a1
Update README.md and index.md for including CHGNet
kenko911 Sep 4, 2024
6438649
add more description for using CHGNet pretrained models in Relaxation…
kenko911 Sep 4, 2024
981b9c1
Merge branch 'materialsvirtuallab:main' into main
kenko911 Sep 4, 2024
984f9fa
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Sep 4, 2024
c00594b
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Sep 18, 2024
b4e34cd
A command-line interface for performing ASE MD simulations is added
kenko911 Sep 18, 2024
efe4118
added back py.typed
kenko911 Sep 18, 2024
a32bddc
Merge branch 'materialsvirtuallab:main' into main
kenko911 Sep 22, 2024
d2f1c43
ExpNormal Smearing for radial basis functions is added
kenko911 Sep 22, 2024
916a77c
Changed deprecated torch.scalar_tensor into torch.Tensor
kenko911 Sep 23, 2024
49ba5d8
Converted the float number into tensor
kenko911 Sep 23, 2024
cba6fd6
Upgrade torch to 2.4.0 in pyproject.toml
kenko911 Oct 6, 2024
8b63fdc
fix the united test in test_bond.py
kenko911 Oct 6, 2024
cf9fc6a
Merge branch 'main' into main
shyuep Oct 10, 2024
a20fa10
Bump boto3 from 1.35.38 to 1.35.39
dependabot[bot] Oct 14, 2024
2c1a096
Merge branch 'materialsvirtuallab:main' into main
kenko911 Oct 14, 2024
7209f48
fix the error from the upgrade of boto3
kenko911 Oct 14, 2024
10932c0
Downgrade DGL to 2.2.1
kenko911 Oct 14, 2024
104b53e
Downgrade pytorch
kenko911 Oct 14, 2024
ae09d16
fix mypy by adding self.norm_layers is not None
kenko911 Oct 14, 2024
9f9f43c
Downgrade Pytorch and DGL version (#392)
kenko911 Oct 14, 2024
cf79de7
correct the downgrade for torch
kenko911 Oct 14, 2024
3aa5ef9
Merge branch 'dependabot/pip/boto3-1.35.39' into main
kenko911 Oct 14, 2024
9c185ed
fix the version of torchdata
kenko911 Oct 14, 2024
e6502e8
merged the conflicts
kenko911 Oct 14, 2024
a997119
downgrade torch to 2.2.1
kenko911 Oct 14, 2024
a2c6b80
included symbreak for chgnet training to improve code coverage
kenko911 Oct 22, 2024
f0112a0
Merge branch 'materialsvirtuallab:main' into main
kenko911 Nov 8, 2024
9ece41b
Merge branch 'materialsvirtuallab:main' into main
kenko911 Nov 22, 2024
b30f5b9
fix ruff by sorting '__all__' in _so3.py
kenko911 Nov 22, 2024
296a834
remove Bussi for MD now
kenko911 Nov 22, 2024
189a330
remove Bussi for now
kenko911 Nov 22, 2024
33c156d
Remove trailing whitespace
kenko911 Nov 22, 2024
3f528e2
merge the changes
kenko911 Dec 17, 2024
cdac22b
Merge branch 'main' of https://github.com/kenko911/matgl
kenko911 Dec 17, 2024
20290e1
adjust the proper version of torch and numpy for MatGL
kenko911 Dec 17, 2024
f17cd22
merge the changes
kenko911 Dec 30, 2024
ae75797
NVT bussi ensemble is added
kenko911 Dec 30, 2024
e0fa76a
black again
kenko911 Dec 30, 2024
0cf486c
fix ruff
kenko911 Dec 30, 2024
c1d6100
Merge branch 'materialsvirtuallab:main' into main
kenko911 Jan 23, 2025
52d8015
Set the default directory of storing MGLDataset to current working di…
kenko911 Jan 23, 2025
91156cb
docstring for use_voigt is added
kenko911 Jan 23, 2025
86c17fb
Fix ruff
kenko911 Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@
"element_types = DEFAULT_ELEMENTS\n",
"converter = Structure2Graph(element_types=element_types, cutoff=5.0)\n",
"dataset = MGLDataset(\n",
" threebody_cutoff=4.0, structures=structures, converter=converter, labels=labels, include_line_graph=True\n",
" threebody_cutoff=4.0, \n",
" structures=structures, \n",
" converter=converter, \n",
" labels=labels, \n",
" include_line_graph=True,\n",
")\n",
"train_data, val_data, test_data = split_dataset(\n",
" dataset,\n",
Expand Down
11 changes: 10 additions & 1 deletion src/matgl/ext/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ase.md.nvtberendsen import NVTBerendsen
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
from ase.md.verlet import VelocityVerlet
from ase.stress import full_3x3_to_voigt_6_stress
from pymatgen.core.structure import Molecule, Structure
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.optimization.neighbors import find_points_in_spheres
Expand Down Expand Up @@ -133,6 +134,7 @@ def __init__(
potential: Potential,
state_attr: torch.Tensor | None = None,
stress_weight: float = 1.0,
use_voigt: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add documentation for the use_voigt parameter.

The docstring is missing the description for the use_voigt parameter.

         stress_weight (float): conversion factor from GPa to eV/A^3, if it is set to 1.0, the unit is in GPa
+        use_voigt (bool): Whether to convert stress tensor to Voigt notation format.
+            Default: False
         **kwargs: Kwargs pass through to super().__init__().

Committable suggestion skipped: line range outside the PR's diff.

**kwargs,
):
"""
Expand All @@ -143,6 +145,7 @@ def __init__(
state_attr (tensor): State attribute
compute_stress (bool): whether to calculate the stress
stress_weight (float): conversion factor from GPa to eV/A^3, if it is set to 1.0, the unit is in GPa
use_voigt (bool): whether the voigt notation is used for stress output
**kwargs: Kwargs pass through to super().__init__().
"""
super().__init__(**kwargs)
Expand All @@ -154,6 +157,7 @@ def __init__(
self.state_attr = state_attr
self.element_types = potential.model.element_types # type: ignore
self.cutoff = potential.model.cutoff
self.use_voigt = use_voigt

def calculate(
self,
Expand Down Expand Up @@ -186,7 +190,12 @@ def calculate(
forces=calc_result[1].detach().cpu().numpy(),
)
if self.compute_stress:
self.results.update(stress=calc_result[2].detach().cpu().numpy() * self.stress_weight)
stresses_np = (
full_3x3_to_voigt_6_stress(calc_result[2].detach().cpu().numpy())
if self.use_voigt
else calc_result[2].detach().cpu().numpy()
)
self.results.update(stress=stresses_np * self.stress_weight)
if self.compute_hessian:
self.results.update(hessian=calc_result[3].detach().cpu().numpy())
if self.compute_magmom:
Expand Down
12 changes: 6 additions & 6 deletions src/matgl/graph/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ def __init__(
directed_line_graph: bool = False,
structures: list | None = None,
labels: dict[str, list] | None = None,
name: str = "MGLDataset",
directory_name: str = "MGLDataset",
graph_labels: list[int | float] | None = None,
clear_processed: bool = False,
save_cache: bool = True,
raw_dir: str | None = None,
save_dir: str | None = None,
raw_dir: str = "./",
save_dir: str = "./",
):
"""
Args:
Expand All @@ -152,7 +152,7 @@ def __init__(
Default: False (for M3GNet)
structures: Pymatgen structure.
labels: targets, as a dict of {name: list of values}.
name: name of dataset.
directory_name: name of the generated directory that stores the dataset.
graph_labels: state attributes.
clear_processed: Whether to clear the stored structures after processing into graphs. Structures
are not really needed after the conversion to DGL graphs and can take a significant amount of memory.
Expand All @@ -161,7 +161,7 @@ def __init__(
Default: True
raw_dir : str specifying the directory that will store the downloaded data or the directory that already
stores the input data.
Default: ~/.dgl/
Default: current working directory
save_dir : directory to save the processed dataset. Default: same as raw_dir.
"""
self.filename = filename
Expand All @@ -180,7 +180,7 @@ def __init__(
self.graph_labels = graph_labels
self.clear_processed = clear_processed
self.save_cache = save_cache
super().__init__(name=name, raw_dir=raw_dir, save_dir=save_dir)
super().__init__(name=directory_name, raw_dir=raw_dir, save_dir=save_dir)

def has_cache(self) -> bool:
"""Check if the dgl_graph.bin exists or not."""
Expand Down
8 changes: 4 additions & 4 deletions tests/graph/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_megnet_dataset_with_graph_label_float(self, LiFePO4, BaNiO3):
labels={"label": label},
clear_processed=True,
graph_labels=graph_label,
name="MGLDataset_megnet",
directory_name="MGLDataset_megnet",
)
g1, lat1, state1, label1 = dataset_with_graph_label[0]
g2, lat2, state2, label2 = dataset_with_graph_label[1]
Expand All @@ -85,7 +85,7 @@ def test_load_megenet_dataset(self, LiFePO4, BaNiO3):
label = [-1.0, 2.0]
element_types = get_element_list(structures)
cry_graph = Structure2Graph(element_types=element_types, cutoff=4.0)
dataset = MGLDataset(name="MGLDataset_megnet")
dataset = MGLDataset(directory_name="MGLDataset_megnet")
g1, lat1, state1, label1 = dataset[0]
assert label1["label"] == label[0]
assert g1.num_edges() == cry_graph.get_graph(LiFePO4)[0].num_edges()
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_mgl_dataset(self, LiFePO4, BaNiO3):
include_line_graph=True,
labels={"energies": energies, "forces": forces, "stresses": stresses},
clear_processed=True,
name="MGLDataset_pes",
directory_name="MGLDataset_pes",
)
g1, lat1, l_g1, state1, pes1 = dataset[0]
g2, lat2, l_g2, state2, pes2 = dataset[1]
Expand All @@ -148,7 +148,7 @@ def test_load_mgl_dataset(self, LiFePO4, BaNiO3):
element_types = get_element_list(structures)
cry_graph = Structure2Graph(element_types=element_types, cutoff=4.0)
dataset = MGLDataset(
name="MGLDataset_pes",
directory_name="MGLDataset_pes",
include_line_graph=True,
)
dataset.load()
Expand Down
Loading