Skip to content

Commit

Permalink
add some codes used in Pasha's project
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenQianA committed Dec 2, 2024
1 parent a97eee4 commit e1ee2c5
Show file tree
Hide file tree
Showing 60 changed files with 15,781 additions and 0 deletions.
46 changes: 46 additions & 0 deletions examples/H2O/Inspect.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using Serialization
using JuLIP
using LinearAlgebra
using Statistics

data_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_2_500K_rcut_10/data_dict_test.jls"
data_dict = open(data_file, "r") do file
deserialize(file)
end

data_dict["MAE"]["all"]*27211.4
data_dict["MAE"]["on"]["all"]*27211.4
data_dict["MAE"]["off"]["all"]*27211.4



matrix_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/dm_H2O_2_500K_rcut_10/matrix_dict_test.jls"
data_dict = open(matrix_file, "r") do file
deserialize(file)
end

data_dict["predicted"].-data_dict["gt"]

mean(norm.(data_dict["predicted"].-data_dict["gt"], 2))
0.13157953160181063



matrix_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/dm_H2O_2_300K_rcut_10/matrix_dict_test.jls"


data_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/hyperparameter_searching_dm_r_cutoff/dyn-wd-500K_3/data_dict_d_max_14_r_cutoff.jls"
data_dict = open(data_file, "r") do file
deserialize(file)
end





matrix_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/dm_H2O_2_500K_rcut_14_n_512/matrix_dict_test.jls"
matrix_dict = open(matrix_file, "r") do file
deserialize(file)
end
mean(norm.(matrix_dict["predicted"].-matrix_dict["gt"], 2))

66 changes: 66 additions & 0 deletions examples/H2O/evaluation_Pasha_H.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using Distributed, SlurmClusterManager, SparseArrays
addprocs(28)
#addprocs(SlurmManager())
@everywhere begin
using ACEhamiltonians, HDF5, Serialization
using Statistics
using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma
end

model_path = "./Result/H_12_light/smpl10_md_w12.bin"

w6_path = "./Data/Datasets/smpl10_md_w6.h5"
w12_path = "./Data/Datasets/smpl10_md_w12.h5"
w21_path = "./Data/Datasets/smpl10_md_w21.h5"
w101_path = "./Data/Datasets/smpl10_md_w101.h5"


function get_matrices_dict(model_path::String, data_path::String)

model = deserialize(model_path)

get_matrix = Dict( # Select an appropriate function to load the target matrix
"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix,
"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label]

target_systems = h5open(data_path) do database keys(database) end

atoms = h5open(data_path) do database
[load_atoms(database[system]) for system in target_systems]
end

images = [cell_translations(atoms, model) for atoms in atoms]

predicted = [sparse(dropdims(pred, dims=3)) for pred in predict.(Ref(model), atoms, images)]

groud_truth = h5open(data_path) do database
[sparse(dropdims(get_matrix(database[system]), dims=3)) for system in target_systems]
end

data_dict = Dict{String, Dict}()

for (system, pred, gt) in zip(target_systems, predicted, groud_truth)

data_dict[system] = Dict("gt"=>gt, "pred"=>pred)

end

return data_dict

end


function evaluate_on_data(model_path::String, data_path::String)

matrices_dict = get_matrices_dict(model_path, data_path)
dict_path = joinpath(dirname(model_path), split(split(basename(data_path), ".")[1], "_")[end]*"_dict.jls")
open(dict_path, "w") do file
serialize(file, matrices_dict)
end

end


for data_path in [w6_path, w12_path, w21_path]
evaluate_on_data(model_path, data_path)
end
152 changes: 152 additions & 0 deletions examples/H2O/hyperparameter_searching_H_H2O_1.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
using BenchmarkTools, Serialization, Random
using Distributed, SlurmClusterManager
# addprocs(SlurmManager())
addprocs(28)
@everywhere begin
using ACEhamiltonians, HDF5, Serialization
using Statistics
using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma
using JuLIP: Atoms
using PeriodicTable
using Statistics
include("./utils.jl")
end

# -----------------------
# |***general setting***|
# -----------------------

database_path = "./Data/Datasets/H2O_H_aims.h5"
data_name = split(basename(database_path), ".")[1]
output_path = joinpath("./Result/hyperparameter_searching_H_H2O_1", data_name)
nsamples = 512
mkpath(output_path)

# ---------------------------
# |function to get the model|
# ---------------------------

nfolds = 5

model_type = "H"

basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2])

function get_model(basis_definition::BasisDef, d_max::Int64=14, r_cutoff::Float64=6.,
label::String="H", meta_data::Union{Dict, Nothing}=nothing)

on_site_parameters = OnSiteParaSet(
# Maximum correlation order
GlobalParams(2),
# Maximum polynomial degree
GlobalParams(d_max),
# Environmental cutoff radius
GlobalParams(r_cutoff),
# Scaling factor "r₀"
GlobalParams(0.9)
)

# Off-site parameter deceleration
off_site_parameters = OffSiteParaSet(
# Maximum correlation order
GlobalParams(1),
# Maximum polynomial degree
GlobalParams(d_max),
# Bond cutoff radius
GlobalParams(r_cutoff),
# Environmental cutoff radius
GlobalParams(r_cutoff/2.0),
)

model = Model(basis_definition, on_site_parameters, off_site_parameters, label, meta_data)

return model

@info "finished buiding a model with d_max=$(d_max), r_cutoff=$(r_cutoff)"

end

# ---------------------------------
# |***cross validation function***|
# ---------------------------------
# single model evaluation
function evaluate_single(model::Model, database_path::String, train_systems::Vector{String}, test_systems::Vector{String})

model = deepcopy(model)

#model fitting
h5open(database_path) do database
# Load the target system(s) for fitting
systems = [database[system] for system in train_systems]

# Perform the fitting operation
fit!(model, systems; recentre=true)
end

#prediction
atoms = h5open(database_path) do database
[load_atoms(database[system]) for system in test_systems]
end
images = cell_translations.(atoms, Ref(model))
predicted = predict.(Ref(model), atoms, images)

#groud truth data
get_matrix = Dict( # Select an appropriate function to load the target matrix
"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix,
"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label]
gt = h5open(database_path) do database
[get_matrix(database[system]) for system in test_systems]
end

error = predicted-gt

return error, atoms

end


function cross_validation(model::Model, database_path::String, nfolds::Int=5)


target_systems = h5open(database_path) do database keys(database) end
rng = MersenneTwister(1234)
target_systems = shuffle(rng, target_systems)[begin:nsamples]
target_systems = [target_systems[i:nfolds:end] for i in 1:nfolds]

errors = []
atomsv = []
for fold in 1:nfolds # nfolds:nfolds # 1:nfolds
train_systems = vcat(target_systems[1:fold-1]..., target_systems[fold+1:end]...)
test_systems = target_systems[fold]
error, atoms = evaluate_single(model, database_path, train_systems, test_systems)
push!(errors, error)
push!(atomsv, atoms)
end
errors = vcat(errors...)
atomsv = vcat(atomsv...)

data_dict = get_error_dict(errors, atomsv, model)

return data_dict

end


# ---------------------
# |initialize the test|
# ---------------------
data_dict = Dict{Tuple, Dict}()
output_path_figs = joinpath(output_path, "figures")

for d_max in 6:14
r_cutoff = 6.
dict_name = (d_max, r_cutoff)
model = get_model(basis_definition, d_max, r_cutoff, model_type)
data_dict_sub = cross_validation(model, database_path, nfolds)
data_dict[dict_name] = data_dict_sub
open(joinpath(output_path, "data_dict_d_max_r_6.jls"), "w") do file
serialize(file, data_dict)
end
plot_hyperparams(data_dict, "d_max", output_path_figs)
@info "finished testing a model with d_max=$(d_max), r_cutoff=$(r_cutoff)"
end
Loading

0 comments on commit e1ee2c5

Please sign in to comment.