Skip to content

Commit

Permalink
Merge 34899ac into a937f62
Browse files Browse the repository at this point in the history
  • Loading branch information
cortner authored Sep 17, 2024
2 parents a937f62 + 34899ac commit 0042c03
Show file tree
Hide file tree
Showing 18 changed files with 77 additions and 42 deletions.
4 changes: 2 additions & 2 deletions docs/src/tutorials/basic_julia_workflow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ result = acefit!(train_data, model; solver=solver, prior = P, weights=weights);
# We can display an error table as follows:

@info("Training Error Table")
err_train = ACEpotentials.linear_errors(train_data, model; weights=weights);
err_train = ACEpotentials.compute_errors(train_data, model; weights=weights);

# We should of course also look at test errors, which can be done as follows. Depending on the choice of solver, and solver parameters, the test errors might be very poor. Exploring different parameters in different applications can lead to significantly improved predictions.

@info("Test Error Table")
err_test = ACEpotentials.linear_errors(test_data, model; weights=weights);
err_test = ACEpotentials.compute_errors(test_data, model; weights=weights);

# If we want to save the fitted potentials to disk to later use we can simply save the hyperparameters and the parameters. At the moment this must be done manually but a more complete and convenient interface for this will be provided, also adding various sanity checks.

Expand Down
2 changes: 1 addition & 1 deletion docs/src/tutorials/smoothness_priors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ for (prior_name, P) in priors
set_parameters!(_modl, P \ c̃)

## compute errors and store them for later use (don't print them here)
errs = linear_errors(rawdata, model; verbose=false, datakeys...)
errs = compute_errors(rawdata, model; verbose=false, datakeys...)
rmse[prior_name] = errs["rmse"]["set"]["F"]
pots[prior_name] = _modl
println(" force=rmse = ", rmse[prior_name])
Expand Down
2 changes: 1 addition & 1 deletion docs/src_outdated/AtomsBase_interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ P = smoothness_prior(model; p = 4)
acefit!(model, data_train; solver=solver, weights=weights, prior = P);

@info("Training Error Table")
ACEpotentials.linear_errors(data_train, model; weights=weights);
ACEpotentials.compute_errors(data_train, model; weights=weights);
```

### Training data in AtomsBase structures
Expand Down
4 changes: 2 additions & 2 deletions docs/src_outdated/TiAl_basis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ test = [ACEpotentials.AtomsData(t; weights=weights, v_ref=Vref, datakeys...) for

@info("Test Error Tables")
@info("First Potential: ")
ACEpotentials.linear_errors(test, pot_1);
ACEpotentials.compute_errors(test, pot_1);

@info("Second Potential: ")
ACEpotentials.linear_errors(test, pot_2);
ACEpotentials.compute_errors(test, pot_2);

# If we want to save the fitted potentials to disk to later use we can use one of the following commands: the first saves the potential as an `ACE1.jl` compatible potential, while the second line exports it to a format that can be ready by the `pacemaker` code to be used within LAMMPS. This functionality is currently disabled.
#
Expand Down
4 changes: 2 additions & 2 deletions docs/src_outdated/first_example_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ acefit!(model, train; solver=solver, data_keys...);
# To see the training errors we can use

@info("Training Errors")
ACEpotentials.linear_errors(train, model; data_keys...);
ACEpotentials.compute_errors(train, model; data_keys...);

# ### Step 4: Run some tests
#
# At a minimum one should have a test set, check the errors on that test set, and confirm that they are similar as the training errors.

@info("Test Errors")
test = [gen_dat() for _=1:20]
ACEpotentials.linear_errors(test, model; data_keys...);
ACEpotentials.compute_errors(test, model; data_keys...);

# If we wanted to perform such a test ``manually'' it might look like this:

Expand Down
22 changes: 12 additions & 10 deletions examples/Tutorial/ACEpotentials-Tutorial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ Pkg.add(["LaTeXStrings", "MultivariateStats", "Plots", "PrettyTables",
"Suppressor", "ExtXYZ", "Unitful", "Distributed", "AtomsCalculators",
])

## ACEpotentials installation
using Pkg
Pkg.activate(".")
## ACEpotentials installation:
## If ACEpotentials has not been installed yet, uncomment the following lines
## using Pkg; Pkg.activate(".")
## Add the ACE registry, which stores the ACEpotential package information
Pkg.Registry.add(RegistrySpec(url="https://github.com/ACEsuit/ACEregistry"))
Pkg.add("ACEpotentials")
## Pkg.Registry.add(RegistrySpec(url="https://github.com/ACEsuit/ACEregistry"))
## Pkg.add("ACEpotentials")

# We can check the status of the installed packages.

Expand Down Expand Up @@ -247,10 +247,10 @@ acefit!(Si_tiny_dataset, model;
solver=solver, data_keys...);

@info("Training Errors")
linear_errors(Si_tiny_dataset, model; data_keys...);
compute_errors(Si_tiny_dataset, model; data_keys...);

@info("Test Error")
linear_errors(Si_dataset, model; data_keys...);
compute_errors(Si_dataset, model; data_keys...);

# Export to LAMMPS is currently not supported. Earlier versions of
# `ACEpotentials` supported this via
Expand Down Expand Up @@ -376,6 +376,7 @@ assess_model(new_model, new_dataset)
# structures in total.

for i in 1:4
global new_dataset, new_model # declare these are global variables
@show i
new_dataset, new_model = augment(new_dataset, new_model; num=5);
end
Expand Down Expand Up @@ -424,7 +425,7 @@ model = ace1_model(elements = [:Ti, :Al],
# and it is fit in the same manner.

acefit!(tial_data[1:5:end], model);
linear_errors(tial_data[1:5:end], model);
compute_errors(tial_data[1:5:end], model);

# ## Part 6: Recreate data from the ACEpotentials.jl paper
#
Expand Down Expand Up @@ -455,6 +456,7 @@ totaldegree = [ 20, 16, 12 ] # small model: ~ 300 basis functions
errors = Dict("E" => Dict(), "F" => Dict())

for element in elements
local model # treat `model` as a variable local to the scope of `for`
## load the dataset
@info("---------- loading $(element) dataset ----------")
train, test, _ = ACEpotentials.example_dataset("Zuo20_$element")
Expand All @@ -464,7 +466,7 @@ for element in elements
## train the model
acefit!(train, model, solver = ACEfit.BLR(; factorization = :svd))
## compute and store errors
err = linear_errors(test, model)
err = compute_errors(test, model)
errors["E"][element] = err["mae"]["set"]["E"] * 1000
errors["F"][element] = err["mae"]["set"]["F"]
end
Expand Down Expand Up @@ -523,4 +525,4 @@ pretty_table(f_table; header = header)
# - Use an `ACEpotentials.jl` potential with ASE:
# https://acesuit.github.io/ACEpotentials.jl/dev/tutorials/python_ase/
# - Install LAMMPS with `ACEpotentials` patch:
# https://acesuit.github.io/ACEpotentials.jl/dev/tutorials/lammps/
# https://acesuit.github.io/ACEpotentials.jl/dev/tutorials/lammps/
4 changes: 2 additions & 2 deletions examples/zuobench/error_table.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ for sym in syms
acefit!(train, model_lge; solver=solver); GC.gc()

# compute and store errors for later visualisation
err_sm = ACEpotentials.linear_errors(test, model_sm)
err_lge = ACEpotentials.linear_errors(test, model_lge)
err_sm = ACEpotentials.compute_errors(test, model_sm)
err_lge = ACEpotentials.compute_errors(test, model_lge)
err["sm" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000
err["sm" ]["F"][sym] = err_sm["mae"]["set"]["F"]
err["lge"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000
Expand Down
8 changes: 4 additions & 4 deletions examples/zuobench/error_table_svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ for sym in syms
acefit!(train, model_lge; solver=solver); GC.gc()

# compute and store errors for later visualisation
err_sm = ACEpotentials.linear_errors(test, model_sm)
err_lge = ACEpotentials.linear_errors(test, model_lge)
err_sm = ACEpotentials.compute_errors(test, model_sm)
err_lge = ACEpotentials.compute_errors(test, model_lge)
err["sm_blr" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000
err["sm_blr" ]["F"][sym] = err_sm["mae"]["set"]["F"]
err["lge_blr"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000
Expand All @@ -50,8 +50,8 @@ for sym in syms
solver = ACEfit.TruncatedSVD() # truncation will be determined from validation set
acefit!(train1, model_sm; validation_set = val1, solver=solver); GC.gc()
acefit!(train1, model_lge; validation_set = val1, solver=solver); GC.gc()
err_sm = ACEpotentials.linear_errors(test, model_sm)
err_lge = ACEpotentials.linear_errors(test, model_lge)
err_sm = ACEpotentials.compute_errors(test, model_sm)
err_lge = ACEpotentials.compute_errors(test, model_lge)
err["sm_svd" ]["E"][sym] = err_sm["mae"]["set"]["E"] * 1000
err["sm_svd" ]["F"][sym] = err_sm["mae"]["set"]["F"]
err["lge_svd"]["E"][sym] = err_lge["mae"]["set"]["E"] * 1000
Expand Down
6 changes: 3 additions & 3 deletions examples/zuobench/zuo_asp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ pot_300 = fast_evaluator(model_300; aa_static = true)
pot_100 = fast_evaluator(model_100; aa_static = true)

@info("Evaluate errors on the test set")
err_100 = ACEpotentials.linear_errors(test_data, pot_100)
err_300 = ACEpotentials.linear_errors(test_data, pot_300)
err_1000 = ACEpotentials.linear_errors(test_data, pot_1000)
err_100 = ACEpotentials.compute_errors(test_data, pot_100)
err_300 = ACEpotentials.compute_errors(test_data, pot_300)
err_1000 = ACEpotentials.compute_errors(test_data, pot_1000)

##

Expand Down
Empty file added ext_tests/tutorial/README.md
Empty file.
29 changes: 29 additions & 0 deletions ext_tests/tutorial/run_tutorial.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Run this test using
# julia --project=. run_tutorial.jl
# if the command fails, then clean the folder using
# rm ACEpotentials-Tutorial.jl ACEpotentials-Tutorial.ipynb Project.toml Manifest.toml Si_dataset.xyz Si_tiny_tutorial.json

julia_cmd = Base.julia_cmd()
appath = abspath(joinpath(@__DIR__(), "..", ".."))
setuptutorial = """
begin
using Pkg;
Pkg.develop(; path = \"$appath\");
using ACEpotentials;
ACEpotentials.copy_tutorial();
end
"""

run(`$julia_cmd --project=. -e $setuptutorial`)

if !isfile("ACEpotentials-Tutorial.ipynb")
error("Tutorial notebook not installed.")
end

tutorial_file = joinpath(appath, "examples", "Tutorial", "ACEpotentials-Tutorial.jl")
cp(tutorial_file, joinpath(pwd(), "ACEpotentials-Tutorial.jl"); force=true)

run(`$julia_cmd --project=. ACEpotentials-Tutorial.jl`)

@info("Cleaning up")
run(`rm ACEpotentials-Tutorial.jl ACEpotentials-Tutorial.ipynb Project.toml Manifest.toml Si_dataset.xyz Si_tiny_tutorial.json`)
4 changes: 2 additions & 2 deletions scripts/runfit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ OD = args_dict["output"]
if OD["error_table"] || OD["scatter"]
@info("evaluating errors")
# training errors
err_train, train_evf = ACEpotentials.linear_errors(train, model; data_keys..., weights=weights, return_efv = true)
err_train, train_evf = ACEpotentials.compute_errors(train, model; data_keys..., weights=weights, return_efv = true)
err = Dict("train" => err_train)
if OD["scatter"]
D["train_evf"] = train_evf
Expand All @@ -77,7 +77,7 @@ if OD["error_table"] || OD["scatter"]
# test errors (if a test dataset exists)
if haskey(args_dict["data"], "test_file")
test = ExtXYZ.load(args_dict["data"]["test_file"])
err_test, test_evf = ACEpotentials.linear_errors(test, model; data_keys..., weights=weights, return_efv = true)
err_test, test_evf = ACEpotentials.compute_errors(test, model; data_keys..., weights=weights, return_efv = true)
err["test"] = err_test
if OD["scatter"]
D["test_evf"] = test_evf
Expand Down
7 changes: 6 additions & 1 deletion src/ACEpotentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ import ACEpotentials.Models: algebraic_smoothness_prior,
set_committee!
import JSON

@deprecate linear_errors compute_errors

export ace1_model,
length_basis,
algebraic_smoothness_prior,
Expand All @@ -54,7 +56,10 @@ export ace1_model,
set_parameters!,
fast_evaluator,
@committee,
set_committee!
set_committee!,
compute_errors,
linear_errors


include("json_interface.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/atoms_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ function group_type(d::AtomsData; group_key="config_type")
end


function linear_errors(data::AbstractArray{AtomsData}, model;
function compute_errors(data::AbstractArray{AtomsData}, model;
group_key="config_type", verbose=true,
return_efv = false
)
Expand Down
7 changes: 3 additions & 4 deletions src/fit_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ACEpotentials.Models: ACEPotential

import ACEfit: assemble

export acefit!, assemble, linear_errors
export acefit!, assemble, compute_errors


# ---------------- some utilities and defaults
Expand Down Expand Up @@ -185,8 +185,7 @@ end




function linear_errors(raw_data::AbstractArray{<: AbstractSystem}, model;
function compute_errors(raw_data::AbstractArray{<: AbstractSystem}, model;
energy_key = "energy",
force_key = "force",
virial_key = "virial",
Expand All @@ -198,7 +197,7 @@ function linear_errors(raw_data::AbstractArray{<: AbstractSystem}, model;
virial_key = virial_key, weights = weights,
v_ref = nothing)
for at in raw_data ]
return linear_errors(data, model; verbose=verbose, return_efv = return_efv)
return compute_errors(data, model; verbose=verbose, return_efv = return_efv)
end


Expand Down
6 changes: 3 additions & 3 deletions src/outdated/atoms_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ _has_forces(data; force_key=:force, kwargs...) = hasatomkey(data, Symbol(force
_has_virial(data; virial_key=:virial, kwargs...) = haskey(data, Symbol(virial_key))


function linear_errors(data, model::ACE1x.ACE1Model; kwargs...)
return linear_errors(data, ACEmd.ACEpotential(model.potential.components); kwargs...)
function compute_errors(data, model::ACE1x.ACE1Model; kwargs...)
return compute_errors(data, ACEmd.ACEpotential(model.potential.components); kwargs...)
end


Expand All @@ -19,7 +19,7 @@ function ACEmd.ACEpotential(model::ACE1x.ACE1Model; kwargs...)
end


function linear_errors(
function compute_errors(
data,
model::ACEmd.ACEpotential;
group_key="config_type",
Expand Down
2 changes: 1 addition & 1 deletion test/atomsbase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ using Test
P = smoothness_prior(model; p = 4)

acefit!(model, data_train; solver=solver, weights=weights, prior = P, repulsion_restraint=true);
ce, err = ACEpotentials.linear_errors(data, model; weights=weights);
ce, err = ACEpotentials.compute_errors(data, model; weights=weights);
@test err["mae"]["F"] < 0.6
end
6 changes: 3 additions & 3 deletions test/test_silicon.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ acefit!(data, model;
weights = weights,
solver=ACEfit.QR())

err = ACEpotentials.linear_errors(data, model; data_keys..., weights=weights)
err = ACEpotentials.compute_errors(data, model; data_keys..., weights=weights)

test_rmse(err["rmse"], rmse_qr)

Expand All @@ -72,7 +72,7 @@ acefit!(data, model;

rmprocs(workers())

err_dist = ACEpotentials.linear_errors(data, model; data_keys..., weights=weights)
err_dist = ACEpotentials.compute_errors(data, model; data_keys..., weights=weights)
test_rmse(err_dist["rmse"], rmse_qr)

##
Expand All @@ -90,7 +90,7 @@ acefit!(data, model;
weights = weights,
solver = ACEfit.BLR())

err_blr = ACEpotentials.linear_errors(data, model; data_keys..., weights=weights)
err_blr = ACEpotentials.compute_errors(data, model; data_keys..., weights=weights)

test_rmse(err_blr["rmse"], rmse_blr)

Expand Down

0 comments on commit 0042c03

Please sign in to comment.