This repository contains the code for the paper:
Amin, I., Raja, S., Krishnapriyan, A.S. (2024). Towards Fast, Specialized Machine Learning Force Fields: Distilling Foundation Models via Energy Hessians.
Accepted to ICLR 2025. arXiv:2501.09009.
We built our implementation of Hessian distillation on top of the Fairchem repository.
The environment and NERSC training instructions were adapted from the EScAIP repository.
If you have any questions about the repo feel free to email ishanthewizard@berkeley.edu.
(environment is from the EScAIP repository)
Step 1: Install mamba solver for conda (optional)
conda install mamba -n base -c conda-forge
Step 2: Check the CUDA is in PATH
and LD_LIBRARY_PATH
$ echo $PATH | tr ':' '\n' | grep cuda
/usr/local/cuda/bin
$ echo $LD_LIBRARY_PATH | tr ':' '\n' | grep cuda
/usr/local/cuda/lib64
If not, add something like the following (depends on location) to your .bashrc
or .zshrc
:
export PATH="/usr/local/cuda/bin:$PATH"
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64
Step 3: Install the dependencies
mamba env create -f env.yml
conda activate escaip
All of the specialized datasets we used in the paper, as well as the hessian and baseline labels we used for distillation, can be found at this link: https://zenodo.org/records/14759305
Below we provide the links to repositories where the foundation model weights were obtained, as well as the original, unspecialized training datasets that the foundation models were trained on.
Scripts to process the original data and generate the specialized subsets can be found in the scripts folder.
To perform a standard (non-distilled) training run of Gemnet-dT on the Solvated Amino Acids subset of SPICE, execute:
python main.py --mode train --config-yml configs/SPICE/solvated_amino_acids/gemnet-dT-small.yml
For more info and options relating to training command inputs, please see the Fairchem repository.
For a Hessian distillation version of the above training run, navigate to the hessians folder, located in the same directory as the original configuration. The distillation configuration specifies attributes unique to distillation training and includes a link to all the attributes of the original (non-distilled) configuration:
python main.py --mode train --config-yml configs/SPICE/solvated_amino_acids/hessian/gemnet-dT-small.yml
Similiarly, you can run some of the baselines we ran in our paper by selecting a config from the baselines folder, which is located in the same directory as the undistilled configuration:
python main.py --mode train --config-yml configs/SPICE/solvated_amino_acids/baselines/gemnet-dT-small-n2n.yml
To generate the labels for Mace-OFF, use the script scripts/spice_scripts/get_maceOFF_labels.py
.
Set the dataset_path
and labels_folder
variables in the main
function to the correct source and destination.
To generate the labels for Mace-MP0, use the script scripts/mptraj_scripts/get_maceMP0_labels.py
.
Similarly, set the dataset_path
and labels_folder
variables in the main
function appropriately.
If you have a teacher checkpoint that is runnable in the Fairchem repository, you can generate Hessian labels using the following steps:
- Go to a Hessian configuration file.
- Set the
trainer
tosrc.distill_trainer.LabelsTrainer
. - Ensure your Hessian configuration includes the following structure:
dataset:
train:
teacher_checkpoint_path: data/teacher_checkpoints/nanotube_jmp-l.ckpt
teacher_labels_folder: data/labels/md22_labels/jmp-large_double-walled_nanotube/
label_force_batch_size: 32
label_jac_batch_size: 64
vectorize_teach_jacs: False
teacher_checkpoint_path
: Path to your teacher checkpoint.teacher_labels_folder
: Destination path for the generated labels.label_force_batch_size
: Batch size for generating force labels.label_jac_batch_size
: Batch size for generating Hessians (note: setting this too high may cause memory overflow).vectorize_teach_jacs
: If set toTrue
, Hessian generation speed will increase usingvmap
, but there is a risk of memory overflow.
Please see this file for an example of label generation, the label generation settings are the ones commented out.
- Important: Ensure the dataset specified in your base configuration matches the dataset you plan to distill with (i.e., the dataset you want to generate labels for). Also ensure that the linked config with the model attributes is the teacher's config, not the student's.
We created the JMP labels from the JMP repository, essentially by just copying over our src/labels_trainer.py
For distributed training on NERSC, please see the Nersc Distributed Training README taken from the EScAIP repository
If you find this work useful, please consider citing the following:
@article{amin2025distilling,
title={Towards Fast, Specialized Machine Learning Force Fields: Distilling Foundation Models via Energy Hessians},
author={Ishan Amin, Sanjeev Raja, and Krishnapriyan, A.S.},
journal={International Conference on Learning Representations 2025},
year={2025},
archivePrefix={arXiv},
eprint={2501.09009},
}