Skip to content

Commit

Permalink
a
Browse files Browse the repository at this point in the history
  • Loading branch information
vxfung committed Mar 2, 2024
1 parent eb6c48c commit 2eb4335
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 42 deletions.
26 changes: 9 additions & 17 deletions configs/tensor_net_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,19 @@ task:

model:
name: tensor_net
load_model: False
save_model: True
model_path: "my_model.pth"
#model attributes
hidden_channels: 128
hidden_channels: 256
num_layers: 2
num_rbf: 32
num_rbf: 50
rbf_type: "expnorm"
trainable_rbf: False
trainable_rbf: True
activation: "silu"
cutoff_lower: 0
cutoff_upper: 4.5
max_num_neighbors: 64
max_z: 128
max_z: 100
equivariance_invariance_group: "O(3)"
static_shapes: True
check_errors: True
dtype: torch.float32
box_vecs: None
num_post_layers: 1
post_hidden_channels: 64
num_post_layers: 2
post_hidden_channels: 128
pool: "global_mean_pool"
aggr: "add"
pool_order: "early"
Expand Down Expand Up @@ -89,15 +81,15 @@ dataset:
# Whether the data has already been processed and a data.pt file is present from a previous run
processed: False
# Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path}
src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/2D_data_npj/raw/"
src: data/test_data/data_graph_scalar.json
#src: "/project/Rithwik/2D_data_npj/raw/"
#src: "/project/Rithwik/QM9/data.json"
# Path to target file within data_path
target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/2D_data_npj/targets.csv"
target_path:
#target_path: "/project/Rithwik/2D_data_npj/targets.csv"
#target_path:
# Path to save processed data.pt file
pt_path: "/global/cfs/projectdirs/m3641/Rithwik/datasets/2D_data_npj/"
pt_path: data/
prediction_level: graph

transforms:
Expand Down
18 changes: 6 additions & 12 deletions configs/torchmd_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ task:

model:
name: torchmd_etEarly
load_model: False
save_model: True
model_path: "my_model.pth"
#model attributes
# model attributes
hidden_channels: 256
num_filters: 128
num_layers: 8
Expand All @@ -42,12 +39,9 @@ model:
num_heads: 8
distance_influence: "both"
neighbor_embedding: True
cutoff_lower: 0.0
cutoff_upper: 8.0
max_z: 100
max_num_neighbors: 32
aggr: "add"
num_post_layers: 3
num_post_layers: 1
post_hidden_channels: 64
pool: "global_mean_pool"
pool_order: "early"
Expand All @@ -63,7 +57,7 @@ model:
gradient: False

optim:
max_epochs: 300
max_epochs: 200
max_checkpoint_epochs: 0
lr: 0.001
# Either custom or from torch.nn.functional library. If from torch, loss_type is TorchLossWrapper
Expand All @@ -89,15 +83,15 @@ dataset:
# Whether the data has already been processed and a data.pt file is present from a previous run
processed: False
# Path to data files - this can either be in the form of a string denoting a single path or a dictionary of {train: train_path, val: val_path, test: test_path, predict: predict_path}
src: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/2D_data_npj/raw/"
src: data/test_data/data_graph_scalar.json
#src: "/project/Rithwik/2D_data_npj/raw/"
#src: "/project/Rithwik/QM9/data.json"
# Path to target file within data_path
target_path: "/global/cfs/projectdirs/m3641/Shared/Materials_datasets/2D_data_npj/targets.csv"
target_path:
#target_path: "/project/Rithwik/2D_data_npj/targets.csv"
#target_path:
# Path to save processed data.pt file
pt_path: "/global/cfs/projectdirs/m3641/Rithwik/datasets/2D_data_npj/"
pt_path: data/
prediction_level: graph

transforms:
Expand Down
21 changes: 8 additions & 13 deletions matdeeplearn/models/tensor_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,19 +128,14 @@ def __init__(
output_dim,
hidden_channels=128,
num_layers=2,
num_rbf=32,
num_rbf=50,
rbf_type="expnorm",
trainable_rbf=False,
trainable_rbf=True,
activation="silu",
cutoff_lower=0,
cutoff_upper=4.5,
max_num_neighbors=64,
max_z=128,
equivariance_invariance_group="O(3)",
static_shapes=True,
check_errors=True,
dtype=torch.float32,
box_vecs=None,
num_post_layers=1,
post_hidden_channels=64,
pool="global_mean_pool",
Expand Down Expand Up @@ -168,19 +163,19 @@ def __init__(
self.num_layers = num_layers
self.num_rbf = num_rbf
self.rbf_type = rbf_type
self.activation = activation
self.cutoff_lower = cutoff_lower
self.cutoff_upper = cutoff_upper
self.activation = activation
cutoff_lower = 0

act_class = act_class_mapping[activation]
self.distance_expansion = rbf_class_mapping[rbf_type](
cutoff_lower, cutoff_upper, num_rbf, trainable_rbf
cutoff_lower, self.cutoff_radius, num_rbf, trainable_rbf
)
self.tensor_embedding = TensorEmbedding(
hidden_channels,
num_rbf,
act_class,
cutoff_lower,
cutoff_upper,
self.cutoff_radius,
trainable_rbf,
max_z,
dtype,
Expand All @@ -195,7 +190,7 @@ def __init__(
hidden_channels,
act_class,
cutoff_lower,
cutoff_upper,
self.cutoff_radius,
equivariance_invariance_group,
dtype,
)
Expand Down

0 comments on commit 2eb4335

Please sign in to comment.