diff --git a/examples/H2O/Inspect.jl b/examples/H2O/Inspect.jl new file mode 100644 index 0000000..d115ab6 --- /dev/null +++ b/examples/H2O/Inspect.jl @@ -0,0 +1,46 @@ +using Serialization +using JuLIP +using LinearAlgebra +using Statistics + +data_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_2_500K_rcut_10/data_dict_test.jls" +data_dict = open(data_file, "r") do file + deserialize(file) +end + +data_dict["MAE"]["all"]*27211.4 +data_dict["MAE"]["on"]["all"]*27211.4 +data_dict["MAE"]["off"]["all"]*27211.4 + + + +matrix_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/dm_H2O_2_500K_rcut_10/matrix_dict_test.jls" +data_dict = open(matrix_file, "r") do file + deserialize(file) +end + +data_dict["predicted"].-data_dict["gt"] + +mean(norm.(data_dict["predicted"].-data_dict["gt"], 2)) +0.13157953160181063 + + + +matrix_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/dm_H2O_2_300K_rcut_10/matrix_dict_test.jls" + + +data_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/hyperparameter_searching_dm_r_cutoff/dyn-wd-500K_3/data_dict_d_max_14_r_cutoff.jls" +data_dict = open(data_file, "r") do file + deserialize(file) +end + + + + + +matrix_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/dm_H2O_2_500K_rcut_14_n_512/matrix_dict_test.jls" +matrix_dict = open(matrix_file, "r") do file + deserialize(file) +end +mean(norm.(matrix_dict["predicted"].-matrix_dict["gt"], 2)) + diff --git a/examples/H2O/evaluation_Pasha_H.jl b/examples/H2O/evaluation_Pasha_H.jl new file mode 100644 index 0000000..cc1b34f --- /dev/null +++ b/examples/H2O/evaluation_Pasha_H.jl @@ -0,0 +1,66 @@ +using Distributed, SlurmClusterManager, SparseArrays +addprocs(28) +#addprocs(SlurmManager()) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma +end + +model_path = "./Result/H_12_light/smpl10_md_w12.bin" + +w6_path = "./Data/Datasets/smpl10_md_w6.h5" +w12_path = "./Data/Datasets/smpl10_md_w12.h5" +w21_path = "./Data/Datasets/smpl10_md_w21.h5" +w101_path = "./Data/Datasets/smpl10_md_w101.h5" + + +function get_matrices_dict(model_path::String, data_path::String) + + model = deserialize(model_path) + + get_matrix = Dict( # Select an appropriate function to load the target matrix + "H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, + "Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] + + target_systems = h5open(data_path) do database keys(database) end + + atoms = h5open(data_path) do database + [load_atoms(database[system]) for system in target_systems] + end + + images = [cell_translations(atoms, model) for atoms in atoms] + + predicted = [sparse(dropdims(pred, dims=3)) for pred in predict.(Ref(model), atoms, images)] + + groud_truth = h5open(data_path) do database + [sparse(dropdims(get_matrix(database[system]), dims=3)) for system in target_systems] + end + + data_dict = Dict{String, Dict}() + + for (system, pred, gt) in zip(target_systems, predicted, groud_truth) + + data_dict[system] = Dict("gt"=>gt, "pred"=>pred) + + end + + return data_dict + +end + + +function evaluate_on_data(model_path::String, data_path::String) + + matrices_dict = get_matrices_dict(model_path, data_path) + dict_path = joinpath(dirname(model_path), split(split(basename(data_path), ".")[1], "_")[end]*"_dict.jls") + open(dict_path, "w") do file + serialize(file, matrices_dict) + end + +end + + +for data_path in [w6_path, w12_path, w21_path] + evaluate_on_data(model_path, data_path) +end \ No newline at end of file diff --git a/examples/H2O/hyperparameter_searching_H_H2O_1.jl b/examples/H2O/hyperparameter_searching_H_H2O_1.jl new file mode 100644 index 0000000..d34ddc5 --- /dev/null +++ b/examples/H2O/hyperparameter_searching_H_H2O_1.jl @@ -0,0 +1,152 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +# addprocs(SlurmManager()) +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/H2O_H_aims.h5" +data_name = split(basename(database_path), ".")[1] +output_path = joinpath("./Result/hyperparameter_searching_H_H2O_1", data_name) +nsamples = 512 +mkpath(output_path) + +# --------------------------- +# |function to get the model| +# --------------------------- + +nfolds = 5 + +model_type = "H" + +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +function get_model(basis_definition::BasisDef, d_max::Int64=14, r_cutoff::Float64=6., + label::String="H", meta_data::Union{Dict, Nothing}=nothing) + + on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(d_max), + # Environmental cutoff radius + GlobalParams(r_cutoff), + # Scaling factor "r₀" + GlobalParams(0.9) + ) + + # Off-site parameter deceleration + off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(d_max), + # Bond cutoff radius + GlobalParams(r_cutoff), + # Environmental cutoff radius + GlobalParams(r_cutoff/2.0), + ) + + model = Model(basis_definition, on_site_parameters, off_site_parameters, label, meta_data) + + return model + + @info "finished buiding a model with d_max=$(d_max), r_cutoff=$(r_cutoff)" + +end + +# --------------------------------- +# |***cross validation function***| +# --------------------------------- +# single model evaluation +function evaluate_single(model::Model, database_path::String, train_systems::Vector{String}, test_systems::Vector{String}) + + model = deepcopy(model) + + #model fitting + h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) + end + + #prediction + atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end + images = cell_translations.(atoms, Ref(model)) + predicted = predict.(Ref(model), atoms, images) + + #groud truth data + get_matrix = Dict( # Select an appropriate function to load the target matrix + "H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, + "Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] + gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] + end + + error = predicted-gt + + return error, atoms + +end + + +function cross_validation(model::Model, database_path::String, nfolds::Int=5) + + + target_systems = h5open(database_path) do database keys(database) end + rng = MersenneTwister(1234) + target_systems = shuffle(rng, target_systems)[begin:nsamples] + target_systems = [target_systems[i:nfolds:end] for i in 1:nfolds] + + errors = [] + atomsv = [] + for fold in 1:nfolds # nfolds:nfolds # 1:nfolds + train_systems = vcat(target_systems[1:fold-1]..., target_systems[fold+1:end]...) + test_systems = target_systems[fold] + error, atoms = evaluate_single(model, database_path, train_systems, test_systems) + push!(errors, error) + push!(atomsv, atoms) + end + errors = vcat(errors...) + atomsv = vcat(atomsv...) + + data_dict = get_error_dict(errors, atomsv, model) + + return data_dict + +end + + +# --------------------- +# |initialize the test| +# --------------------- +data_dict = Dict{Tuple, Dict}() +output_path_figs = joinpath(output_path, "figures") + +for d_max in 6:14 + r_cutoff = 6. + dict_name = (d_max, r_cutoff) + model = get_model(basis_definition, d_max, r_cutoff, model_type) + data_dict_sub = cross_validation(model, database_path, nfolds) + data_dict[dict_name] = data_dict_sub + open(joinpath(output_path, "data_dict_d_max_r_6.jls"), "w") do file + serialize(file, data_dict) + end + plot_hyperparams(data_dict, "d_max", output_path_figs) + @info "finished testing a model with d_max=$(d_max), r_cutoff=$(r_cutoff)" +end \ No newline at end of file diff --git a/examples/H2O/hyperparameter_searching_S_H2O_1.jl b/examples/H2O/hyperparameter_searching_S_H2O_1.jl new file mode 100644 index 0000000..52b7c06 --- /dev/null +++ b/examples/H2O/hyperparameter_searching_S_H2O_1.jl @@ -0,0 +1,152 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +# addprocs(SlurmManager()) +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/H2O_H_aims.h5" +data_name = split(basename(database_path), ".")[1] +output_path = joinpath("./Result/hyperparameter_searching_S_H2O_1", data_name*"_crossval") +nsamples = 512 +mkpath(output_path) + +# --------------------------- +# |function to get the model| +# --------------------------- + +nfolds = 5 + +model_type = "S" + +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +function get_model(basis_definition::BasisDef, d_max::Int64=14, r_cutoff::Float64=6., + label::String="H", meta_data::Union{Dict, Nothing}=nothing) + + on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(1), # (2), + # Maximum polynomial degree + GlobalParams(d_max), + # Environmental cutoff radius + GlobalParams(r_cutoff), + # Scaling factor "r₀" + GlobalParams(0.9) + ) + + # Off-site parameter deceleration + off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(d_max), + # Bond cutoff radius + GlobalParams(r_cutoff), + # Environmental cutoff radius + GlobalParams(r_cutoff/2.0), + ) + + model = Model(basis_definition, on_site_parameters, off_site_parameters, label, meta_data) + + return model + + @info "finished buiding a model with d_max=$(d_max), r_cutoff=$(r_cutoff)" + +end + +# --------------------------------- +# |***cross validation function***| +# --------------------------------- +# single model evaluation +function evaluate_single(model::Model, database_path::String, train_systems::Vector{String}, test_systems::Vector{String}) + + model = deepcopy(model) + + #model fitting + h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) + end + + #prediction + atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end + images = cell_translations.(atoms, Ref(model)) + predicted = predict.(Ref(model), atoms, images) + + #groud truth data + get_matrix = Dict( # Select an appropriate function to load the target matrix + "H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, + "Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] + gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] + end + + error = predicted-gt + + return error, atoms + +end + + +function cross_validation(model::Model, database_path::String, nfolds::Int=5) + + + target_systems = h5open(database_path) do database keys(database) end + rng = MersenneTwister(1234) + target_systems = shuffle(rng, target_systems)[begin:nsamples] + target_systems = [target_systems[i:nfolds:end] for i in 1:nfolds] + + errors = [] + atomsv = [] + for fold in 1:nfolds # nfolds:nfolds # 1:nfolds + train_systems = vcat(target_systems[1:fold-1]..., target_systems[fold+1:end]...) + test_systems = target_systems[fold] + error, atoms = evaluate_single(model, database_path, train_systems, test_systems) + push!(errors, error) + push!(atomsv, atoms) + end + errors = vcat(errors...) + atomsv = vcat(atomsv...) + + data_dict = get_error_dict(errors, atomsv, model) + + return data_dict + +end + + +# --------------------- +# |initialize the test| +# --------------------- +data_dict = Dict{Tuple, Dict}() +output_path_figs = joinpath(output_path, "figures") + +for d_max in 6:14 + r_cutoff = 6. + dict_name = (d_max, r_cutoff) + model = get_model(basis_definition, d_max, r_cutoff, model_type) + data_dict_sub = cross_validation(model, database_path, nfolds) + data_dict[dict_name] = data_dict_sub + open(joinpath(output_path, "data_dict_co_1_d_max_r_6.jls"), "w") do file + serialize(file, data_dict) + end + plot_hyperparams(data_dict, "d_max", output_path_figs) + @info "finished testing a model with d_max=$(d_max), r_cutoff=$(r_cutoff)" +end \ No newline at end of file diff --git a/examples/H2O/hyperparameter_searching_dm_r0.jl b/examples/H2O/hyperparameter_searching_dm_r0.jl new file mode 100644 index 0000000..a3dace0 --- /dev/null +++ b/examples/H2O/hyperparameter_searching_dm_r0.jl @@ -0,0 +1,153 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +# addprocs(SlurmManager()) +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-500K_3.h5" +data_name = split(basename(database_path), ".")[1] +output_path = joinpath("./Result/hyperparameter_searching_dm_r0", data_name) +nsamples = 128 +mkpath(output_path) + +# --------------------------- +# |function to get the model| +# --------------------------- + +nfolds = 5 + +model_type = "dm" + +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +function get_model(basis_definition::BasisDef, d_max::Int64=14, r_cutoff::Float64=6., r₀::Float64=0.9, + label::String="H", meta_data::Union{Dict, Nothing}=nothing) + + on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(d_max), + # Environmental cutoff radius + GlobalParams(r_cutoff), + # Scaling factor "r₀" + GlobalParams(r₀) + ) + + # Off-site parameter deceleration + off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(d_max), + # Bond cutoff radius + GlobalParams(r_cutoff), + # Environmental cutoff radius + GlobalParams(r_cutoff/2.0), + ) + + model = Model(basis_definition, on_site_parameters, off_site_parameters, label, meta_data) + + return model + + @info "finished buiding a model with d_max=$(d_max), r_cutoff=$(r_cutoff), r₀=$(r₀)" + +end + +# --------------------------------- +# |***cross validation function***| +# --------------------------------- +# single model evaluation +function evaluate_single(model::Model, database_path::String, train_systems::Vector{String}, test_systems::Vector{String}) + + model = deepcopy(model) + + #model fitting + h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) + end + + #prediction + atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end + images = cell_translations.(atoms, Ref(model)) + predicted = predict.(Ref(model), atoms, images) + + #groud truth data + get_matrix = Dict( # Select an appropriate function to load the target matrix + "H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, + "Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] + gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] + end + + error = predicted-gt + + return error, atoms + +end + + +function cross_validation(model::Model, database_path::String, nfolds::Int=5) + + + target_systems = h5open(database_path) do database keys(database) end + rng = MersenneTwister(1234) + target_systems = shuffle(rng, target_systems)[begin:nsamples] + target_systems = [target_systems[i:nfolds:end] for i in 1:nfolds] + + errors = [] + atomsv = [] + for fold in nfolds:nfolds # 1:nfolds + train_systems = vcat(target_systems[1:fold-1]..., target_systems[fold+1:end]...) + test_systems = target_systems[fold] + error, atoms = evaluate_single(model, database_path, train_systems, test_systems) + push!(errors, error) + push!(atomsv, atoms) + end + errors = vcat(errors...) + atomsv = vcat(atomsv...) + + data_dict = get_error_dict(errors, atomsv, model) + + return data_dict + +end + + +# --------------------- +# |initialize the test| +# --------------------- +data_dict = Dict{Tuple, Dict}() +output_path_figs = joinpath(output_path, "figures") + +for r₀ in 0.9:0.3:3.1 + d_max = 14 + r_cutoff = 10. + dict_name = (d_max, r₀) + model = get_model(basis_definition, d_max, r_cutoff, r₀, model_type) + data_dict_sub = cross_validation(model, database_path, nfolds) + data_dict[dict_name] = data_dict_sub + open(joinpath(output_path, "data_dict_d_max_14_r_cutoff.jls"), "w") do file + serialize(file, data_dict) + end + plot_hyperparams(data_dict, "r_cut", output_path_figs) + @info "finished testing a model with d_max=$(d_max), r_cutoff=$(r_cutoff), r₀=$(r₀)" +end \ No newline at end of file diff --git a/examples/H2O/hyperparameter_searching_dm_r_cutoff.jl b/examples/H2O/hyperparameter_searching_dm_r_cutoff.jl new file mode 100644 index 0000000..87ce4d4 --- /dev/null +++ b/examples/H2O/hyperparameter_searching_dm_r_cutoff.jl @@ -0,0 +1,152 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +# addprocs(SlurmManager()) +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-500K_3.h5" +data_name = split(basename(database_path), ".")[1] +output_path = joinpath("./Result/hyperparameter_searching_dm_r_cutoff", data_name) +nsamples = 128 +mkpath(output_path) + +# --------------------------- +# |function to get the model| +# --------------------------- + +nfolds = 5 + +model_type = "dm" + +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +function get_model(basis_definition::BasisDef, d_max::Int64=14, r_cutoff::Float64=6., + label::String="H", meta_data::Union{Dict, Nothing}=nothing) + + on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(d_max), + # Environmental cutoff radius + GlobalParams(r_cutoff), + # Scaling factor "r₀" + GlobalParams(0.9) + ) + + # Off-site parameter deceleration + off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(d_max), + # Bond cutoff radius + GlobalParams(r_cutoff), + # Environmental cutoff radius + GlobalParams(r_cutoff/2.0), + ) + + model = Model(basis_definition, on_site_parameters, off_site_parameters, label, meta_data) + + return model + + @info "finished buiding a model with d_max=$(d_max), r_cutoff=$(r_cutoff)" + +end + +# --------------------------------- +# |***cross validation function***| +# --------------------------------- +# single model evaluation +function evaluate_single(model::Model, database_path::String, train_systems::Vector{String}, test_systems::Vector{String}) + + model = deepcopy(model) + + #model fitting + h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) + end + + #prediction + atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end + images = cell_translations.(atoms, Ref(model)) + predicted = predict.(Ref(model), atoms, images) + + #groud truth data + get_matrix = Dict( # Select an appropriate function to load the target matrix + "H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, + "Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] + gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] + end + + error = predicted-gt + + return error, atoms + +end + + +function cross_validation(model::Model, database_path::String, nfolds::Int=5) + + + target_systems = h5open(database_path) do database keys(database) end + rng = MersenneTwister(1234) + target_systems = shuffle(rng, target_systems)[begin:nsamples] + target_systems = [target_systems[i:nfolds:end] for i in 1:nfolds] + + errors = [] + atomsv = [] + for fold in nfolds:nfolds # 1:nfolds + train_systems = vcat(target_systems[1:fold-1]..., target_systems[fold+1:end]...) + test_systems = target_systems[fold] + error, atoms = evaluate_single(model, database_path, train_systems, test_systems) + push!(errors, error) + push!(atomsv, atoms) + end + errors = vcat(errors...) + atomsv = vcat(atomsv...) + + data_dict = get_error_dict(errors, atomsv, model) + + return data_dict + +end + + +# --------------------- +# |initialize the test| +# --------------------- +data_dict = Dict{Tuple, Dict}() +output_path_figs = joinpath(output_path, "figures") + +for r_cutoff in 6.:2.:14. + d_max = 14 + dict_name = (d_max, r_cutoff) + model = get_model(basis_definition, d_max, r_cutoff, model_type) + data_dict_sub = cross_validation(model, database_path, nfolds) + data_dict[dict_name] = data_dict_sub + open(joinpath(output_path, "data_dict_d_max_14_r_cutoff.jls"), "w") do file + serialize(file, data_dict) + end + plot_hyperparams(data_dict, "r_cut", output_path_figs) + @info "finished testing a model with d_max=$(d_max), r_cutoff=$(r_cutoff)" +end \ No newline at end of file diff --git a/examples/H2O/intra_vs_inter.jl b/examples/H2O/intra_vs_inter.jl new file mode 100644 index 0000000..bfc9515 --- /dev/null +++ b/examples/H2O/intra_vs_inter.jl @@ -0,0 +1,150 @@ +using Serialization +using JuLIP +using LinearAlgebra: norm, pinv +using StatsBase +using Statistics +using Plots + +# model_path = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_2_500K_rcut_10/dyn-wd-500K_3.bin" + +# bond_cutoff = 1. +num_bond = 2 + +Hartree2meV=27211.4 + +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) +n_1 = sum([2*i+1 for i in basis_definition[1]]) +n_8 = sum([2*i+1 for i in basis_definition[8]]) +basis_num = Dict(1=>n_1, 8=>n_8) +mol_basis_num = n_1*2 + n_8 + +data_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_2_500K_4_300K/matrix_dict.jls" +data_dict = open(data_file, "r") do file + deserialize(file) +end + +predicts = data_dict["predicted"].*Hartree2meV +gts = data_dict["gt"].*Hartree2meV +atoms_list = data_dict["atoms"] + +# model = deserialize(model_path) +# images_list = cell_translations.(atoms_list, Ref(model)) +# pairs = [ for image in images for atom in atoms ] + +errors_list = [abs.(predict-gt) for (predict, gt) in zip(predicts, gts)] +error_intra = cat([cat(errors[1:mol_basis_num, 1:mol_basis_num], errors[1+mol_basis_num:end, 1+mol_basis_num:end], dims=3) for errors in errors_list]..., dims=3) +error_inter = cat([cat(errors[1+mol_basis_num:end, 1:mol_basis_num], errors[1:mol_basis_num, 1+mol_basis_num:end], dims=3) for errors in errors_list]..., dims=3) +mae_intra = mean(error_intra) +mae_inter = mean(error_inter) + +gts_intra = cat([cat(gt[1:mol_basis_num, 1:mol_basis_num], gt[1+mol_basis_num:end, 1+mol_basis_num:end], dims=3) for gt in gts]..., dims=3) +gts_inter = cat([cat(gt[1+mol_basis_num:end, 1:mol_basis_num], gt[1:mol_basis_num, 1+mol_basis_num:end], dims=3) for gt in gts]..., dims=3) +gt_intra = std(gts_intra) +gt_inter = std(gts_inter) + +mae_norm_intra = mae_intra/gt_intra +mae_norm_inter = mae_inter/gt_inter + +mae = mean(cat(errors_list..., dims=3)) +println("mae, mae_intra, mae_inter, mae_norm_intra, mae_norm_inter: $mae, $mae_intra, $mae_inter, $mae_norm_intra, $mae_norm_inter") + +mae_intra_plot = dropdims(mean(error_intra, dims=3), dims=3) +mae_inter_plot = dropdims(mean(error_inter, dims=3), dims=3) + +p=heatmap(mae_intra_plot, size=(800, 750), color=:jet, title="Intra molecule MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +savefig(joinpath(dirname(data_file), "mae_intra_plot.png")) +display(p) + +p=heatmap(mae_inter_plot, size=(800, 750), color=:jet, title="Inter molecule MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +savefig(joinpath(dirname(data_file), "mae_inter_plot.png")) +display(p) + + + +mae_norm_intra = dropdims(mean(error_intra, dims=3)./std(gts_intra, dims=3), dims=3) +mae_norm_inter = dropdims(mean(error_inter, dims=3)./std(gts_inter, dims=3), dims=3) + +p=heatmap(mae_norm_intra, size=(800, 750), color=:jet, title="Normalized intra molecule MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold)) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +p1=plot(p, right_margin=13Plots.mm) +savefig(p1, joinpath(dirname(data_file), "mae_norm_intra_plot.png")) +display(p) + +p=heatmap(mae_norm_inter, size=(800, 750), color=:jet, title="Normalized inter molecule MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold), clims=(0, 1)) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +p1=plot(p, right_margin=10Plots.mm) +savefig(p1, joinpath(dirname(data_file), "mae_norm_inter_plot.png")) +display(p) + + + + +# for (atoms, images, errors) in zip(atoms_list, images_list, errors_list) +# shift_vectors = collect(eachrow(images' * atoms.cell)) +# pos = hcat(collect.(atoms.X)...) +# dis = reshape(pos, size(pos, 1), size(pos, 2), 1) .- reshape(pos, size(pos, 1), 1, size(pos, 2)) +# dis = reshape(dis, (1, size(dis)...)) +# dis = dis .- reshape(hcat(collect.(shift_vectors)...)', length(shift_vectors), 3, 1, 1) +# dis = dropdims(mapslices(norm, dis; dims=2); dims=2) +# dis = permutedims(dis, (3,2,1)) +# for i_idx, (X₁, Z₁) in enumerate(atoms.X, atoms.Z) +# if Z₁ == 8 +# H_OO = extract_block(errors, i_idx, i_idx, findall([images[:,i]==[0, 0, 0] for i in range(1,size(images,2))])[1], basis_num, atoms) +# idx_line = partialsortperm(dis[i_idx, :, :][:], 1:3; rev=false)[2:end] +# j_idxes, image_idxes = rem.(idx_line, Ref(length(atoms.Z))), Int.(floor.(idx_line./Ref(length(atoms.Z)))) +# O_indices = i_idx +# H_indices = j_idxes +# i_idxes = [i_idx for i in 1: num_bond] +# for j_idx in j_idxes +# push!(partialsortperm(dis[j_idx, i_idx, :], 1, rev=false), image_idxes) +# push!(j_idx, i_idxes) +# push!(i_idx, j_idxes) +# push!(partialsortperm(dis[j_idx, i_idx, :], 1, rev=false), image_idxes) + + + + +# for (X₂, Z₂) in zip(atoms.X, atoms.Z) +# if Z₂==1 and + +# norm.( X₁ - X₂ + shift_vectors) + + + + + +# mask = norm.(atoms.X - atoms.X + shift_vectors[block_idxs[3, :]]) .<= distance + + + + + + + +# function extract_block(matrix::Array{Float64, 3}, i_idx::Int, j_idx::Int, image_idx::Int, basis_num::Dict, atoms::Atoms) +# idx_begin = vcat([1],cumsum([basis_num[i] for i in atoms.Z])[1:end-1].+1) +# idx_end = cumsum([basis_num[i] for i in atoms.Z]) +# return matrix[idx_begin[i_idx]: idx_end[i_idx], idx_begin[j_idx]: idx_end[j_idx], image_idx] +# end + + + + diff --git a/examples/H2O/monomer.jl b/examples/H2O/monomer.jl new file mode 100644 index 0000000..5e86cb3 --- /dev/null +++ b/examples/H2O/monomer.jl @@ -0,0 +1,180 @@ +using Serialization +using JuLIP +using LinearAlgebra: norm, pinv +using StatsBase +using Statistics +using Plots + +# model_path = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_2_500K_rcut_10/dyn-wd-500K_3.bin" + +# bond_cutoff = 1. +num_bond = 2 + +Hartree2meV=27211.4 + +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) +n_1 = sum([2*i+1 for i in basis_definition[1]]) +n_8 = sum([2*i+1 for i in basis_definition[8]]) +basis_num = Dict(1=>n_1, 8=>n_8) +mol_basis_num = n_1*2 + n_8 + +data_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_1_rcut_6/matrix_dict_test.jls" +data_dict = open(data_file, "r") do file + deserialize(file) +end + +predicts = data_dict["predicted"].*Hartree2meV +gts = data_dict["gt"].*Hartree2meV +atoms_list = data_dict["atoms"] + +# model = deserialize(model_path) +# images_list = cell_translations.(atoms_list, Ref(model)) +# pairs = [ for image in images for atom in atoms ] + +# errors_list = [abs.(predict-gt) for (predict, gt) in zip(predicts, gts)] +# error_intra = cat([cat(errors[1:mol_basis_num, 1:mol_basis_num], errors[1+mol_basis_num:end, 1+mol_basis_num:end], dims=3) for errors in errors_list]..., dims=3) +# error_inter = cat([cat(errors[1+mol_basis_num:end, 1:mol_basis_num], errors[1:mol_basis_num, 1+mol_basis_num:end], dims=3) for errors in errors_list]..., dims=3) +# mae_intra = mean(error_intra) +# mae_inter = mean(error_inter) + +# gts_intra = cat([cat(gt[1:mol_basis_num, 1:mol_basis_num], gt[1+mol_basis_num:end, 1+mol_basis_num:end], dims=3) for gt in gts]..., dims=3) +# gts_inter = cat([cat(gt[1+mol_basis_num:end, 1:mol_basis_num], gt[1:mol_basis_num, 1+mol_basis_num:end], dims=3) for gt in gts]..., dims=3) +# gt_intra = std(gts_intra) +# gt_inter = std(gts_inter) + +errors = abs.(cat((gts.-predicts)..., dims=3)) # cat([abs.(predict-gt) for (predict, gt) in zip(predicts, gts)]..., dims=3) +gts = cat(gts..., dims=3) +mae = mean(errors) +gt = std(gts) +mae_norm = mae/gt + +mae = mean(errors) +println("mae, mae_norm: $mae, $mae_norm") + +# mae_norm_intra = mae_intra/gt_intra +# mae_norm_inter = mae_inter/gt_inter + +mae_plot = dropdims(mean(errors, dims=3), dims=3) +mae_norm_plot = dropdims(mean(errors, dims=3)./std(gts, dims=3), dims=3) + +p=heatmap(mae_plot, size=(800, 750), color=:jet, title="Monomer MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +savefig(joinpath(dirname(data_file), "mae.png")) +display(p) + +p=heatmap(mae_norm_plot, size=(800, 750), color=:jet, title="Normalized monomer MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold)) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +p1=plot(p, right_margin=13Plots.mm) +savefig(p1, joinpath(dirname(data_file), "mae_norm.png")) +display(p) + + + +# mae_intra_plot = dropdims(mean(error_intra, dims=3), dims=3) +# mae_inter_plot = dropdims(mean(error_inter, dims=3), dims=3) + +# p=heatmap(mae_intra_plot, size=(800, 750), color=:jet, title="Intra molecule MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], +# ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +# vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# hline!(p, [14.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# hline!(p, [29.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# savefig(joinpath(dirname(data_file), "mae_intra_plot.png")) +# display(p) + +# p=heatmap(mae_inter_plot, size=(800, 750), color=:jet, title="Inter molecule MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], +# ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +# vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# hline!(p, [14.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# hline!(p, [29.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# savefig(joinpath(dirname(data_file), "mae_inter_plot.png")) +# display(p) + + + +# mae_norm_intra = dropdims(mean(error_intra, dims=3)./std(gts_intra, dims=3), dims=3) +# mae_norm_inter = dropdims(mean(error_inter, dims=3)./std(gts_inter, dims=3), dims=3) + +# p=heatmap(mae_norm_intra, size=(800, 750), color=:jet, title="Normalized intra molecule MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], +# ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold)) +# vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# hline!(p, [14.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# hline!(p, [29.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# p1=plot(p, right_margin=13Plots.mm) +# savefig(p1, joinpath(dirname(data_file), "mae_norm_intra_plot.png")) +# display(p) + +# p=heatmap(mae_norm_inter, size=(800, 750), color=:jet, title="Normalized intra molecule MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], +# ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold), clims=(0, 1)) +# vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# hline!(p, [14.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# hline!(p, [29.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# p1=plot(p, right_margin=10Plots.mm) +# savefig(p1, joinpath(dirname(data_file), "mae_norm_inter_plot.png")) +# display(p) + + + + +# for (atoms, images, errors) in zip(atoms_list, images_list, errors_list) +# shift_vectors = collect(eachrow(images' * atoms.cell)) +# pos = hcat(collect.(atoms.X)...) +# dis = reshape(pos, size(pos, 1), size(pos, 2), 1) .- reshape(pos, size(pos, 1), 1, size(pos, 2)) +# dis = reshape(dis, (1, size(dis)...)) +# dis = dis .- reshape(hcat(collect.(shift_vectors)...)', length(shift_vectors), 3, 1, 1) +# dis = dropdims(mapslices(norm, dis; dims=2); dims=2) +# dis = permutedims(dis, (3,2,1)) +# for i_idx, (X₁, Z₁) in enumerate(atoms.X, atoms.Z) +# if Z₁ == 8 +# H_OO = extract_block(errors, i_idx, i_idx, findall([images[:,i]==[0, 0, 0] for i in range(1,size(images,2))])[1], basis_num, atoms) +# idx_line = partialsortperm(dis[i_idx, :, :][:], 1:3; rev=false)[2:end] +# j_idxes, image_idxes = rem.(idx_line, Ref(length(atoms.Z))), Int.(floor.(idx_line./Ref(length(atoms.Z)))) +# O_indices = i_idx +# H_indices = j_idxes +# i_idxes = [i_idx for i in 1: num_bond] +# for j_idx in j_idxes +# push!(partialsortperm(dis[j_idx, i_idx, :], 1, rev=false), image_idxes) +# push!(j_idx, i_idxes) +# push!(i_idx, j_idxes) +# push!(partialsortperm(dis[j_idx, i_idx, :], 1, rev=false), image_idxes) + + + + +# for (X₂, Z₂) in zip(atoms.X, atoms.Z) +# if Z₂==1 and + +# norm.( X₁ - X₂ + shift_vectors) + + + + + +# mask = norm.(atoms.X - atoms.X + shift_vectors[block_idxs[3, :]]) .<= distance + + + + + + + +# function extract_block(matrix::Array{Float64, 3}, i_idx::Int, j_idx::Int, image_idx::Int, basis_num::Dict, atoms::Atoms) +# idx_begin = vcat([1],cumsum([basis_num[i] for i in atoms.Z])[1:end-1].+1) +# idx_end = cumsum([basis_num[i] for i in atoms.Z]) +# return matrix[idx_begin[i_idx]: idx_end[i_idx], idx_begin[j_idx]: idx_end[j_idx], image_idx] +# end + + + + diff --git a/examples/H2O/output_indices.jl b/examples/H2O/output_indices.jl new file mode 100644 index 0000000..2092f90 --- /dev/null +++ b/examples/H2O/output_indices.jl @@ -0,0 +1,26 @@ +using HDF5, Random +using Serialization + +output_path = "./Result/output_indices" +indices_dict = Dict() + +for database_path in ["./Data/Datasets/H2O_H_aims.h5", "./Data/Datasets/dyn-wd-300K_3.h5", "./Data/Datasets/dyn-wd-500K_3.h5"] + + nsamples = 512 #5200 + # Names of the systems to which the model should be fitted + target_systems = h5open(database_path) do database keys(database) end + rng = MersenneTwister(1234) + @assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" + target_systems = shuffle(rng, target_systems)[begin:nsamples] + target_systems = [target_systems[i:5:end] for i in 1:5] + train_systems = vcat(target_systems[1:end-1]...) + test_systems = target_systems[end] + + data_name = split(basename(database_path), ".")[1] + indices_dict[data_name] = Dict("train"=>parse.(Ref(Int), train_systems).+1, "test"=>parse.(Ref(Int), test_systems).+1) + +end + +open(joinpath(output_path, "indices_dict.jls"), "w") do file + serialize(file, indices_dict) +end \ No newline at end of file diff --git a/examples/H2O/python_interface/large/script/function.jl b/examples/H2O/python_interface/large/script/function.jl new file mode 100644 index 0000000..1df963b --- /dev/null +++ b/examples/H2O/python_interface/large/script/function.jl @@ -0,0 +1,29 @@ +module J2P + +using Serialization +using Distributed +using JuLIP: Atoms +using LinearAlgebra +using Statistics +using PyCall + +addprocs(28) +@everywhere begin + using ACEhamiltonians + import ACEhamiltonians: predict +end + + +function predict(atoms::Vector{PyObject}, model::Model) + atoms = atoms_p2j(atoms) + images = cell_translations.(atoms, Ref(model)) + predicted = predict.(Ref(model), atoms, images) + return predicted +end + + +function atoms_p2j(atoms::Vector{PyObject}) + return [Atoms(Z=atom.get_atomic_numbers(), X=transpose(atom.positions), cell=collect(Float64.(I(3) * 100)), pbc=true) for atom in atoms] +end + +end diff --git a/examples/H2O/python_interface/large/script/py4ACE.py b/examples/H2O/python_interface/large/script/py4ACE.py new file mode 100644 index 0000000..c294a1f --- /dev/null +++ b/examples/H2O/python_interface/large/script/py4ACE.py @@ -0,0 +1,46 @@ + +############################configure the julia project and load the interface###################### +import os +import julia + +# Specify the path to your Julia project or environment +os.environ["JULIA_PROJECT"] = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA" +julia.install() +from julia.api import Julia +jl = Julia(compiled_modules=False) + +from julia import Main, Serialization, Base +Main.include("/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/data_interface/function.jl") + + + +##########################################load the model############################################# +model_path = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_1_rcut_6/H2O_H_aims.bin" +model = Serialization.deserialize(model_path) + + +################################define a python format ase.atoms object############################### +from ase import Atoms +import numpy as np +O_H_distance = 0.96 +angle = 104.5 + +angle_rad = np.radians(angle / 2) + +x = O_H_distance * np.sin(angle_rad) +y = O_H_distance * np.cos(angle_rad) +z = 0.0 + +positions = [ + (0, 0, 0), # Oxygen + (x, y, 0), # Hydrogen 1 + (x, -y, 0) # Hydrogen 2 +] +water = Atoms('OH2', positions=positions) + + + +####################################################################################################### +predicted = Main.J2P.predict([water]*64, model) +predicted = [h for h in predicted] +np.save("/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/data_interface/predicted.npy", predicted) diff --git a/examples/H2O/python_interface/large/src/ACEhamiltonians.jl b/examples/H2O/python_interface/large/src/ACEhamiltonians.jl new file mode 100644 index 0000000..c4c3ec0 --- /dev/null +++ b/examples/H2O/python_interface/large/src/ACEhamiltonians.jl @@ -0,0 +1,95 @@ +module ACEhamiltonians + +using JuLIP, JSON, HDF5, Reexport, LinearAlgebra + +using ACE.SphericalHarmonics: SphericalCoords +import ACE.SphericalHarmonics: cart2spher + +export BasisDef, DUAL_BASIS_MODEL, BOND_ORIGIN_AT_MIDPOINT, SYMMETRY_FIX_ENABLED + +# Enabling this will activate the dual basis mode +const DUAL_BASIS_MODEL = true + +# If `true` then bond origin will be set to the bond's mid-point +const BOND_ORIGIN_AT_MIDPOINT = false + +# This can be used to enable/disable the symmetry fix code +const SYMMETRY_FIX_ENABLED = false + +if SYMMETRY_FIX_ENABLED && !BOND_ORIGIN_AT_MIDPOINT + @warn "Symmetry fix code is only viable when the bond origin is set to the midpoint " +end + + +if SYMMETRY_FIX_ENABLED && DUAL_BASIS_MODEL + @warn "It is ill advised to enable the symmetry fix when in dual basis mode" +end + +""" + BasisDef(atomic_number => [ℓ₁, ..., ℓᵢ], ...) + +Provides information about the basis set by specifying the azimuthal quantum numbers (ℓ) +of each shell on each species. Dictionary is keyed by atomic numbers & valued by vectors +of ℓs i.e. `Dict{atomic_number, [ℓ₁, ..., ℓᵢ]}`. + +A minimal basis set for hydrocarbon systems would be `BasisDef(1=>[0], 6=>[0, 0, 1])`. +This declares hydrogen atoms as having only a single s-shell and carbon atoms as having +two s-shells and one p-shell. +""" +BasisDef = Dict{I, Vector{I}} where I<:Integer + +@static if BOND_ORIGIN_AT_MIDPOINT + function cart2spher(r⃗::AbstractVector{T}) where T + @assert length(r⃗) == 3 + # When the length of the vector `r⃗` is zero then the signs can have a + # destabalising effect on the results; i.e. + # cart2spher([0., 0., 0.]) ≠ cart2spher([0. 0., -0.]) + # Hense the following catch: + if norm(r⃗) ≠ 0.0 + φ = atan(r⃗[2], r⃗[1]) + θ = atan(hypot(r⃗[1], r⃗[2]), r⃗[3]) + sinφ, cosφ = sincos(φ) + sinθ, cosθ = sincos(θ) + return SphericalCoords{T}(norm(r⃗), cosφ, sinφ, cosθ, sinθ) + else + return SphericalCoords{T}(0.0, 1.0, 0.0, 1.0, 0.0) + end + end +end + +include("common.jl") +@reexport using ACEhamiltonians.Common + +include("io.jl") +@reexport using ACEhamiltonians.DatabaseIO + +include("parameters.jl") +@reexport using ACEhamiltonians.Parameters + +include("data.jl") +@reexport using ACEhamiltonians.MatrixManipulation + +include("states.jl") +@reexport using ACEhamiltonians.States + +include("basis.jl") +@reexport using ACEhamiltonians.Bases + +include("models.jl") +@reexport using ACEhamiltonians.Models + +include("datastructs.jl") +@reexport using ACEhamiltonians.DataSets + +include("fitting.jl") +@reexport using ACEhamiltonians.Fitting + +include("predicting.jl") +@reexport using ACEhamiltonians.Predicting + +include("properties.jl") +@reexport using ACEhamiltonians.Properties + +include("api/dftbp_api.jl") + +end diff --git a/examples/H2O/python_interface/large/src/api/dftbp_api.jl b/examples/H2O/python_interface/large/src/api/dftbp_api.jl new file mode 100644 index 0000000..efbbb93 --- /dev/null +++ b/examples/H2O/python_interface/large/src/api/dftbp_api.jl @@ -0,0 +1,357 @@ +module DftbpApi +using ACEhamiltonians +using BlockArrays, StaticArrays, Serialization +using LinearAlgebra: norm, diagind +using ACEbase: read_dict, load_json + +using ACEhamiltonians.States: _inner_evaluate + + +export load_model, n_orbs_per_atom, offers_species, offers_species, species_name_to_id, max_interaction_cutoff, + max_environment_cutoff, shells_on_species!, n_shells_on_species, shell_occupancies!, build_on_site_atom_block!, + build_off_site_atom_block! + + + +# WARNING; THIS CODE IS NOT STABLE UNTIL THE SYMMETRY ISSUE HAS BEEN RESOLVED, DO NOT USE. + + +# Todo: +# - resolve unit mess. + +_Bohr2Angstrom = 1/0.188972598857892E+01 +_F64SV = SVector{3, Float64} +_FORCE_SHOW_ERROR = true + + +macro FSE(func) + # Force show error + if _FORCE_SHOW_ERROR + func.args[2] = quote + try + $(func.args[2].args...) + catch e + println("\nError encountered in Julia-DFTB+ API") + for (exc, bt) in current_exceptions() + showerror(stdout, exc, bt) + println(stdout) + println("Terminating....") + # Ensure streams are flushed prior to the `exit` call below + flush(stdout) + flush(stderr) + end + # The Julia thread must be explicitly terminated, otherwise the DFTB+ + # calculation will continue. + exit() + end + end + end + return func +end + + +_sub_block_sizes(species, basis_def) = 2basis_def[species] .+ 1 + +function _reshape_to_block(array, species_1, species_2, basis_definition) + return PseudoBlockArray( + reshape( + array, + number_of_orbitals(species_1, basis_definition), + number_of_orbitals(species_2, basis_definition) + ), + _sub_block_sizes(species_1, basis_definition), + _sub_block_sizes(species_2, basis_definition) + ) +end + +# Setup related function + +_s2n = Dict( + "H"=>1, "He"=>2, "Li"=>3, "Be"=>4, "B"=>5, "C"=>6, "N"=>7, "O"=>8, "F"=>9,"Ne"=>10, + "Na"=>11, "Mg"=>12, "Al"=>13, "Si"=>14, "P"=>15, "S"=>16, "Cl"=>17, "Ar"=>18, + "K"=>19, "Ca"=>20, "Sc"=>21, "Ti"=>22, "V"=>23, "Cr"=>24, "Mn"=>25, "Fe"=>26, + "Co"=>27, "Ni"=>28, "Cu"=>29, "Zn"=>30, "Ga"=>31, "Ge"=>32, "As"=>33, "Se"=>34, + "Br"=>35, "Kr"=>36, "Rb"=>37, "Sr"=>38, "Y"=>39, "Zr"=>40, "Nb"=>41, "Mo"=>42, + "Tc"=>43, "Ru"=>44, "Rh"=>45, "Pd"=>46, "Ag"=>47, "Cd"=>48, "In"=>49, "Sn"=>50, + "Sb"=>51, "Te"=>52, "I"=>53, "Xe"=>54, "Cs"=>55, "Ba"=>56, "La"=>57, "Ce"=>58, + "Pr"=>59, "Nd"=>60, "Pm"=>61, "Sm"=>62, "Eu"=>63, "Gd"=>64, "Tb"=>65, "Dy"=>66, + "Ho"=>67, "Er"=>68, "Tm"=>69, "Yb"=>70, "Lu"=>71, "Hf"=>72, "Ta"=>73, "W"=>74, + "Re"=>75, "Os"=>76, "Ir"=>77, "Pt"=>78, "Au"=>79, "Hg"=>80, "Tl"=>81, "Pb"=>82, + "Bi"=>83, "Po"=>84, "At"=>85, "Rn"=>86, "Fr"=>87, "Ra"=>88, "Ac"=>89, "Th"=>90, + "Pa"=>91, "U"=>92, "Np"=>93, "Pu"=>94, "Am"=>95, "Cm"=>96, "Bk"=>97, "Cf"=>98, + "Es"=>99, "Fm"=>10, "Md"=>10, "No"=>10, "Lr"=>10, "Rf"=>10, "Db"=>10, "Sg"=>10, + "Bh"=>10, "Hs"=>10, "Mt"=>10, "Ds"=>11, "Rg"=>111) + + +"""Return the unique id number associated with species of a given name.""" +@FSE function species_name_to_id(name, model) + # The current implementation of this function is only a temporary measure. + # Once atom names have been replaced with enums and the atomic number in "Basis.id" + # is replaced with a species id then this function will be able interrogate the + # model and its bases for information on what species it supports. Furthermore + # this will allow for multiple species of the same atomic number to be used. + return Int32(_s2n[name]) + +end + +"""Returns true of the supplied model supports the provided species""" +@FSE function offers_species(name, model) + if haskey(_s2n, name) + return _s2n[name] ∈ Set([id[1] for id in keys(model.on_site_bases)]) + else + return false + end +end + +@FSE function n_orbs_per_atom(species_id, model) + return Int32(sum(2model.basis_definition[species_id].+1)) +end + +"""Maximum environmental cutoff distance""" +@FSE function max_environment_cutoff(model) + max_on = maximum(values(model.on_site_parameters.e_cut_out)) + max_off = maximum(values(model.off_site_parameters.e_cut_out)) + distance = max(max_on, max_off) + return distance / _Bohr2Angstrom +end + + +"""Maximum interaction cutoff distance""" +@FSE function max_interaction_cutoff(model) + distance = maximum([env.r0cut for env in envelope.(values(model.off_site_bases))]) + return distance / _Bohr2Angstrom +end + + +@FSE function n_shells_on_species(species, model) + return Int32(length(model.basis_definition[species])) +end + +@FSE function shells_on_species!(array, species, model) + shells = model.basis_definition[species] + if length(array) ≠ length(shells) + println("shells_on_species!: Provided array is of incorrect length.") + throw(BoundsError("shells_on_species!: Provided array is of incorrect length.")) + end + array[:] = shells + nothing +end + +@FSE function shell_occupancies!(array, species, model) + if !haskey(model.meta_data, "occupancy") + throw(KeyError( + "shell_occupancies!: an \"occupancy\" key must be present in the model's + `meta_data` which provides the occupancies for each shell of each species." + )) + end + + occupancies = model.meta_data["occupancy"][species] + + if length(array) ≠ length(occupancies) + throw(BoundsError("shell_occupancies!: Provided array is of incorrect length.")) + end + + array[:] = occupancies + +end + + +@FSE function load_model(path::String) + if endswith(lowercase(path), ".json") + return read_dict(load_json(path)) + elseif endswith(lowercase(path), ".bin") + return deserialize(path) + else + error("Unknown file extension used; only \"json\" & \"bin\" are supported") + end +end + + +@FSE function _build_atom_state(coordinates, cutoff) + + # Todo: + # - The distance filter can likely be removed as atoms beyond the cutoff will be + # ignored by ACE. Tests will need to be performed to identify which is more + # performant; culling here or letting ACE handle the culling. + + # Build a list of static coordinate vectors, excluding the origin. + positions = map(_F64SV, eachcol(coordinates[:, 2:end])) + + # Exclude positions that lie outside of the cutoff allowed by the model. + positions_in_range = positions[norm.(positions) .≤ cutoff] + + # Construct the associated state object vector + return map(AtomState, positions_in_range) +end + +# Method for when bond origin is the midpoint +# function _build_bond_state(coordinates, envelope) +# # Build a list of static coordinate vectors, excluding the two bonding +# # atoms. +# positions = map(_F64SV, eachcol(coordinates[:, 3:end])) + +# # Coordinates must be rounded to prevent stability issues associated with +# # noise. This mostly only effects situations where atoms lie near the mid- +# # point of a bond. +# positions = [round.(i, digits=8) for i in positions] + + +# # The rest of this function copies code directly from `states.get_state`. +# # Here, rr0 is multiplied by two as vectors provided by DFTB+ point to +# # the midpoint of the bond. ACEhamiltonians expects the bond vector and +# # to be inverted, hence the second position is taken. +# rr0 = _F64SV(round.(coordinates[:, 2], digits=8) * 2) +# states = Vector{BondState{_F64SV, Bool}}(undef, length(positions) + 1) +# states[1] = BondState(_F64SV(round.(coordinates[:, 2], digits=8)), rr0, true) + +# for k=1:length(positions) +# states[k+1] = BondState{_F64SV, Bool}(positions[k], rr0, false) +# end + +# @views mask = _inner_evaluate.(Ref(envelope), states[2:end]) .!= 0.0 +# @views n = sum(mask) + 1 +# @views states[2:n] = states[2:end][mask] +# return states[1:n] +# end + +# Method for when bond origin is the first atoms position +function _build_bond_state(coordinates, envelope) + # Build a list of static coordinate vectors, excluding the two bonding + # atoms. + positions = map(_F64SV, eachcol(coordinates[:, 3:end])) + + # The rest of this function copies code directly from `states.get_state`. + # Here, rr0 is multiplied by two as vectors provided by DFTB+ point to + # the midpoint of the bond. ACEhamiltonians expects the bond vector and + # to be inverted, hence the second position is taken. + rr0 = _F64SV(coordinates[:, 2] * 2) + offset = rr0 / 2 + states = Vector{BondState{_F64SV, Bool}}(undef, length(positions) + 1) + states[1] = BondState(_F64SV(coordinates[:, 2] * 2), rr0, true) + + for k=1:length(positions) + states[k+1] = BondState{_F64SV, Bool}(positions[k] + offset, rr0, false) + end + + @views mask = _inner_evaluate.(Ref(envelope), states[2:end]) .!= 0.0 + @views n = sum(mask) + 1 + @views states[2:n] = states[2:end][mask] + return states[1:n] +end + + +function build_on_site_atom_block!(block::Vector{Float64}, coordinates::Vector{Float64}, species, model) + basis_def = model.basis_definition + n_shells = length(basis_def[species[1]]) + + # Unflatten the coordinates array + coordinates = reshape(coordinates, 3, :) * _Bohr2Angstrom + + # Unflatten the atom-block array and convert it into a PseudoBlockMatrix + block = _reshape_to_block(block, species[1], species[1], basis_def) + + # On-site atom block of the overlap matrix are just an identify matrix + if model.label == "S" + block .= 0.0 + block[diagind(block)] .= 1.0 + return nothing + end + + # Loop over all shell pairs + for i_shell=1:n_shells + for j_shell=i_shell:n_shells + + # Pull out the associated sub-block as a view + @views sub_block = block[Block(i_shell, j_shell)] + + # Select the appropriate model + basis = model.on_site_bases[(species[1], i_shell, j_shell)] + + # Construct the on-site state taking into account the required cutoff + state = _build_atom_state(coordinates, radial(basis).R.ru) + + # Make the prediction + predict!(sub_block, basis, state) + + # Set the symmetrically equivalent block when appropriate + if i_shell ≠ j_shell + @views block[Block(j_shell, i_shell)] = sub_block' + end + end + end + +end + + +@FSE function build_off_site_atom_block!(block::Vector{Float64}, coordinates::Vector{Float64}, species, model) + # Need to deal with situation where Z₁ > Z₂ + basis_def = model.basis_definition + species_i, species_j = species[1:2] + n_shells_i = number_of_shells(species_i, basis_def) + n_shells_j = number_of_shells(species_j, basis_def) + + # Unflatten the coordinates array + coordinates = reshape(coordinates, 3, :) * _Bohr2Angstrom + + # Unflatten the atom-block array and convert it into a PseudoBlockMatrix + block = _reshape_to_block(block, species_i, species_j, basis_def) + + # By default only interactions where species-i ≥ species-j are defined as + # adding interactions for species-i < species-j would be redundant. + if species_i > species_j + block = block' + n_shells_i, n_shells_j = n_shells_j, n_shells_i + reflect_state = true + else + reflect_state = false + end + + # Loop over all shell pairs + for i_shell=1:n_shells_i + for j_shell=1:n_shells_j + + # Skip over i_shell > j_shell homo-atomic interactions + species_i ≡ species_j && i_shell > j_shell && continue + + # Pull out the associated sub-block as a view + @views sub_block = block[Block(i_shell, j_shell)] + + # Select the appropriate model + basis = model.off_site_bases[(species[1], species[2], i_shell, j_shell)] + + # Construct the on-site state taking into account the required cutoff + state = _build_bond_state(coordinates, envelope(basis)) + + if reflect_state + state = reflect.(state) + end + + # Make the prediction + predict!(sub_block, basis, state) + + if species_i ≡ species_j + @views predict!(block[Block(j_shell, i_shell)]', basis, reflect.(state)) + end + + end + end + +end + + +# if model.label == "H" +# basis = model.off_site_bases[(species[1], species[2], 1, 1)] +# state = _build_bond_state(coordinates, envelope(basis)) +# dump("states.bin", state) +# dump("atom_blocks.bin", block) +# end + +# function dump(path, data::T) where T +# data_set = isfile(path) ? deserialize(path) : T[] +# append!(data_set, (data,)) +# serialize(path, data_set) +# nothing +# end + + +end \ No newline at end of file diff --git a/examples/H2O/python_interface/large/src/basis.jl b/examples/H2O/python_interface/large/src/basis.jl new file mode 100644 index 0000000..553e859 --- /dev/null +++ b/examples/H2O/python_interface/large/src/basis.jl @@ -0,0 +1,579 @@ +module Bases + +using ACEhamiltonians, ACE, ACEbase, SparseArrays, LinearAlgebra, ACEatoms, JuLIP + +using ACEhamiltonians.Parameters: OnSiteParaSet, OffSiteParaSet +using ACE: SymmetricBasis, SphericalMatrix, Utils.RnYlm_1pbasis, SimpleSparseBasis, + CylindricalBondEnvelope, Categorical1pBasis, cutoff_radialbasis, cutoff_env, + get_spec, coco_dot + +using ACEbase.ObjectPools: VectorPool +using ACEhamiltonians: BOND_ORIGIN_AT_MIDPOINT, SYMMETRY_FIX_ENABLED + + +import ACEbase: read_dict, write_dict +import LinearAlgebra.adjoint, LinearAlgebra.transpose, Base./ +import Base +import ACE: SphericalMatrix + + +export AHSubModel, radial, angular, categorical, envelope, on_site_ace_basis, off_site_ace_basis, filter_offsite_be, is_fitted, SubModel, AnisoSubModel +""" +TODO: + - A warning should perhaps be given if no filter function is given when one is + expected; such as off-site functions. If no-filter function is desired than + a dummy filter should be required. + - Improve typing for the Model structure. + - Replace e_cutₒᵤₜ, e_cutᵢₙ, etc. with more "typeable" names. +""" +###################### +# SubModel Structure # +###################### +# Todo: +# - Document +# - Give type information +# - Serialization routines + + + +# ╔══════════╗ +# ║ SubModel ║ +# ╚══════════╝ + +abstract type AHSubModel end + + + +""" + + +A Linear ACE model for modelling symmetry invariant interactions. +In the context of Hamiltonian, this is just a (sub-)model for some specific +blocks and is hence called SubModel + +# Fields +- `basis::SymmetricBasis`: +- `id::Tuple`: +- `coefficients::Vector`: +- `mean::Matrix`: + +""" +struct SubModel{T₁<:SymmetricBasis, T₂, T₃, T₄} <: AHSubModel + basis::T₁ + id::T₂ + coefficients::T₃ + mean::T₄ + + function SubModel(basis, id) + t = ACE.valtype(basis) + F = real(t.parameters[5]) + SubModel(basis, id, zeros(F, length(basis)), zeros(F, size(zero(t)))) + end + + function SubModel(basis::T₁, id::T₂, coefficients::T₃, mean::T₄) where {T₁, T₂, T₃, T₄} + new{T₁, T₂, T₃, T₄}(basis, id, coefficients, mean) + end + +end + +""" + +Another linear ACE model for modelling symmetry variant interactions. + + +- `basis::SymmetricBasis`: +- `basis_i::SymmetricBasis`: +- `id::Tuple`: +- `coefficients::Vector`: +- `coefficients_i::Vector`: +- `mean::Matrix`: +- `mean_i::Matrix`: + +""" +struct AnisoSubModel{T₁<:SymmetricBasis, T₂<:SymmetricBasis, T₃, T₄, T₅, T₆, T₇} <: AHSubModel + basis::T₁ + basis_i::T₂ + id::T₃ + coefficients::T₄ + coefficients_i::T₅ + mean::T₆ + mean_i::T₇ + + function AnisoSubModel(basis, basis_i, id) + t₁, t₂ = ACE.valtype(basis), ACE.valtype(basis_i) + F = real(t₁.parameters[5]) + AnisoSubModel( + basis, basis_i, id, zeros(F, length(basis)), zeros(F, length(basis_i)), + zeros(F, size(zero(t₁))), zeros(F, size(zero(t₂)))) + end + + function AnisoSubModel(basis::T₁, basis_i::T₂, id::T₃, coefficients::T₄, coefficients_i::T₅, mean::T₆, mean_i::T₇) where {T₁, T₂, T₃, T₄, T₅, T₆, T₇} + new{T₁, T₂, T₃, T₄, T₅, T₆, T₇}(basis, basis_i, id, coefficients, coefficients_i, mean, mean_i) + end +end + +AHSubModel(basis, id) = SubModel(basis, id) +AHSubModel(basis, basis_i, id) = AnisoSubModel(basis, basis_i, id) + + +# ╭──────────┬───────────────────────╮ +# │ SubModel │ General Functionality │ +# ╰──────────┴───────────────────────╯ +"""Boolean indicating whether a `SubModel` instance is fitted; i.e. has non-zero coefficients""" +is_fitted(submodel::AHSubModel) = !all(submodel.coefficients .≈ 0.0) || !all(submodel.mean .≈ 0.0) + + +"""Check if two `SubModel` instances are equivalent""" +function Base.:(==)(x::T₁, y::T₂) where {T₁<:AHSubModel, T₂<:AHSubModel} + + # Check that the ID's, coefficients and means match up first + check = x.id == y.id && size(x.mean) == size(y.mean) && x.mean == y.mean + + # If they don't then return false. Otherwise perform a check of the basis object + # itself. A try/catch block must be used when comparing the bases as this can + # result in a DimensionMismatch. + if !check + return check + else + try + return x.basis == y.basis + catch y + if isa(y, DimensionMismatch) + return false + else + rethrow(y) + end + + end + end +end + + +"""Expected shape of the sub-block associated with the `SubModel`; 3×3 for a pp basis etc.""" +Base.size(submodel::AHSubModel) = (ACE.valtype(submodel.basis).parameters[3:4]...,) + +# """Expected type of resulting sub-blocks.""" +# Base.valtype(::Basis{T}) where T = T + +"""Expected type of resulting sub-blocks.""" +function Base.valtype(::AHSubModel) + throw("AHSubModel structure type has been changed this function must be updated.") +end + + +"""Azimuthal quantum numbers associated with the `SubModel`.""" +azimuthals(submodel::AHSubModel) = (ACE.valtype(submodel.basis).parameters[1:2]...,) + +"""Returns a boolean indicating if the submodel instance represents an on-site interaction.""" +Parameters.ison(x::AHSubModel) = length(x.id) ≡ 3 + + +""" + _filter_bases(submodel, type) + +Helper function to retrieve specific submodel function information out of a `AHSubModel` instance. +This is an internal function which is not expected to be used outside of this module. + +Arguments: +- `submodel::AHSubModel`: submodel instance from which function is to be extracted. +- `type::DataType`: type of the submodel functions to extract; e.g. `CylindricalBondEnvelope`. +""" +function _filter_bases(submodel::AHSubModel, T) + functions = filter(i->i isa T, submodel.basis.pibasis.basis1p.bases) + if length(functions) == 0 + error("Could not locate submodel function matching the supplied type") + elseif length(functions) ≥ 2 + @warn "Multiple matching submodel functions found, only the first will be returned" + end + return functions[1] +end + +"""Extract and return the radial component of a `AHSubModel` instance.""" +radial(submodel::AHSubModel) = _filter_bases(submodel, ACE.Rn1pBasis) + +"""Extract and return the angular component of a `AHSubModel` instance.""" +angular(submodel::AHSubModel) = _filter_bases(submodel, ACE.Ylm1pBasis) + +"""Extract and return the categorical component of a `AHSubModel` instance.""" +categorical(submodel::AHSubModel) = _filter_bases(submodel, ACE.Categorical1pBasis) + +"""Extract and return the bond envelope component of a `AHSubModel` instance.""" +envelope(submodel::AHSubModel) = _filter_bases(submodel, ACE.BondEnvelope) + +#TODO: add extract function for species1p basis, if needed + + +# ╭──────────┬──────────────────╮ +# │ SubModel │ IO Functionality │ +# ╰──────────┴──────────────────╯ +""" + write_dict(submodel[,hash_basis]) + +Convert a `SubModel` structure instance into a representative dictionary. + +# Arguments +- `submodel::SubModel`: the `SubModel` instance to parsed. +- `hash_basis::Bool`: ff `true` then hash values will be stored in place of + the `SymmetricBasis` objects. +""" +function write_dict(submodel::T, hash_basis=false) where T<:SubModel + return Dict( + "__id__"=>"SubModel", + "basis"=>hash_basis ? string(hash(submodel.basis)) : write_dict(submodel.basis), + "id"=>submodel.id, + "coefficients"=>write_dict(submodel.coefficients), + "mean"=>write_dict(submodel.mean)) + +end + + +"""Instantiate a `SubModel` instance from a representative dictionary.""" +function ACEbase.read_dict(::Val{:SubModel}, dict::Dict) + return SubModel( + read_dict(dict["basis"]), + Tuple(dict["id"]), + read_dict(dict["coefficients"]), + read_dict(dict["mean"])) +end + + +function Base.show(io::IO, submodel::T) where T<:AHSubModel + print(io, "$(nameof(T))(id: $(submodel.id), fitted: $(is_fitted(submodel)))") +end + + +# ╔════════════════════════╗ +# ║ ACE Basis Constructors ║ +# ╚════════════════════════╝ + +# Codes hacked to proceed to enable multispecies basis construction +ACE.get_spec(basis::Species1PBasis, i::Integer) = (μ = basis.zlist.list[i],) +ACE.get_spec(basis::Species1PBasis) = ACE.get_spec.(Ref(basis), 1:length(basis)) +# Base.length(a::Nothing) = 0 + +@doc raw""" + + on_site_ace_basis(ℓ₁, ℓ₂, ν, deg, e_cutₒᵤₜ[, r0]) + +Initialise a simple on-site `SymmetricBasis` instance with sensible default parameters. + +The on-site `SymmetricBasis` entities are produced by applying a `SimpleSparseBasis` +selector to a `Rn1pBasis` instance. The latter of which is initialised via the `Rn_basis` +method, using all the defaults associated therein except `e_cutₒᵤₜ` and `e_cutᵢₙ` which are +provided by this function. This facilitates quick construction of simple on-site `Bases` +instances; if more fine-grain control over over the initialisation process is required +then bases must be instantiated manually. + +# Arguments +- `(ℓ₁,ℓ₂)::Integer`: azimuthal numbers of the basis function. +- `ν::Integer`: maximum correlation order = body order - 1. +- `deg::Integer`: maximum polynomial degree. +- `e_cutₒᵤₜ::AbstractFloat`: only atoms within the specified cutoff radius will contribute + to the local environment. +- `r0::AbstractFloat`: scaling parameter (typically set to the nearest neighbour distances). +- `species::Union{Nothing, Vector{AtomicNumber}}`: A set of species of a system + +# Returns +- `basis::SymmetricBasis`: ACE basis entity for modelling the specified interaction. + +""" +function on_site_ace_basis(ℓ₁::I, ℓ₂::I, ν::I, deg::I, e_cutₒᵤₜ::F, r0::F=2.5; species::Union{Nothing, Vector{AtomicNumber}}=nothing + ) where {I<:Integer, F<:AbstractFloat} + # Build i) a matrix indicating the desired sub-block shape, ii) the one + # particle Rₙ·Yₗᵐ basis describing the environment, & iii) the basis selector. + # Then instantiate the SymmetricBasis required by the Basis structure. + if !isnothing(species) + return SymmetricBasis( + SphericalMatrix(ℓ₁, ℓ₂; T=ComplexF64), + Species1PBasis(species) * RnYlm_1pbasis(maxdeg=deg, r0=r0, rcut=e_cutₒᵤₜ), + SimpleSparseBasis(ν, deg)) + else + return SymmetricBasis( + SphericalMatrix(ℓ₁, ℓ₂; T=ComplexF64), + RnYlm_1pbasis(maxdeg=deg, r0=r0, rcut=e_cutₒᵤₜ), + SimpleSparseBasis(ν, deg)) + end +end + + + +function _off_site_ace_basis_no_sym(ℓ₁::I, ℓ₂::I, ν::I, deg::I, b_cut::F, e_cutₒᵤₜ::F=5.; + λₙ::F=.5, λₗ::F=.5, species::Union{Nothing, Vector{AtomicNumber}}=nothing) where {I<:Integer, F<:AbstractFloat} + + + # Bond envelope which controls which atoms are seen by the bond. + @static if BOND_ORIGIN_AT_MIDPOINT + env = CylindricalBondEnvelope(b_cut, e_cutₒᵤₜ, e_cutₒᵤₜ, floppy=false, λ=0.0) + else + env = CylindricalBondEnvelope(b_cut, e_cutₒᵤₜ, e_cutₒᵤₜ, floppy=false, λ=0.5) + end + + # Categorical1pBasis is applied to the basis to allow atoms which are part of the + # bond to be treated differently to those that are just part of the environment. + discriminator = Categorical1pBasis([true, false]; varsym=:bond, idxsym=:bond) + + # The basis upon which the above entities act. Note that the internal cutoff "rin" must + # be set to 0.0 to allow for atoms a the bond's midpoint to be observed. + RnYlm = RnYlm_1pbasis( + maxdeg=deg, rcut=cutoff_env(env), + trans=IdTransform(), rin=0.0) + + B1p = RnYlm * env * discriminator + if !isnothing(species) + B1p = Species1PBasis(species) * B1p + end + + # Finally, construct and return the SymmetricBasis entity + basis = SymmetricBasis( + SphericalMatrix(ℓ₁, ℓ₂; T=ComplexF64), + B1p, SimpleSparseBasis(ν + 1, deg), + filterfun = indices -> _filter_offsite_be(indices, deg, λₙ, λₗ)) + + + return basis +end + + +function _off_site_ace_basis_sym(ℓ₁::I, ℓ₂::I, ν::I, deg::I, b_cut::F, e_cutₒᵤₜ::F=5.; + λₙ::F=.5, λₗ::F=.5, species::Union{Nothing, Vector{AtomicNumber}}=nothing) where {I<:Integer, F<:AbstractFloat} + + # TODO: For now the symmetrised basis only works for single species case but it can be easily extended + @assert species == nothing || length(species) <= 1 + + basis = _off_site_ace_basis_no_sym(ℓ₁, ℓ₂, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ, species = species) + + if ℓ₁ == ℓ₂ + Uᵢ = let A = get_spec(basis.pibasis) + U = adjoint.(basis.A2Bmap) * (_perm(A) * _dpar(A)) + (basis.A2Bmap + U) .* 0.5 + end + + # Purge system of linear dependence + svdC = svd(_build_G(Uᵢ)) + rk = rank(Diagonal(svdC.S), rtol = 1e-7) + Uⱼ = sparse(Diagonal(sqrt.(svdC.S[1:rk])) * svdC.U[:, 1:rk]') + U_new = Uⱼ * Uᵢ + + # construct symmetric offsite basis + basis = SymmetricBasis(basis.pibasis, U_new, basis.symgrp, basis.real) + + elseif ℓ₁ > ℓ₂ + U_new = let A = get_spec(basis.pibasis) + adjoint.(_off_site_ace_basis_no_sym(ℓ₂, ℓ₁, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ).A2Bmap) * _perm(A) * _dpar(A) + end + + basis = SymmetricBasis(basis.pibasis, U_new, basis.symgrp, basis.real) + end + + + return basis +end + +@doc raw""" + + off_site_ace_basis(ℓ₁, ℓ₂, ν, deg, b_cut[,e_cutₒᵤₜ, λₙ, λₗ]) + + +Initialise a simple off-site `SymmetricBasis` instance with sensible default parameters. + +Operates similarly to [`on_site_ace_basis`](@ref) but applies a `CylindricalBondEnvelope` to +the `Rn1pBasis` basis instance. The length and radius of the cylinder are defined as +maths: ``b_{cut}+2e_{cut\_out}`` and maths: ``e_{cut\_out}`` respectively; all other +parameters resolve to their defaults as defined by their constructors. Again, instances +must be manually instantiated if more fine-grained control is desired. + +# Arguments +- `(ℓ₁,ℓ₂)::Integer`: azimuthal numbers of the basis function. +- `ν::Integer`: maximum correlation order. +- `deg::Integer`: maximum polynomial degree. +- `b_cut::AbstractFloat`: cutoff distance for bonded interactions. +- `e_cutₒᵤₜ::AbstractFloat`: radius and axial-padding of the cylindrical bond envelope that + is used to determine which atoms impact to the bond's environment. +- `λₙ::AbstractFloat`: +- `λₗ::AbstractFloat`: +- `species::Union{Nothing, Vector{AtomicNumber}}`: A set of species of a system + +# Returns +- `basis::SymmetricBasis`: ACE basis entity for modelling the specified interaction. + +""" +function off_site_ace_basis(ℓ₁::I, ℓ₂::I, ν::I, deg::I, b_cut::F, e_cutₒᵤₜ::F=5.; + λₙ::F=.5, λₗ::F=.5, symfix=true, species::Union{Nothing, Vector{AtomicNumber}}=nothing) where {I<:Integer, F<:AbstractFloat} + # WARNING symfix might cause issues when applied to interactions between different species. + # It is still not clear how appropriate this non homo-shell interactions. + @static if SYMMETRY_FIX_ENABLED + if symfix && ( isnothing(species) || length(species) <= 1) + basis = _off_site_ace_basis_sym(ℓ₁, ℓ₂, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ, species = species) + else + basis = _off_site_ace_basis_no_sym(ℓ₁, ℓ₂, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ, species = species) + end + else + basis = _off_site_ace_basis_no_sym(ℓ₁, ℓ₂, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ, species = species) + end + + return basis + +end + + + + +""" + _filter_offsite_be(states, max_degree[, λ_n=0.5, λ_l=0.5]) + +Cap the the maximum polynomial components. + +This filter function should be passed, via the keyword `filterfun`, to `SymmetricBasis` +when instantiating. + + +# Arguments +- `indices::Tuple`: set of index that defines 1pbasis, e.g., (n,l,m,:be). +- `max_degree::Integer`: maximum polynomial degree. +- `λ_n::AbstractFloat`: +- `λ_l::AbstractFloat`: + +# Developers Notes +This function and its doc-string will be rewritten once its function and arguments have +been identified satisfactorily. + +# Examples +This is primarily intended to act as a filter function for off site bases like so: +``` +julia> off_site_sym_basis = SymmetricBasis( + φ, basis, selector, + filterfun = states -> filter_offsite_be(states, max_degree) +``` + +# Todo + - This should be inspected and documented. + - Refactoring may improve performance. + +""" +function _filter_offsite_be(indices, max_degree, λ_n=.5, λ_l=.5) + if length(indices) == 0; return false; end + deg_n, deg_l = ceil(Int, max_degree * λ_n), ceil(Int, max_degree * λ_l) + for idx in indices + if !idx.bond && (idx.n>deg_n || idx.l>deg_l) + return false + end + end + return sum(idx.bond for idx in indices) == 1 +end + + +adjoint(A::SphericalMatrix{ℓ₁, ℓ₂, LEN1, LEN2, T, LL}) where {ℓ₁, ℓ₂, LEN1, LEN2, T, LL} = SphericalMatrix(A.val', Val{ℓ₂}(), Val{ℓ₁}()) +transpose(A::SphericalMatrix{ℓ₁, ℓ₂, LEN1, LEN2, T, LL}) where {ℓ₁, ℓ₂, LEN1, LEN2, T, LL} = SphericalMatrix(transpose(A.val), Val{ℓ₂}(), Val{ℓ₁}()) +/(A::SphericalMatrix{ℓ₁, ℓ₂, LEN1, LEN2, T, LL}, b::Number) where {ℓ₁, ℓ₂, LEN1, LEN2, T, LL} = SphericalMatrix(A.val / b, Val{ℓ₁}(), Val{ℓ₂}()) +Base.:(*)(A::SphericalMatrix{ℓ₁, ℓ₂, LEN1, LEN2, T, LL}, b::Number) where {ℓ₁, ℓ₂, LEN1, LEN2, T, LL} = SphericalMatrix(A.val * b, Val{ℓ₁}(), Val{ℓ₂}()) + + +function _dpar(A) + parity = spzeros(Int, length(A), length(A)) + + for i=1:length(A) + for j=1:length(A[i]) + if A[i][j].bond + parity[i,i] = (-1)^A[i][j].l + break + end + end + end + + return parity + end + +""" +This function that takes an arbitrary named tuple and returns an identical copy with the +value of its "m" field inverted, if present. +""" +@generated function _invert_m(i) + + filed_names = Meta.parse( + join( + [field_name ≠ :m ? "i.$field_name" : "-i.$field_name" + for field_name in fieldnames(i)], + ", ") + ) + + return quote + $i(($(filed_names))) + end +end + +function _perm(A::Vector{Vector{T}}) where T + # This function could benefit from some optimisation. However it is stable for now. + + # Dictionary mapping groups of A to their column index. + D = Dict(sort(j)=>i for (i,j) in enumerate(A)) + + # Ensure that there is no double mapping going on; as there is not clear way + # to handel such an occurrence. + @assert length(D) ≡ length(A) "Double mapping ambiguity present in \"A\"" + + # Sparse zero matrix to hold the results + P = spzeros(length(A), length(A)) + + # Track which systems have been evaluated ahead of their designated loop via + # symmetric equivalence. This done purely for the sake of speed. + done_by_symmetry = zeros(Int, length(A)) + + for (i, A_group) in enumerate(A) + # Skip over groups that have already been assigned during evaluation of their + # symmetric equivalent. + i in done_by_symmetry && continue + + # Construct the "m" inverted named tuple + U = sort([_invert_m(t) for t in A_group]) + + # Identify which column the conjugate group `U` is associated. + idx = D[U] + + # Update done_by_symmetry checklist to prevent evaluating the symmetrically + # equivalent group. (if it exists) + done_by_symmetry[i] = idx + + # Compute and assign the cumulative "m" parity term for the current group and + # its symmetric equivalent (if present.) + P[i, idx] = P[idx, i] = (-1)^sum(o->o.m, A_group) + + end + + return P +end + +function _build_G(U) + # This function is highly inefficient and would benefit from a rewrite. It should be + # possible to make use of sparsity information present in U ahead of time to Identify + # when, where, and how many non-sparse points there are. + n_rows = size(U, 1) + # A sparse matrix would be more appropriate here given the high degree of sparsity + # present in `G`. However, this matrix is to be passed into the LinearAlgebra.svd + # method which is currently not able to operate upon sparse arrays. Hence a dense + # array is used here. + # G = spzeros(valtype(U[1].val), n_rows, n_rows) + G = zeros(valtype(U[1].val), n_rows, n_rows) + for row=1:n_rows, col=row:n_rows + result = sum(coco_dot.(U[row, :], U[col, :])) + if !iszero(result) + G[row, col] = G[col, row] = result + end + end + return G +end + + + +""" +At the time of writing there is an oversite present in the SparseArrays module which +prevents it from correctly identifying if an operation is sparsity preserving. That +is to say, in many cases a sparse matrix will be converted into its dense form which +can have profound impacts on performance. This function exists correct this behaviour, +and should be removed once the fixes percolate through to the stable branch. +""" +function Base.:(==)(x::SphericalMatrix, y::Integer) + return !(any(real(x.val) .!= y) || any(imag(x.val) .!= y)) +end + + +end diff --git a/examples/H2O/python_interface/large/src/common.jl b/examples/H2O/python_interface/large/src/common.jl new file mode 100644 index 0000000..7bbe630 --- /dev/null +++ b/examples/H2O/python_interface/large/src/common.jl @@ -0,0 +1,107 @@ +module Common +using ACEhamiltonians +using JuLIP: Atoms +export parse_key, with_cache, number_of_orbitals, species_pairs, shell_pairs, number_of_shells, test_function + +# Converts strings into tuples of integers or integers as appropriate. This function +# should be refactored and moved to a more appropriate location. It is mostly a hack +# at the moment. +function parse_key(key) + if key isa Integer || key isa Tuple + return key + elseif '(' in key + return Tuple(parse.(Int, split(strip(key, ['(', ')', ' ']), ", "))) + else + return parse(Int, key) + end +end + + +"""Nᵒ of orbitals present in system `atoms` based on a given basis definition.""" +function number_of_orbitals(atoms::Atoms, basis_definition::BasisDef) + # Work out the number of orbitals on each species + n_orbs = Dict(k=>sum(v * 2 .+ 1) for (k, v) in basis_definition) + # Use this to get the sum of the number of orbitals on each atom + return sum(getindex.(Ref(n_orbs), getfield.(atoms.Z, :z))) +end + +"""Nᵒ of orbitals present on a specific element, `z`, based on a given basis definition.""" +number_of_orbitals(z::I, basis_definition::BasisDef) where I<:Integer = sum(basis_definition[z] * 2 .+ 1) + +"""Nᵒ of shells present on a specific element, `z`, based on a given basis definition.""" +number_of_shells(z::I, basis_definition::BasisDef) where I<:Integer = length(basis_definition[z]) + +""" +Todo: + - Document this function correctly. + +Returns a cached guarded version of a function that stores known argument-result pairs. +This reduces the overhead associated with making repeated calls to expensive functions. +It is important to note that results for identical inputs will be the same object. + +# Warnings +Do no use this function, it only supports a very specific use-case. + +Although similar to the Memoize packages this function's use case is quite different and +is not supported by Memoize; hence the reimplementation here. This is mostly a stop-gap +measure and will be refactored at a later data. +""" +function with_cache(func::Function)::Function + cache = Dict() + function cached_function(args...; kwargs...) + k = (args..., kwargs...) + if !haskey(cache, k) + cache[k] = func(args...; kwargs...) + end + return cache[k] + end + return cached_function +end + + +_triangular_number(n::I) where I<:Integer = n*(n + 1)÷2 + +function species_pairs(atoms::Atoms) + species = sort(unique(getfield.(atoms.Z, :z))) + n = length(species) + + pairs = Vector{NTuple{2, valtype(species)}}(undef, _triangular_number(n)) + + c = 0 + for i=1:n, j=i:n + c += 1 + pairs[c] = (species[i], species[j]) + end + + return pairs +end + + +function shell_pairs(species_1, species_2, basis_def) + n₁, n₂ = length(basis_def[species_1]), length(basis_def[species_2]) + c = 0 + + if species_1 ≡ species_2 + pairs = Vector{NTuple{2, Int}}(undef, _triangular_number(n₁)) + + for i=1:n₁, j=i:n₁ + c += 1 + pairs[c] = (i, j) + end + + return pairs + else + pairs = Vector{NTuple{2, Int}}(undef, n₁ * n₂) + + for i=1:n₁, j=1:n₂ + c += 1 + pairs[c] = (i, j) + end + + return pairs + + end +end + + +end \ No newline at end of file diff --git a/examples/H2O/python_interface/large/src/data.jl b/examples/H2O/python_interface/large/src/data.jl new file mode 100644 index 0000000..e284668 --- /dev/null +++ b/examples/H2O/python_interface/large/src/data.jl @@ -0,0 +1,777 @@ +module MatrixManipulation + +using LinearAlgebra: norm, pinv +using ACEhamiltonians +using JuLIP: Atoms + +export BlkIdx, atomic_block_idxs, repeat_atomic_block_idxs, filter_on_site_idxs, + filter_off_site_idxs, filter_upper_idxs, filter_lower_idxs, get_sub_blocks, + filter_idxs_by_bond_distance, set_sub_blocks!, get_blocks, set_blocks!, + locate_and_get_sub_blocks + +# ╔═════════════════════╗ +# ║ Matrix Manipulation ║ +# ╚═════════════════════╝ + +# ╭─────────────────────┬────────╮ +# │ Matrix Manipulation │ BlkIdx │ +# ╰─────────────────────┴────────╯ +""" +An alias for `AbstractMatrix` used to signify a block index matrix. Given the frequency & +significance of block index matrices it became prudent to create an alias for it. This +helps to i) make it clear when and where a block index is used and ii) prevent having to +repeatedly explained what a block index matrix was each time one was used. + +As the name suggests, these are matrices which specifies the indices of a series of atomic +blocks. The 1ˢᵗ row specifies the atomic indices of the 1ˢᵗ atom in each block and the 2ⁿᵈ +row the indices of the 2ⁿᵈ atom. That is to say `block_index_matrix[:,i]` would yield the +atomic indices of the atoms associated with the iᵗʰ block listed in `block_index_matrix`. +The 3ʳᵈ row, if present, specifies the index of the cell in in which the second atom lies; +i.e. it indexes the cell translation vector list. + +For example; `BlkIdx([1 3; 2 4])` specifies two atomic blocks, the first being between +atoms 1&2, and the second between atoms 3&4; `BlkIdx([5; 6; 10])` represents the atomic +block between atoms 5&6, however in this case there is a third number 10 which give the +cell number. The cell number is can be used to help 3D real-space matrices or indicate +which cell translation vector should be applied. + +It is important to note that the majority of functions that take `BlkIdx` as an argument +will assume the first and second species are consistent across all atomic-blocks. +""" +BlkIdx = AbstractMatrix + + +# ╭─────────────────────┬─────────────────────╮ +# │ Matrix Manipulation │ BlkIdx:Constructors │ +# ╰─────────────────────┴─────────────────────╯ +# These are the main methods by which `BlkIdx` instances are are constructed & expanded. + +""" + atomic_block_idxs(z_1, z_2, z_s[; order_invariant=false]) + +Construct a block index matrix listing all atomic-blocks present in the supplied system +where the first and second species are `z_1` and `z_2` respectively. + +# Arguments +- `z_1::Int`: first interacting species +- `z_2::Int`: second interacting species +- `z_s::Vector`: atomic numbers present in system +- `order_invariant::Bool`: by default, `block_idxs` only indexes atomic blocks in which the + 1ˢᵗ & 2ⁿᵈ species are `z_1` & `z_2` respectively. However, if `order_invariant` is enabled + then `block_idxs` will index all atomic blocks between the two species irrespective of + which comes first. + +# Returns +- `block_idxs::BlkIdx`: a 2×N matrix in which each column represents the index of an atomic block. + With `block_idxs[:, i]` yielding the atomic indices associated with the iᵗʰ atomic-block. + +# Notes +Enabling `order_invariant` is analogous to the following compound call: +`hcat(atomic_block_idxs(z_1, z_2, z_s), atomic_block_idxs(z_2, z_1, z_s))` +Furthermore, only indices associated with the origin cell are returned; if extra-cellular +blocks are required then `repeat_atomic_block_idxs` should be used. + +# Examples +``` +julia> atomic_numbers = [1, 1, 8] +julia> atomic_block_idxs(1, 8, atomic_numbers) +2×2 Matrix{Int64}: + 1 2 + 3 3 + +julia> atomic_block_idxs(8, 1, atomic_numbers) +2×2 Matrix{Int64}: + 3 3 + 1 2 + +julia> atomic_block_idxs(8, 1, atomic_numbers; order_invariant=true) +2×4 Matrix{Int64}: + 3 3 1 2 + 1 2 3 3 +``` +""" +function atomic_block_idxs(z_1::I, z_2::I, z_s::Vector; order_invariant::Bool=false) where I<:Integer + # This function uses views, slices and reshape operations to construct the block index + # list rather than an explicitly nested for-loop to reduce speed. + z_1_idx, z_2_idx = findall(==(z_1), z_s), findall(==(z_2), z_s) + n, m = length(z_1_idx), length(z_2_idx) + if z_1 ≠ z_2 && order_invariant + res = Matrix{I}(undef, 2, n * m * 2) + @views let res = res[:, 1:end ÷ 2] + @views reshape(res[1, :], (m, n)) .= z_1_idx' + @views reshape(res[2, :], (m, n)) .= z_2_idx + end + + @views let res = res[:, 1 + end ÷ 2:end] + @views reshape(res[1, :], (n, m)) .= z_2_idx' + @views reshape(res[2, :], (n, m)) .= z_1_idx + end + else + res = Matrix{I}(undef, 2, n * m) + @views reshape(res[1, :], (m, n)) .= z_1_idx' + @views reshape(res[2, :], (m, n)) .= z_2_idx + end + return res +end + +function atomic_block_idxs(z_1, z_2, z_s::Atoms; kwargs...) + return atomic_block_idxs(z_1, z_2, convert(Vector{Int}, z_s.Z); kwargs...) +end + + +""" + repeat_atomic_block_idxs(block_idxs, n) + +Repeat the atomic blocks indices `n` times and adds a new row specifying image number. +This is primarily intended to be used as a way to extend an atom block index list to +account for periodic images as they present in real-space matrices. + +# Arguments +- `block_idxs::BlkIdx`: the block indices which are to be expanded. + +# Returns +- `block_indxs_expanded::BlkIdx`: expanded block indices. + +# Examples +``` +julia> block_idxs = [10 10 20 20; 10 20 10 20] +julia> repeat_atomic_block_idxs(block_idxs, 2) +3×8 Matrix{Int64}: + 10 10 20 20 10 10 20 20 + 10 20 10 20 10 20 10 20 + 1 1 1 1 2 2 2 2 +``` +""" +function repeat_atomic_block_idxs(block_idxs::BlkIdx, n::T) where T<:Integer + @assert size(block_idxs, 1) != 3 "`block_idxs` has already been expanded." + m = size(block_idxs, 2) + res = Matrix{T}(undef, 3, m * n) + @views reshape(res[3, :], (m, n)) .= (1:n)' + @views reshape(res[1:2, :], (2, m, n)) .= block_idxs + return res +end + + +# ╭─────────────────────┬──────────────────╮ +# │ Matrix Manipulation │ BlkIdx:Ancillary │ +# ╰─────────────────────┴──────────────────╯ +# Internal functions that operate on `BlkIdx`. + +""" + _block_starts(block_idxs, atoms, basis_def) + +This function takes a series of atomic block indices, `block_idxs`, and returns the index +of the first element in each atomic block. This is helpful when wanting to locate atomic +blocks in a Hamiltonian or overlap matrix associated with a given pair of atoms. + +# Arguments +- `block_idxs::BlkIdx`: block indices specifying the blocks whose starts are to be returned. +- `atoms::Atoms`: atoms object of the target system. +- `basis_def::BasisDef`: corresponding basis set definition. + +# Returns +- `block_starts::Matrix`: A copy of `block_idxs` where the first & second rows now provide an + index specifying where the associated block starts in the Hamiltonian/overlap matrix. The + third row, if present, is left unchanged. +""" +function _block_starts(block_idxs::BlkIdx, atoms::Atoms, basis_def::BasisDef) + n_orbs = Dict(k=>sum(2v .+ 1) for (k,v) in basis_def) # N∘ orbitals per species + n_orbs_per_atom = [n_orbs[z] for z in atoms.Z] # N∘ of orbitals on each atom + block_starts = copy(block_idxs) + @views block_starts[1:2, :] = ( + cumsum(n_orbs_per_atom) - n_orbs_per_atom .+ 1)[block_idxs[1:2, :]] + + return block_starts +end + +""" + _sub_block_starts(z_1, z_2, s_i, s_j, basis_def) + + +Get the index of the first element of the sub-block formed between shells `s_i` and `s_j` +of species `z_1` and `z_2` respectively. The results of this method are commonly added to +those of `_block_starts` to give the first index of a desired sub-block in some arbitrary +Hamiltonian or overlap matrix. + +# Arguments +- `z_1::Int`: species on which shell `s_i` resides. +- `z_2::Int`: species on which shell `s_j` resides. +- `s_i::Int`: first shell of the sub-block. +- `s_j::Int`: second shell of the sub-block. +- `basis_def::BasisDef`: corresponding basis set definition. + +# Returns +- `sub_block_starts::Vector`: vector specifying the index in a `z_1`-`z_2` atom-block at which + the first element of the `s_i`-`s_j` sub-block is found. + +""" +function _sub_block_starts(z_1, z_2, s_i::I, s_j::I, basis_def::BasisDef) where I<:Integer + sub_block_starts = Vector{I}(undef, 2) + sub_block_starts[1] = sum(2basis_def[z_1][1:s_i-1] .+ 1) + 1 + sub_block_starts[2] = sum(2basis_def[z_2][1:s_j-1] .+ 1) + 1 + return sub_block_starts +end + + +# ╭─────────────────────┬────────────────╮ +# │ Matrix Manipulation │ BlkIdx:Filters │ +# ╰─────────────────────┴────────────────╯ +# Filtering operators to help with differentiating between and selecting specific block +# indices or collections thereof. + +""" + filter_on_site_idxs(block_idxs) + +Filter out all but the on-site block indices. + +# Arguments +- `block_idxs::BlkIdx`: block index matrix to be filtered. + +# Returns +- `filtered_block_idxs::BlkIdx`: copy of `block_idxs` with only on-site block indices remaining. + +""" +function filter_on_site_idxs(block_idxs::BlkIdx) + # When `block_idxs` is a 2×N matrix then the only requirement for an interaction to be + # on-site is that the two atomic indices are equal to one another. If `block_idxs` is a + # 3×N matrix then the interaction must lie within the origin cell. + if size(block_idxs, 1) == 2 + return block_idxs[:, block_idxs[1, :] .≡ block_idxs[2, :]] + else + return block_idxs[:, block_idxs[1, :] .≡ block_idxs[2, :] .&& block_idxs[3, :] .== 1] + end +end + +""" + filter_off_site_idxs(block_idxs) + +Filter out all but the off-site block indices. + +# Arguments +- `block_idxs::BlkIdx`: block index matrix to be filtered. + +# Returns +- `filtered_block_idxs::BlkIdx`: copy of `block_idxs` with only off-site block indices remaining. + +""" +function filter_off_site_idxs(block_idxs::BlkIdx) + if size(block_idxs, 1) == 2 # Locate where atomic indices not are equal + return block_idxs[:, block_idxs[1, :] .≠ block_idxs[2, :]] + else # Find where atomic indices are not equal or the cell≠1. + return block_idxs[:, block_idxs[1, :] .≠ block_idxs[2, :] .|| block_idxs[3, :] .≠ 1] + end +end + + +""" + filter_upper_idxs(block_idxs) + +Filter out atomic-blocks that reside in the lower triangle of the matrix. This is useful +for removing duplicate data in some cases. (blocks on the diagonal are retained) + +# Arguments +- `block_idxs::BlkIdx`: block index matrix to be filtered. + +# Returns +- `filtered_block_idxs::BlkIdx`: copy of `block_idxs` with only blocks from the upper + triangle remaining. +""" +filter_upper_idxs(block_idxs::BlkIdx) = block_idxs[:, block_idxs[1, :] .≤ block_idxs[2, :]] + + +""" + filter_lower_idxs(block_idxs) + +Filter out atomic-blocks that reside in the upper triangle of the matrix. This is useful +for removing duplicate data in some cases. (blocks on the diagonal are retained) + +# Arguments +- `block_idxs::BlkIdx`: block index matrix to be filtered. + +# Returns +- `filtered_block_idxs::BlkIdx`: copy of `block_idxs` with only blocks from the lower + triangle remaining. +""" +filter_lower_idxs(block_idxs::BlkIdx) = block_idxs[:, block_idxs[1, :] .≥ block_idxs[2, :]] + +""" + filter_idxs_by_bond_distance(block_idxs, distance, atoms[, images]) + +Filters out atomic-blocks associated with interactions between paris of atoms that are +separated by a distance greater than some specified cutoff. + + +# Arguments +- `block_idx::BlkIdx`: block index matrix holding the off-site atom-block indices that + are to be filtered. +- `distance::AbstractFloat`: maximum bond distance; atom blocks representing interactions + between atoms that are more than `distance` away from one another will be filtered out. +- `atoms::Atoms`: system in which the specified blocks are located. +- `images::Matrix{<:Integer}`: cell translation index lookup list, this is only relevant + when `block_idxs` supplies and cell index value. The cell translation index for the iᵗʰ + state will be taken to be `images[block_indxs[i, 3]]`. + +# Returns +- `filtered_block_idx::BlkIdx`: a copy of `block_idx` in which all interactions associated + with interactions separated by a distance greater than `distance` have been filtered out. + +# Notes +It is only appropriate to use this method to filter `block_idx` instance in which all +indices pertain to off-site atom-blocks. + +""" +function filter_idxs_by_bond_distance( + block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, + images::Union{Nothing, AbstractMatrix{<:Integer}}=nothing) + + let mask = _distance_mask(block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, images) + return block_idxs[:, mask] + end +end + + +function _filter_block_indices(block_indices, focus::AbstractVector) + mask = ∈(focus).(block_indices[1, :]) .& ∈(focus).(block_indices[2, :]) + return block_indices[:, mask] + +end + +function _filter_block_indices(block_indices, focus::AbstractMatrix) + mask = ∈(collect(eachcol(focus))).(collect(eachcol(block_indices[1:2, :]))) + return block_indices[:, mask] +end + +# ╭─────────────────────┬─────────────────╮ +# │ Matrix Manipulation │ Data Assignment │ +# ╰─────────────────────┴─────────────────╯ + +# The `_get_blocks!` methods are used to collect either atomic-blocks or sub-blocks from +# a Hamiltonian or overlap matrix. + +""" + _get_blocks!(src, target, starts) + +Gather blocks from a `src` matrix and store them in the array `target`. This method is +to be used when gathering data from two-dimensional, single k-point, matrices. + +# Arguments +- `src::Matrix`: matrix from which data is to be drawn. +- `target::Array`: array in which data should be stored. +- `starts::BlkIdx`: a matrix specifying where each target block starts. + +# Notes +The size of each block to ge gathered is worked out form the size of `target`. + +""" +function _get_blocks!(src::Matrix{T}, target::AbstractArray{T, 3}, starts::BlkIdx) where T + for i in 1:size(starts, 2) + @views target[:, :, i] = src[ + starts[1, i]:starts[1, i] + size(target, 1) - 1, + starts[2, i]:starts[2, i] + size(target, 2) - 1] + end +end + +""" + _get_blocks!(src, target, starts) + +Gather blocks from a `src` matrix and store them in the array `target`. This method is +to be used when gathering data from three-dimensional, real-space, matrices. + +# Arguments +- `src::Matrix`: matrix from which data is to be drawn. +- `target::Array`: array in which data should be stored. +- `starts::BlkIdx`: a matrix specifying where each target block starts. Note that in + this case the cell index, i.e. the third row, specifies the cell index. + +# Notes +The size of each block to ge gathered is worked out form the size of `target`. + +""" +function _get_blocks!(src::AbstractArray{T, 3}, target::AbstractArray{T, 3}, starts::BlkIdx) where T + for i in 1:size(starts, 2) + @views target[:, :, i] = src[ + starts[1, i]:starts[1, i] + size(target, 1) - 1, + starts[2, i]:starts[2, i] + size(target, 2) - 1, + starts[3, i]] + end +end + +# The `_set_blocks!` methods perform the inverted operation of their `_get_blocks!` +# counterparts as they place data **into** the Hamiltonian or overlap matrix. + +""" + _set_blocks!(src, target, starts) + +Scatter blocks from the `src` matrix into the `target`. This method is to be used when +assigning data to two-dimensional, single k-point, matrices. + +# Arguments +- `src::Matrix`: matrix from which data is to be drawn. +- `target::Array`: array in which data should be stored. +- `starts::BlkIdx`: a matrix specifying where each target block starts. + +# Notes +The size of each block to ge gathered is worked out form the size of `target`. + +""" +function _set_blocks!(src::AbstractArray{T, 3}, target::Matrix{T}, starts::BlkIdx) where T + for i in 1:size(starts, 2) + @views target[ + starts[1, i]:starts[1, i] + size(src, 1) - 1, + starts[2, i]:starts[2, i] + size(src, 2) - 1, + ] = src[:, :, i] + end +end + +""" + _set_blocks!(src, target, starts) + +Scatter blocks from the `src` matrix into the `target`. This method is to be used when +assigning data to three-dimensional, real-space, matrices. + +# Arguments +- `src::Matrix`: matrix from which data is to be drawn. +- `target::Array`: array in which data should be stored. +- `starts::BlkIdx`: a matrix specifying where each target block starts. Note that in + this case the cell index, i.e. the third row, specifies the cell index. + +# Notes +The size of each block to ge gathered is worked out form the size of `target`. + +""" +function _set_blocks!(src::AbstractArray{T, 3}, target::AbstractArray{T, 3}, starts::BlkIdx) where T + for i in 1:size(starts, 2) + @views target[ + starts[1, i]:starts[1, i] + size(src, 1) - 1, + starts[2, i]:starts[2, i] + size(src, 2) - 1, + starts[3, i] + ] = src[:, :, i] + end +end + + + + +""" + get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def) + +Collect sub-blocks of a given type from select atom-blocks in a provided matrix. + +This method will collect, from `matrix`, the `s_i`-`s_j` sub-block of each atom-block +listed in `block_idxs`. It is assumed that all atom-blocks are between identical pairs +of species. + +# Arguments +- `matrix::Array`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `block_idxs::BlkIdx`: atomic-blocks from which sub-blocks are to be gathered. +- `s_i::Int`: first shell +- `s_j::Int`: second shell +- `atoms::Atoms`: target system's `JuLIP.Atoms` objects +- `basis_def:BasisDef`: corresponding basis set definition object (`BasisDef`) + +# Returns +- `sub_blocks`: an array containing the collected sub-blocks. + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +function get_sub_blocks(matrix::AbstractArray{T}, block_idxs::BlkIdx, s_i, s_j, atoms::Atoms, basis_def) where T + z_1, z_2 = atoms.Z[block_idxs[1:2, 1]] + + # Identify where each target block starts (first column and row) + starts = _block_starts(block_idxs, atoms, basis_def) + + # Shift `starts` so it points to the start of the **sub-blocks** rather than the block + starts[1:2, :] .+= _sub_block_starts(z_1, z_2, s_i, s_j, basis_def) .- 1 + + data = Array{T, 3}( # Array in which the resulting sub-blocks are to be collected + undef, 2basis_def[z_1][s_i] + 1, 2basis_def[z_2][s_j] + 1, size(block_idxs, 2)) + + # Carry out the assignment operation. + _get_blocks!(matrix, data, starts) + + return data +end + + +""" + set_sub_blocks(matrix, values, block_idxs, s_i, s_j, atoms, basis_def) + +Place sub-block data from `values` representing the interaction between shells `s_i` & +`s_j` into the matrix at the atom-blocks listed in `block_idxs`. This is this performs +the inverse operation to `set_sub_blocks`. + +# Arguments +- `matrix::Array`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `values::Array`: sub-block values. +- `block_idxs::BlkIdx`: atomic-blocks from which sub-blocks are to be gathered. +- `s_i::Int`: first shell +- `s_j::Int`: second shell +- `atoms::Atoms`: target system's `JuLIP.Atoms` objects +- `basis_def:BasisDef`: corresponding basis set definition object (`BasisDef`) + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +function set_sub_blocks!(matrix::AbstractArray, values, block_idxs::BlkIdx, s_i, s_j, atoms::Atoms, basis_def) + + if size(values, 3) != size(block_idxs, 2) + throw(DimensionMismatch( + "The last dimensions of `values` & `block_idxs` must be of the same length.")) + end + + z_1, z_2 = atoms.Z[block_idxs[1:2, 1]] + + # Identify where each target block starts (first column and row) + starts = _block_starts(block_idxs, atoms, basis_def) + + # Shift `starts` so it points to the start of the **sub-blocks** rather than the block + starts[1:2, :] .+= _sub_block_starts(z_1, z_2, s_i, s_j, basis_def) .- 1 + + # Carry out the scatter operation. + _set_blocks!(values, matrix, starts) +end + + +""" + get_blocks(matrix, block_idxs, atoms, basis_def) + +Collect, from `matrix`, the blocks listed in `block_idxs`. + +# Arguments +- `matrix::Array`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `block_idxs::BlkIdx`: the atomic-blocks to gathered. +- `atoms::Atoms`: target system's `JuLIP.Atoms` objects +- `basis_def:BasisDef`: corresponding basis set definition object (`BasisDef`) + +# Returns +- `sub_blocks`: an array containing the collected sub-blocks. + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +function get_blocks(matrix::AbstractArray{T}, block_idxs::BlkIdx, atoms::Atoms, basis_def) where T + z_1, z_2 = atoms.Z[block_idxs[1:2, 1]] + + # Identify where each target block starts (first column and row) + starts = _block_starts(block_idxs, atoms, basis_def) + + data = Array{T, 3}( # Array in which the resulting blocks are to be collected + undef, sum(2basis_def[z_1].+ 1), sum(2basis_def[z_2].+ 1), size(block_idxs, 2)) + + # Carry out the assignment operation. + _get_blocks!(matrix, data, starts) + + return data +end + + +""" + set_sub_blocks(matrix, values, block_idxs, s_i, s_j, atoms, basis_def) + +Place atom-block data from `values` into the matrix at the atom-blocks listed in `block_idxs`. +This is this performs the inverse operation to `set_blocks`. + +# Arguments +- `matrix::Array`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `values::Array`: sub-block values. +- `block_idxs::BlkIdx`: atomic-blocks from which sub-blocks are to be gathered. +- `s_i::Int`: first shell +- `s_j::Int`: second shell +- `atoms::Atoms`: target system's `JuLIP.Atoms` objects +- `basis_def:BasisDef`: corresponding basis set definition object (`BasisDef`) + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +function set_blocks!(matrix::AbstractArray, values, block_idxs::BlkIdx, atoms::Atoms, basis_def) + + if size(values, 3) != size(block_idxs, 2) + throw(DimensionMismatch( + "The last dimensions of `values` & `block_idxs` must be of the same length.")) + end + + # Identify where each target block starts (first column and row) + starts = _block_starts(block_idxs, atoms, basis_def) + + # Carry out the scatter operation. + _set_blocks!(values, matrix, starts) +end + + +""" + locate_and_get_sub_blocks(matrix, z_1, z_2, s_i, s_j, atoms, basis_def) + +Collects sub-blocks from the supplied matrix that correspond to off-site interactions +between the `s_i`'th shell on species `z_1` and the `s_j`'th shell on species `z_2`. + +# Arguments +- `matrix`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `z_1`: 1ˢᵗ species (atomic number) +- `z_2`: 2ⁿᵈ species (atomic number) +- `s_i`: shell on 1ˢᵗ species +- `s_j`: shell on 2ⁿᵈ species +- `atoms`: target system's `JuLIP.Atoms` objects +- `basis_def`: corresponding basis set definition object (`BasisDef`) + +# Returns +- `sub_blocks`: an Nᵢ×Nⱼ×M array containing the collected sub-blocks; where Nᵢ & Nⱼ are + the number of orbitals on the `s_i`'th & `s_j`'th shells of species `z_1` & `z_2` + respectively, and M is the N∘ of sub-blocks found. +- `block_idxs`: A matrix specifying which atomic block each sub-block in `sub_blocks` + was taken from. If `matrix` is a 3D real space matrix then `block_idxs` will also + include the cell index. + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +locate_and_get_sub_blocks(matrix, z_1, z_2, s_i, s_j, atoms::Atoms, basis_def; focus=nothing, no_reduce=false) = _locate_and_get_sub_blocks(matrix, z_1, z_2, s_i, s_j, atoms, basis_def; focus=focus, no_reduce=no_reduce) + +""" + locate_and_get_sub_blocks(matrix, z, s_i, s_j, atoms, basis_def) + +Collects sub-blocks from the supplied matrix that correspond to on-site interactions +between the `s_i`'th & `s_j`'th shells on species `z`. + +# Arguments +- `matrix`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `z_1`: target species (atomic number) +- `s_i`: 1ˢᵗ shell +- `s_j`: 2ⁿᵈ shell +- `atoms`: target system's `JuLIP.Atoms` objects +- `basis_def`: corresponding basis set definition object (`BasisDef`) + +# Returns +- `sub_blocks`: an Nᵢ×Nⱼ×M array containing the collected sub-blocks; where Nᵢ & Nⱼ are + the number of orbitals on the `s_i`'th & `s_j`'th shells of species `z_1` & `z_2` + respectively, and M is the N∘ of sub-blocks found. +- `block_idxs`: A matrix specifying which atomic block each sub-block in `sub_blocks` + was taken from. If `matrix` is a 3D real space matrix then `block_idxs` will also + include the cell index. + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +locate_and_get_sub_blocks(matrix, z, s_i, s_j, atoms::Atoms, basis_def; focus=nothing, kwargs...) = _locate_and_get_sub_blocks(matrix, z, s_i, s_j, atoms, basis_def; focus=focus) + +# Multiple dispatch is used to avoid the type instability in `locate_and_get_sub_blocks` +# associated with the creation of the `block_idxs` variable. It is also used to help +# distinguish between on-site and off-site collection operations. The following +# `_locate_and_get_sub_blocks` functions differ only in how they construct `block_idxs`. + +# Off site _locate_and_get_sub_blocks functions +function _locate_and_get_sub_blocks(matrix::AbstractArray{T, 2}, z_1, z_2, s_i, s_j, atoms::Atoms, basis_def; + focus=nothing, no_reduce=false) where T + block_idxs = atomic_block_idxs(z_1, z_2, atoms.Z) + + if !isnothing(focus) + block_idxs = _filter_block_indices(block_idxs, focus) + end + + block_idxs = filter_off_site_idxs(block_idxs) + + # Duplicate blocks present when gathering off-site homo-atomic homo-orbital interactions + # must be purged. + if (z_1 == z_2) && (s_i == s_j) && !no_reduce + block_idxs = filter_upper_idxs(block_idxs) + end + + return get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def), block_idxs +end + +function _locate_and_get_sub_blocks(matrix::AbstractArray{T, 3}, z_1, z_2, s_i, s_j, atoms::Atoms, basis_def; + focus=nothing, no_reduce=false) where T + block_idxs = atomic_block_idxs(z_1, z_2, atoms.Z) + + if !isnothing(focus) + block_idxs = _filter_block_indices(block_idxs, focus) + end + + block_idxs = repeat_atomic_block_idxs(block_idxs, size(matrix, 3)) + block_idxs = filter_off_site_idxs(block_idxs) + + if size(block_idxs, 2) == 0 + return zeros(basis_def[z_1][s_i] * 2 + 1, basis_def[z_2][s_j] * 2 + 1, 0), block_idxs + end + + if (z_1 == z_2) && (s_i == s_j) && !no_reduce + block_idxs = filter_upper_idxs(block_idxs) + end + + return get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def), block_idxs +end + +# On site _locate_and_get_sub_blocks functions +function _locate_and_get_sub_blocks(matrix::AbstractArray{T, 2}, z, s_i, s_j, atoms::Atoms, basis_def; focus=nothing) where T + block_idxs = atomic_block_idxs(z, z, atoms.Z) + + if !isnothing(focus) + block_idxs = _filter_block_indices(block_idxs, focus) + end + + block_idxs = filter_on_site_idxs(block_idxs) + + return get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def), block_idxs +end + +function _locate_and_get_sub_blocks(matrix::AbstractArray{T, 3}, z, s_i, s_j, atoms::Atoms, basis_def; focus=nothing) where T + block_idxs = atomic_block_idxs(z, z, atoms.Z) + + if !isnothing(focus) + block_idxs = _filter_block_indices(block_idxs, focus) + end + + block_idxs = filter_on_site_idxs(block_idxs) + block_idxs = repeat_atomic_block_idxs(block_idxs, 1) + + return get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def), block_idxs +end + + + +# ╭─────────────────────┬─────────────────────────────────────────╮ +# │ Matrix Manipulation │ BlkIdx:Miscellaneous Internal Functions │ +# ╰─────────────────────┴─────────────────────────────────────────╯ + +# This function is tasked with constructing the boolean mask used to filter out atom-blocks +# associated with interactions between pairs of atoms separated by a distance greater than +# some specified cutoff. Note that this is intended for internal use only and is primarily +# used by the `filter_idxs_by_bond_distance` method. +function _distance_mask( + block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, + images::Union{Nothing, AbstractMatrix{<:Integer}}=nothing) + + if isnothing(images) + l = atoms.cell' + l_inv = pinv(l) + mask = Vector{Bool}(undef, size(block_idxs, 2)) + for i=1:size(block_idxs, 2) + mask[i] = norm(_wrap(atoms.X[block_idxs[2, i]] - atoms.X[block_idxs[1, i]], l, l_inv)) <= distance + end + else + shift_vectors = collect(eachrow(images' * atoms.cell)) + mask = norm.(atoms.X[block_idxs[2, :]] - atoms.X[block_idxs[1, :]] + shift_vectors[block_idxs[3, :]]) .<= distance + end + return mask +end + +# Internal method used exclusively by _distance_mask. +function _wrap(x_vec, l, l_inv) + x_vec_frac = l_inv * x_vec + return l * (x_vec_frac .- round.(x_vec_frac)) +end + + +end diff --git a/examples/H2O/python_interface/large/src/datastructs.jl b/examples/H2O/python_interface/large/src/datastructs.jl new file mode 100644 index 0000000..f609928 --- /dev/null +++ b/examples/H2O/python_interface/large/src/datastructs.jl @@ -0,0 +1,474 @@ +module DataSets +using ACEhamiltonians +using Random: shuffle +using LinearAlgebra: norm +using JuLIP: Atoms +using ACEhamiltonians.MatrixManipulation: BlkIdx, _distance_mask +using ACEhamiltonians.States: reflect, _get_states, _neighbours, _locate_minimum_image, _locate_target_image + +import ACEhamiltonians.Parameters: ison + +# ╔══════════╗ +# ║ DataSets ║ +# ╚══════════╝ + +# `AbstractFittingDataSet` based structures contain all data necessary to perform a fit. +abstract type AbstractFittingDataSet end + +export DataSet, filter_sparse, filter_bond_distance, get_dataset, AbstractFittingDataSet, random_split, random_sample, random_distance_sample + +""" + DataSet(values, blk_idxs, states) + +A structure for storing collections of sub-blocks & states representing the environments +from which they came. These are intended to be used during the model fitting process. +While the block index matrix is not strictly necessary to the fitting process it is useful +enough to merit inclusion. + +# Fields +- `values::AbstractArray`: an i×j×n matrix containing extracted the sub-block values; + where i & j are the number of orbitals associate with the two shells, and n the number + of sub-blocks. +- `blk_idxs::BlkIdx`: a block index matrix specifying from which block each sub-block in + `values` was taken from. This acts mostly as meta-data. +- `states`: states representing the atomic-block from which each sub-block was taken. + +# Notes +Some useful properties of `DataSet` instances have been highlighted below: +- Addition can be used to combine one or more datasets. +- Indexing can be used to take a sub-set of a dataset. +- `size` acts upon `DataSet.values`. +- `length` returns `size(DataSet.values, 3)`, i.e. the number of sub-blocks. +- The adjoint of a `DataSet` will return a copy where: + - `values` is the hermitian conjugate of its parent. + - atomic indices `blk_idxs` have been exchanged, i.e. rows 1 & 2 are swapped. + - `states` are reflected (where appropriate) . + +""" +struct DataSet{V<:AbstractArray{<:Any, 3}, B<:BlkIdx, S<:AbstractVector} <: AbstractFittingDataSet + values::V + blk_idxs::B + states::S +end + +# ╭──────────┬───────────────────────╮ +# │ DataSets │ General Functionality │ +# ╰──────────┴───────────────────────╯ + +function Base.show(io::IO, data_set::T) where T<:AbstractFittingDataSet + F = valtype(data_set.values) + mat_shape = join(size(data_set), '×') + print(io, "$(nameof(T)){$F}($mat_shape)") +end + +function Base.:(==)(x::T, y::T) where T<:DataSet + return x.blk_idxs == y.blk_idxs && x.values == y.values && x.states == y.states +end + +# Two or more AbstractFittingDataSet entities can be added together via the `+` operator +Base.:(+)(x::T, y::T) where T<:AbstractFittingDataSet = T( + (cat(getfield(x, fn), getfield(y, fn), dims=ndims(getfield(x, fn))) + for fn in fieldnames(T))...) + +# Allow AbstractFittingDataSet objects to be indexed so that a subset may be selected. This is +# mostly used when filtering data. +function Base.getindex(data_set::T, idx::UnitRange) where T<:AbstractFittingDataSet + return T((_getindex_helper(data_set, fn, idx) for fn in fieldnames(T))...) +end + +function Base.getindex(data_set::T, idx::Vararg{<:Integer}) where T<:AbstractFittingDataSet + return data_set[collect(idx)] +end + +function Base.getindex(data_set::T, idx) where T<:AbstractFittingDataSet + return T((_getindex_helper(data_set, fn, idx) for fn in fieldnames(T))...) +end + +function _getindex_helper(data_set, fn, idx) + # This abstraction helps to speed up calls to Base.getindex. + return let data = getfield(data_set, fn) + collect(selectdim(data, ndims(data), idx)) + end +end + +Base.lastindex(data_set::AbstractFittingDataSet) = length(data_set) +Base.length(data_set::AbstractFittingDataSet) = size(data_set, 3) +Base.size(data_set::AbstractFittingDataSet, dim::Integer) = size(data_set.values, dim) +Base.size(data_set::AbstractFittingDataSet) = size(data_set.values) + +""" +Return a copy of the provided `DataSet` in which i) the sub-blocks (i.e. the `values` +field) have set to their adjoint, ii) the atomic indices in the `blk_idxs` field have +been exchanged, and iii) all `BondStates` in the `states` field have been reflected. +""" +function Base.adjoint(data_set::T) where T<:DataSet + swapped_blk_idxs = copy(data_set.blk_idxs); _swaprows!(swapped_blk_idxs, 1, 2) + return T( + # Transpose and take the complex conjugate of the sub-blocks + conj(permutedims(data_set.values, (2, 1, 3))), + # Swap the atomic indices in `blk_idxs` i.e. atom block 1-2 is now block 2-1 + swapped_blk_idxs, + # Reflect bond states across the mid-point of the bond. + [reflect.(i) for i in data_set.states]) +end + +function _swaprows!(matrix::AbstractMatrix, i::Integer, j::Integer) + @inbounds for k = 1:size(matrix, 2) + matrix[i, k], matrix[j, k] = matrix[j, k], matrix[i, k] + end +end + + +""" + ison(dataset) + +Return a boolean indicating whether the `DataSet` entity contains on-site data. +""" +ison(x::T) where T<:DataSet = ison(x.states[1][1]) + + + +# ╭──────────┬─────────╮ +# │ DataSets │ Filters │ +# ╰──────────┴─────────╯ + + +""" + filter_sparse(dataset[, threshold=1E-8]) + +Filter out data-points with fully sparse sub-blocks. Only data-points with sub-blocks +containing at least one element whose absolute value is greater than the specified +threshold will be retained. + +# Arguments +- `dataset::AbstractFittingDataSet`: the dataset that is to be filtered. +- `threshold::AbstractFloat`: value below which an element will be considered to be + sparse. This will defaudfgflt to 1E-8 if omitted. + +# Returns +- `filtered_dataset`: a copy of the, now filtered, dataset. + +""" +function filter_sparse(dataset::AbstractFittingDataSet, threshold::AbstractFloat=1E-8) + return dataset[vec(any(abs.(dataset.values) .>= threshold, dims=(1,2)))] +end + + +""" +filter_bond_distance(dataset, distance) + +Filter out data-points whose bond-vectors exceed the supplied cutoff distance. This allows +for states that will not be used to be removed during the data selection process rather +than during evaluation. Note that this is only applicable to off-site datasets. + +# Arguments +- `dataset::AbstractFittingDataSet`: tje dataset that is to be filtered. +- `distance::AbstractFloat`: data-points with bond distances exceeding this value will be + filtered out. + +# Returns +- `filtered_dataset`: a copy of the, now filtered, dataset. + +""" +function filter_bond_distance(dataset::AbstractFittingDataSet, distance::AbstractFloat) + if length(dataset) != 0 # Don't try and check the state unless one exists. + # Throw an error if the user tries to apply the a bond filter to an on-site dataset + # where there is no bond to be filtered. + @assert !ison(dataset) "Only applicable to off-site datasets" + end + return dataset[[norm(i[1].rr0) <= distance for i in dataset.states]] +end + +""" + random_sample(dataset, n[; with_replacement]) + +Select a random subset of size `n` from the supplied `dataset`. + +# Arguments +- `dataset::AbstractFittingDataSet`: dataset to be sampled. +- `n::Integer`: number of sample points +- `with_replacement::Bool`: if true (default) then duplicate samples will not be drawn. + +# Returns +- `sample_dataset::AbstractFittingDataSet`: a randomly selected subset of `dataset` + of size `n`. +""" +function random_sample(dataset::AbstractFittingDataSet, n::Integer; with_replacement=true) + if with_replacement + @assert n ≤ length(dataset) "Sample size cannot exceed dataset size" + return dataset[shuffle(1:length(dataset))[1:n]] + else + return dataset[rand(1:length(dataset), n)] + end +end + +""" + random_split(dataset, x) + +Split the `dataset` into a pair of randomly selected subsets. + +# Arguments +- `dataset::AbstractFittingDataSet`: dataset to be partitioned. +- `x::AbstractFloat`: partition ratio; the fraction of samples to be + placed into the first subset. + +""" +function random_split(dataset::DataSet, x::AbstractFloat) + split_index = Int(round(length(dataset)x)) + idxs = shuffle(1:length(dataset)) + return dataset[idxs[1:split_index]], dataset[idxs[split_index + 1:end]] + +end + + + +""" + random_distance_sample(dataset, n[; with_replacement=true, rng=rand]) + +Select a random subset of size `n` from the supplied `dataset` via distances. + +This functions in a similar manner to `random_sample` but selects points based on their +bond length. This is intended to ensure a more even sample. + +# Arguments +- `dataset::AbstractFittingDataSet`: dataset to be sampled. +- `n::Integer`: number of sample points +- `with_replacement::Bool`: if true (default) then duplicate samples will not be drawn. +- `rng::Function`: function to generate random numbers. + +""" +function random_distance_sample(dataset, n; with_replacement=true, rng=rand) + + @assert length(dataset) ≥ n + # Construct an array storing the bond lengths of each state, sort it and generate + # the sort permutation array to allow elements in r̄ to be mapped back to their + # corresponding state in the dataset. + r̄ = [norm(i[1].rr0) for i in dataset.states] + r̄_perm = sortperm(r̄) + r̄[:] = r̄[r̄_perm] + m = length(dataset) + + # Work out the maximum & minimum bond distance as well as the range + r_min = minimum(r̄) + r_max = maximum(r̄) + r_range = r_max - r_min + + # Preallocate transient index storage array + selected_idxs = zeros(Int, n) + + for i=1:n + # Select a random distance r ∈ [min(r̄), max(r̄)] + r = rng() * r_range + r_min + + # Identify the first element of r̄ ≥ r and the last element ≤ r + idxs = searchsorted(r̄, r) + idx_i, idx_j = minmax(first(idxs), last(idxs)) + + # Expand the window by one each side, but don't exceed the array's bounds + idx_i = max(idx_i-1, 1) + idx_j = min(idx_j+1, m) + + # Identify which element is closest to r and add the associated index + # to the selected index array. + idx = last(findmin(j->abs(r-j), r̄[idx_i:idx_j])) + idx_i - 1 + + # If this state has already been selected then replace it with the next + # closest one. + + if with_replacement && r̄_perm[idx] ∈ selected_idxs + + # Identify the indices corresponding to the first states with longer and shorter + # bond lengths than the current, duplicate, state. + lb = max(idx-1, 1) + ub = min(idx+1, m) + + while lb >= 1 && r̄_perm[lb] ∈ selected_idxs + lb -= 1 + end + + while ub <= m && r̄_perm[ub] ∈ selected_idxs + ub += 1 + end + + # Select the closets valid state + new_idx = 0 + dx = Inf + + if lb != 0 && lb != idx + new_idx = lb + dx = abs(r̄[lb] - r) + end + + if ub != m+1 && (abs(r̄[ub] - r) < dx) && ub != idx + new_idx = ub + end + + idx = new_idx + + end + + + selected_idxs[i] = r̄_perm[idx] + end + + return dataset[selected_idxs] + +end + +# ╭──────────┬───────────╮ +# │ DataSets │ Factories │ +# ╰──────────┴───────────╯ +# This section will hold the factory methods responsible for automating the construction +# of `DataSet` entities. The `get_dataset` methods will be implemented once the `AHSubModel` +# structures have been implemented. + + +# This is just a reimplementation of `filter_idxs_by_bond_distance` that allows for `blocks` +# to get filtered as well +function _filter_bond_idxs(blocks, block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, images) + let mask = _distance_mask(block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, images) + return blocks[:, :, mask], block_idxs[:, mask] + end +end + + + +""" +# Todo + - This could be made more performant. +""" +function _filter_sparse(values, block_idxs, tolerance) + mask = vec(any(abs.(values) .>= tolerance, dims=(1,2))) + return values[:, :, mask], block_idxs[:, mask] +end + + +""" +get_dataset(matrix, atoms, submodel, basis_def[, images; tolerance, filter_bonds, focus]) + +Construct and return a `DataSet` entity containing the minimal data required to fit a +`AHSubModel` entity. + +# Arguments +- `matrix`: matrix from which sub-blocks are to be gathered. +- `atoms`: atoms object representing the system to which the matrix pertains. +- `submodel`: `AHSubModel` entity for the desired sub-block; the `id` field is used to identify + which sub-blocks should be gathered and how they should be gathered. +- `basis_def`: a basis definition specifying what orbitals are present on each species. +- `images`: cell translation vectors associated with the matrix, this is only required + when the `matrix` is in the three-dimensional real-space form. + +# Keyword Arguments +- `tolerance`: specifying a float value will enact sparse culling in which only sub-blocks + with at least one element greater than the permitted tolerance will be included. This + is used to remove all zero, or near zero, sub-blocks. This is disabled by default. +- `filter_bonds`: if set to `true` then only interactions within the permitted cutoff + distance will be returned. This is only valid off-site interactions and is disabled + by default. The cut-off distance is extracted from the bond envelope contained within + the `basis` object. This defaults to `true` for off-site interactions. +- `focus`: the `focus` argument allows the `get_dataset` call to return only a sub-set + of possible data-points. If a vector of atomic indices is provided then only on/off- + site sub-blocks for/between those atoms will be returned; i.e. [1, 2] would return + on-sites 1-on, 2-on and off-sites 1-1-off, 1-2-off, 2-1-off, & 2-2-off. If a matrix + is provided, like so [1 2; 3 4] then only the associated off-site sub-blocks will be + returned, i.e. 1-2-off and 3-4-off. Note that the matrix form is only valid when + retrieving off-site sub-blocks. +- `no_reduce`: by default symmetrically redundant sub-blocks will not be gathered; this + equivalent blocks from be extracted from the upper and lower triangles of the Hamiltonian + and overlap matrices. This will default to `false`, however it is sometimes useful to + disable this when debugging. + +# Todo: + - Warn that only the upper triangle is returned and discuss how this effects "focus". + +""" +function get_dataset( + matrix::AbstractArray, atoms::Atoms, submodel::AHSubModel, basis_def, + images::Union{Matrix, Nothing}=nothing; + tolerance::Union{Nothing, <:AbstractFloat}=nothing, filter_bonds::Bool=true, + focus::Union{Vector{<:Integer}, Matrix{<:Integer}, Nothing}=nothing, + no_reduce=false) + + if ndims(matrix) == 3 && isnothing(images) + throw("`images` must be provided when provided with a real space `matrix`.") + end + + # Locate and gather the sub-blocks correspond the interaction associated with `basis` + blocks, block_idxs = locate_and_get_sub_blocks(matrix, submodel.id..., atoms, basis_def; focus=focus, no_reduce=no_reduce) + + if !isnothing(focus) + mask = ∈(focus).(block_idxs[1, :]) .& ∈(focus).(block_idxs[2, :]) + block_idxs = block_idxs[:, mask] + blocks = blocks[:, :, mask] + end + + # If gathering off-site data and `filter_bonds` is `true` then remove data-points + # associated with interactions between atom pairs whose bond-distance exceeds the + # cutoff as specified by the bond envelope. This prevents having to construct states + # (which is an expensive process) for interactions which will just be deleted later + # on. Enabling this can save a non-trivial amount of time and memory. + if !ison(submodel) && filter_bonds + blocks, block_idxs = _filter_bond_idxs( + blocks, block_idxs, envelope(submodel).r0cut, atoms, images) + end + + if !isnothing(tolerance) # Filter out sparse sub-blocks; but only if instructed to + blocks, block_idxs = _filter_sparse(blocks, block_idxs, tolerance) + end + + # Construct states for each of the sub-blocks. + if ison(submodel) + # For on-site states the cutoff radius is provided; this results in redundant + # information being culled here rather than later on; thus saving on memory. + states = _get_states(block_idxs, atoms; r=radial(submodel).R.ru) + else + if size(block_idxs, 2) == 0 + states = zeros(0) + else + states = _get_states(block_idxs, atoms, envelope(submodel), images) + end + # # For off-site states the basis' bond envelope must be provided. + # states = _get_states(block_idxs, atoms, envelope(submodel), images) + end + + # Construct and return the requested DataSet object + dataset = DataSet(blocks, block_idxs, states) + + return dataset +end + + + +""" +Construct a collection of `DataSet` instances storing the information required to fit +their associated `AHSubModel` entities. This convenience function will call the original +`get_dataset` method for each and every basis in the supplied model and return a +dictionary storing once dataset for each basis in the model. + +""" +function get_dataset( + matrix::AbstractArray, atoms::Atoms, model::Model, + images::Union{Matrix, Nothing}=nothing; kwargs...) + + basis_def = model.basis_definition + on_site_data = Dict( + basis.id => get_dataset(matrix, atoms, basis, basis_def, images; kwargs...) + for basis in values(model.on_site_submodels)) + + off_site_data = Dict( + basis.id => get_dataset(matrix, atoms, basis, basis_def, images; kwargs...) + for basis in values(model.off_site_submodels)) + + return on_site_data, off_site_data +end + + + +end + +# Notes +# - The matrix and array versions of `get_dataset` could easily be combined. +# - The `get_dataset` method is likely to suffer from type instability issues as it is +# unlikely that Julia will know ahead of time whether the `DataSet` structure returned +# will contain on or off-states states; each having different associated structures. +# Thus type ambiguities in the `AHSubModel` structures should be alleviated. diff --git a/examples/H2O/python_interface/large/src/fitting.jl b/examples/H2O/python_interface/large/src/fitting.jl new file mode 100644 index 0000000..d8aa956 --- /dev/null +++ b/examples/H2O/python_interface/large/src/fitting.jl @@ -0,0 +1,474 @@ +module Fitting +using HDF5, ACE, ACEbase, ACEhamiltonians, StaticArrays, Statistics, LinearAlgebra, SparseArrays, IterativeSolvers +using ACEfit: linear_solve, SKLEARN_ARD, SKLEARN_BRR +using HDF5: Group +using JuLIP: Atoms +using ACE: ACEConfig, evaluate, scaling, AbstractState, SymmetricBasis +using ACEhamiltonians.Common: number_of_orbitals +using ACEhamiltonians.Bases: envelope +using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma +using ACEatoms:AtomicNumber +using LowRankApprox: pqrfact + +using ACEhamiltonians: DUAL_BASIS_MODEL + +export fit! + +# set abs(a::AtomicNumber) = 0 as it is called in the `scaling` function but should not change the output +Base.abs(a::AtomicNumber) = 0 # a.z + +# Once the bond inversion issue has been resolved the the redundant models will no longer +# be required. The changes needed to be made in this file to remove the redundant model +# are as follows: +# - Remove inverted state condition in single model `fit!` method. +# - `_assemble_ls` should take `AHSubModel` entities. +# - Remove inverted state condition from the various `predict` methods. + +# Todo: +# - Need to make sure that the acquire_B! function used by ACE does not actually modify the +# basis function. Otherwise there may be some issues with sharing basis functions. +# - ACE should be modified so that `valtype` inherits from Base. This way there should be +# no errors caused when importing it. +# - Remove hard coded matrix type from the predict function. + +# The _assemble_ls and _evaluate_real methods should be rewritten to use PseudoBlockArrays +# this will prevent redundant allocations from being made and will mean that _preprocessA +# and _preprocessY can also be removed. This should help speed up the code and should +# significantly improve memory usage. +function _preprocessA(A) + # Note; this function was copied over from the original ACEhamiltonians/fit.jl file. + + # S1: number of sites; S2: number of basis, SS1: 2L1+1, SS2: 2L2+1 + S1,S2 = size(A) + SS1,SS2 = size(A[1]) + A_temp = zeros(ComplexF64, S1*SS1*SS2, S2) + for i = 1:S1, j = 1:S2 + A_temp[SS1*SS2*(i-1)+1:SS1*SS2*i,j] = reshape(A[i,j],SS1*SS2,1) + end + return real(A_temp) +end + +function _preprocessY(Y) + # Note; this function was copied over from the original ACEhamiltonians/fit.jl file. + + Len = length(Y) + SS1,SS2 = size(Y[1]) + Y_temp = zeros(ComplexF64,Len*SS1*SS2) + for i = 1:Len + Y_temp[SS1*SS2*(i-1)+1:SS1*SS2*i] = reshape(Y[i],SS1*SS2,1) + end + return real(Y_temp) +end + + +function solve_ls(A, Y, λ, Γ, solver = "LSQR"; niter = 10, inner_tol = 1e-3) + # Note; this function was copied over from the original ACEhamiltonians/fit.jl file. + + A = _preprocessA(A) + Y = _preprocessY(Y) + + num = size(A)[2] + A = [A; λ*Γ] + Y = [Y; zeros(num)] + if solver == "QR" + return real(qr(A) \ Y) + elseif solver == "LSQR" + # The use of distributed arrays is still causing a memory leak. As such the following + # code has been disabled until further notice. + # Ad, Yd = distribute(A), distribute(Y) + # res = real(IterativeSolvers.lsqr(Ad, Yd; atol = 1e-6, btol = 1e-6)) + # close(Ad), close(Yd) + res = real(IterativeSolvers.lsqr(A, Y; atol = 1e-6, btol = 1e-6)) + return res + elseif solver == "ARD" + return linear_solve(SKLEARN_ARD(;n_iter = niter, tol = inner_tol), A, Y)["C"] + elseif solver == "BRR" + return linear_solve(SKLEARN_BRR(;n_iter = niter, tol = inner_tol), A, Y)["C"] + elseif solver == "RRQR" + AP = A / I + θP = pqrfact(A, rtol = inner_tol) \ Y + return I \ θP + elseif solver == "NaiveSolver" + return real((A'*A) \ (A'*Y)) + end + + end + + +function _ctran(l::Int64,m::Int64,μ::Int64) + if abs(m) ≠ abs(μ) + return 0 + elseif abs(m) == 0 + return 1 + elseif m > 0 && μ > 0 + return 1/sqrt(2) + elseif m > 0 && μ < 0 + return (-1)^m/sqrt(2) + elseif m < 0 && μ > 0 + return - im * (-1)^m/sqrt(2) + else + return im/sqrt(2) + end +end + +_ctran(l::Int64) = sparse(Matrix{ComplexF64}([ _ctran(l,m,μ) for m = -l:l, μ = -l:l ])) + +function _evaluate_real(Aval) + L1,L2 = size(Aval[1]) + L1 = Int((L1-1)/2) + L2 = Int((L2-1)/2) + C1 = _ctran(L1) + C2 = _ctran(L2) + return real([ C1 * Aval[i].val * C2' for i = 1:length(Aval)]) +end + +""" +""" +# function _assemble_ls(basis::SymmetricBasis, data::T, enable_mean::Bool=false) where T<:AbstractFittingDataSet +# # This will be rewritten once the other code has been refactored. + +# # Should `A` not be constructed using `acquire_B!`? + +# n₁, n₂, n₃ = size(data) +# # Currently the code desires "A" to be an X×Y matrix of Nᵢ×Nⱼ matrices, where X is +# # the number of sub-block samples, Y is equal to `size(bos.basis.A2Bmap)[1]`, and +# # Nᵢ×Nⱼ is the sub-block shape; i.e. 3×3 for pp interactions. This may be refactored +# # at a later data if this layout is not found to be strictly necessary. +# cfg = ACEConfig.(data.states) +# Aval = evaluate.(Ref(basis), cfg) +# A = permutedims(reduce(hcat, _evaluate_real.(Aval)), (2, 1)) + +# Y = [data.values[:, :, i] for i in 1:n₃] + +# # Calculate the mean value x̄ +# if enable_mean && n₁ ≡ n₂ && ison(data) +# x̄ = mean(diag(mean(Y)))*I(n₁) +# else +# x̄ = zeros(n₁, n₂) +# end + +# Y .-= Ref(x̄) +# return A, Y, x̄ + +# end + + +using SharedArrays +using Distributed +using TensorCast +function _assemble_ls(basis::SymmetricBasis, data::T, enable_mean::Bool=false) where T<:AbstractFittingDataSet + # This will be rewritten once the other code has been refactored. + + # Should `A` not be constructed using `acquire_B!`? + + n₁, n₂, n₃ = size(data) + # Currently the code desires "A" to be an X×Y matrix of Nᵢ×Nⱼ matrices, where X is + # the number of sub-block samples, Y is equal to `size(bos.basis.A2Bmap)[1]`, and + # Nᵢ×Nⱼ is the sub-block shape; i.e. 3×3 for pp interactions. This may be refactored + # at a later data if this layout is not found to be strictly necessary. + + type = ACE.valtype(basis).parameters[5] + Avalr = SharedArray{real(type), 4}(n₃, length(basis), n₁, n₂) + np = length(procs(Avalr)) + nstates = length(data.states) + nstates_pp = ceil(Int, nstates/np) + np = ceil(Int, nstates/nstates_pp) + idx_begins = [nstates_pp*(idx-1)+1 for idx in 1:np] + idx_ends = [nstates_pp*(idx) for idx in 1:(np-1)] + push!(idx_ends, nstates) + @sync begin + for (i, id) in enumerate(procs(Avalr)[begin:np]) + # @async begin + @spawnat id begin + cfg = ACEConfig.(data.states[idx_begins[i]:idx_ends[i]]) + Aval_ele = evaluate.(Ref(basis), cfg) + Avalr_ele = _evaluate_real.(Aval_ele) + Avalr_ele = permutedims(reduce(hcat, Avalr_ele), (2, 1)) + @cast M[i,j,k,l] := Avalr_ele[i,j][k,l] + Avalr[idx_begins[i]: idx_ends[i], :, :, :] .= M + end + # end + end + end + @cast A[i,j][k,l] := Avalr[i,j,k,l] + + Y = [data.values[:, :, i] for i in 1:n₃] + + # Calculate the mean value x̄ + if enable_mean && n₁ ≡ n₂ && ison(data) + x̄ = mean(diag(mean(Y)))*I(n₁) + else + x̄ = zeros(n₁, n₂) + end + + Y .-= Ref(x̄) + return A, Y, x̄ + +end + + +################### +# Fitting Methods # +################### + +""" + fit!(submodel, data;[ enable_mean]) + +Fits a specified model with the supplied data. + +# Arguments +- `submodel`: a specified submodel that is to be fitted. +- `data`: data that the basis is to be fitted to. +- `enable_mean::Bool`: setting this flag to true enables a non-zero mean to be + used. +- `λ::AbstractFloat`: regularisation term to be used (default=1E-7). +- `solver::String`: solver to be used (default="LSQR") +""" +function fit!(submodel::T₁, data::T₂; enable_mean::Bool=false, λ=1E-7, solver="LSQR") where {T₁<:AHSubModel, T₂<:AbstractFittingDataSet} + + # Get the basis function's scaling factor + Γ = Diagonal(scaling(submodel.basis, 2)) + + # Setup the least squares problem + Φ, Y, x̄ = _assemble_ls(submodel.basis, data, enable_mean) + + # Assign the mean value to the basis set + submodel.mean .= x̄ + + # Solve the least squares problem and get the coefficients + + submodel.coefficients .= collect(solve_ls(Φ, Y, λ, Γ, solver)) + + @static if DUAL_BASIS_MODEL + if T₁<:AnisoSubModel + Γ = Diagonal(scaling(submodel.basis_i, 2)) + Φ, Y, x̄ = _assemble_ls(submodel.basis_i, data', enable_mean) + submodel.mean_i .= x̄ + submodel.coefficients_i .= collect(solve_ls(Φ, Y, λ, Γ, solver)) + end + end + + nothing +end + + +# Convenience function for appending data to a dictionary +function _append_data!(dict, key, value) + if haskey(dict, key) + dict[key] = dict[key] + value + else + dict[key] = value + end +end + + + +""" + fit!(model, systems;[ on_site_filter, off_site_filter, tolerance, recentre, refit, target]) + +Fits a specified model to the supplied data. + +# Arguments +- `model::Model`: Model to be fitted. +- `systems::Vector{Group}`: HDF5 groups storing data with which the model should + be fitted. +- `on_site_filter::Function`: the on-site `DataSet` entities will be passed through this + filter function prior to fitting; defaults `identity`. +- `off_site_filter::Function`: the off-site `DataSet` entities will be passed through this + filter function prior to fitting; defaults `identity`. +- `tolerance::AbstractFloat`: only sub-blocks where at least one value is greater than + or equal to `tolerance` will be fitted. This argument permits sparse blocks to be + ignored. +- `recentre::Bool`: Enabling this will re-wrap atomic coordinates to be consistent with + the geometry layout used internally by FHI-aims. This should be used whenever loading + real-space matrices generated by FHI-aims. +- `refit::Bool`: By default already fitted bases will not be refitted, but this behaviour + can be suppressed by setting `refit=true`. +- `target::String`: a string indicating which matrix should be fitted. This may be either + `H` or `S`. If unspecified then the model's `.label` field will be read and used. +""" +function fit!( + model::Model, systems::Vector{Group}; + on_site_filter::Function = identity, + off_site_filter::Function = identity, + tolerance::Union{F, Nothing}=nothing, + recentre::Bool=false, + target::Union{String, Nothing}=nothing, + refit::Bool=false, solver = "LSQR") where F<:AbstractFloat + + # Todo: + # - Add fitting parameters options which uses a `Params` instance to define fitting + # specific parameters such as regularisation, solver method, whether or not mean is + # used when fitting, and so on. + # - Modify so that redundant data is not extracted; i.e. both A[0,0,0] -> A[1,0,0] and + # A[0,0,0] -> A[-1,0,0] + # - The approach currently taken limits io overhead by reducing redundant operations. + # However, this will likely use considerably more memory. + + # Section 1: Gather the data + + # If no target has been specified; then default to that given by the model's label. + + target = isnothing(target) ? model.label : target + + get_matrix = Dict( # Select an appropriate function to load the target matrix + "H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, + "Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[target] + + fitting_data = Dict{Any, DataSet}() + + # Loop over the specified systems + for system in systems + + # Load the required data from the database entry + matrix, atoms = get_matrix(system), load_atoms(system; recentre=recentre) + images = ndims(matrix) == 2 ? nothing : load_cell_translations(system) + + # Loop over the on site bases and collect the appropriate data + for basis in values(model.on_site_submodels) + data_set = get_dataset(matrix, atoms, basis, model.basis_definition, images; + tolerance=tolerance) + + # Don't bother filtering and adding empty datasets + if length(data_set) != 0 + # Apply the on-site data filter function + data_set = on_site_filter(data_set) + # Add the selected data to the fitting data-set. + _append_data!(fitting_data, basis.id, data_set) + end + end + + # Repeat for the off-site models + for basis in values(model.off_site_submodels) + data_set = get_dataset(matrix, atoms, basis, model.basis_definition, images; + tolerance=tolerance, filter_bonds=true) + + if length(data_set) != 0 + data_set = off_site_filter(data_set) + _append_data!(fitting_data, basis.id, data_set) + end + + end + end + + # Fit the on/off-site models + fit!(model, fitting_data; refit=refit, solver = solver) + +end + + +""" + fit!(model, fitting_data[; refit]) + + +Fit the specified model using the provided data. + +# Arguments +- `model::Model`: the model that should be fitted. +- `fitting_data`: dictionary providing the data to which the supplied model should be + fitted. This should hold one entry for each submodel that is to be fitted and should take + the form `{SubModel.id, DataSet}`. +- `refit::Bool`: By default, already fitted bases will not be refitted, but this behaviour + can be suppressed by setting `refit=true`. +""" +function fit!( + model::Model, fitting_data; refit::Bool=false, solver="LSQR") + + @debug "Fitting off site bases:" + for (id, basis) in model.off_site_submodels + if !haskey(fitting_data, id) + @debug "Skipping $(id): no fitting data provided" + elseif is_fitted(basis) && !refit + @debug "Skipping $(id): submodel already fitted" + elseif length(fitting_data) ≡ 0 + @debug "Skipping $(id): fitting dataset is empty" + else + @debug "Fitting $(id): using $(length(fitting_data[id])) fitting points" + fit!(basis, fitting_data[id]; solver = solver) + end + end + + @debug "Fitting on site bases:" + for (id, basis) in model.on_site_submodels + if !haskey(fitting_data, id) + @debug "Skipping $(id): no fitting data provided" + elseif is_fitted(basis) && !refit + @debug "Skipping $(id): submodel already fitted" + elseif length(fitting_data) ≡ 0 + @debug "Skipping $(id): fitting dataset is empty" + else + @debug "Fitting $(id): using $(length(fitting_data[id])) fitting points" + fit!(basis, fitting_data[id]; enable_mean=ison(basis), solver = solver) + end + end +end + + +# The following code was added to `fitting.jl` to allow data to be fitted on databases +# structured using the original database format. +using ACEhamiltonians.DatabaseIO: _load_old_atoms, _load_old_hamiltonian, _load_old_overlap +using Serialization + +function old_fit!( + model::Model, systems, target::Symbol; + tolerance::F=0.0, filter_bonds::Bool=true, recentre::Bool=false, + refit::Bool=false) where F<:AbstractFloat + + # Todo: + # - Check that the relevant data exists before trying to extract it; i.e. don't bother + # trying to gather carbon on-site data from an H2 system. + # - Currently the basis set definition is loaded from the first system under the + # assumption that it is constant across all systems. However, this will break down + # if different species are present in each system. + # - The approach currently taken limits io overhead by reducing redundant operations. + # However, this will likely use considerably more memory. + + # Section 1: Gather the data + + get_matrix = Dict( # Select an appropriate function to load the target matrix + :H=>_load_old_hamiltonian, :S=>_load_old_overlap)[target] + + fitting_data = IdDict{AHSubModel, DataSet}() + + # Loop over the specified systems + for (database_path, index_data) in systems + + # Load the required data from the database entry + matrix, atoms = get_matrix(database_path), _load_old_atoms(database_path) + + println("Loading: $database_path") + + # Loop over the on site bases and collect the appropriate data + if haskey(index_data, "atomic_indices") + println("Gathering on-site data:") + for basis in values(model.on_site_submodels) + println("\t- $basis") + data_set = get_dataset( + matrix, atoms, basis, model.basis_definition; + tolerance=tolerance, focus=index_data["atomic_indices"]) + _append_data!(fitting_data, basis, data_set) + + end + println("Finished gathering on-site data") + end + + # Repeat for the off-site models + if haskey(index_data, "atom_block_indices") + println("Gathering off-site data:") + for basis in values(model.off_site_submodels) + println("\t- $basis") + data_set = get_dataset( + matrix, atoms, basis, model.basis_definition; + tolerance=tolerance, filter_bonds=filter_bonds, focus=index_data["atom_block_indices"]) + _append_data!(fitting_data, basis, data_set) + end + println("Finished gathering off-site data") + end + end + + # Fit the on/off-site models + fit!(model, fitting_data; refit=refit) + +end + +end diff --git a/examples/H2O/python_interface/large/src/io.jl b/examples/H2O/python_interface/large/src/io.jl new file mode 100644 index 0000000..3eeacdf --- /dev/null +++ b/examples/H2O/python_interface/large/src/io.jl @@ -0,0 +1,429 @@ +# All general io functionality should be placed in this file. With the exception of the +# `read_dict` and `write_dict` methods. + + + +""" +This module contains a series of functions that are intended to aid in the loading of data +from HDF5 structured databases. The code within this module is primarily intended to be +used only during the fitting of new models. Therefore, i) only loading methods are +currently supported, ii) the load target for each function is always the top level group +for each system. + +A brief outline of the expected HDF5 database structure is provided below. Note that +arrays **must** be stored in column major format! + +Database +├─System-1 +│ ├─Structure +│ │ ├─atomic_numbers: +│ │ │ > A vector of integers specifying the atomic number of each atom present in the +│ │ │ > target system. Read by the function `load_atoms`. +│ │ │ +│ │ ├─positions: +│ │ │ > A 3×N matrix, were N is the number of atoms present in the system, specifying +│ │ │ > cartesian coordinates for each atom. Read by the function `load_atoms`. +│ │ │ +│ │ ├─lattice: +│ │ │ > A 3×3 matrix specifying the lattice systems lattice vector in column-wise +│ │ │ > format; i.e. columns loop over vectors. Read by the function `load_atoms` +│ │ │ > if and when present. +│ │ └─pbc: +│ │ > A boolean, or a vector of booleans, indicating if, or along which dimensions, +│ │ > periodic conditions are enforced. This is only read when the lattice is given. +│ │ > This defaults to true for non-molecular/cluster cases. Read by `load_atoms`. +│ │ +│ ├─Info +│ │ ├─Basis: +│ │ │ > This group contains one dataset for each species that specifies the +│ │ │ > azimuthal quantum numbers of each shell present on that species. Read +│ │ │ > by the `load_basis_set_definition` method. +│ │ │ +│ │ │ +│ │ ├─Translations: +│ │ │ > A 3×N matrix specifying the cell translation vectors associated with the real +│ │ │ > space Hamiltonian & overlap matrices. Only present when Hamiltonian & overlap +│ │ │ > matrices are given in their M×M×N real from. Read by `load_cell_translations`. +│ │ │ > Should be integers specifying the cell indices, rather than cartesian vectors. +│ │ │ > The origin cell, [0, 0, 0], must always be first! +│ │ │ +│ │ └─k-points: +│ │ > A 4×N matrix where N is the number of k-points. The first three rows specify +│ │ > k-points themselves with the final row specifying their associated weights. +│ │ > Read by `load_k_points_and_weights`. Only present for multi-k-point calculations. +│ │ +│ └─Data +│ ├─H: +│ │ > The Hamiltonian. Either an M×M matrix or an M×M×N real-space tensor; where M is +│ │ > is the number of orbitals and N then number of primitive cell equivalents. Read +│ │ > in by the `load_hamiltonian` function. +│ │ +│ ├─S: +│ │ > The Overlap matrix; identical in format to the Hamiltonian matrix. This is read +│ │ > in by the `load_overlap` function. +│ │ +│ ├─total_energy: +│ │ > A single float value specifying the total system energy. +│ │ +│ ├─fermi_level: +│ │ > A single float value specifying the total fermi level. +│ │ +│ ├─forces: +│ │ > A 3×N matrix specifying the force vectors on each atom. +│ │ +│ ├─H_gamma: +│ │ > This can be used to store the gamma point only Hamiltonian matrix when 'H' is +│ │ > used to store the real space matrix. This is mostly for debugging & testing. +│ │ +│ └─S_gamma: +| > Overlap equivalent of `H_gamma` +│ +├─System-2 +│ ├─Structure +│ │ └─ ... +│ │ +│ ├─Info +│ │ └─ ... +│ │ +│ └─Data +│ └─ ... +... +└─System-n + └─ ... + + +Datasets and groups should provide information about what units the data they contain are +given in. This can be done through the of the HDF5 metadata `attributes`. When calling +the various load methods within `DatabaseIO` the `src` Group argument must always point +to the target systems top level Group. In the example structure tree given above these +would be 'System-1', 'System-2', and 'System-n'. +""" +module DatabaseIO +using ACEhamiltonians +using HDF5: Group, h5open, Dataset +using JuLIP: Atoms +using HDF5 # ← Can be removed once support for old HDF5 formats is dropped +using LinearAlgebra: pinv + +using SparseArrays + +function read_sparse(data_sparse:: Group) + + if haskey(data_sparse, "sizes") + data, indices, indptr, shape, sizes = read(data_sparse, "data"), read(data_sparse, "indices"), read(data_sparse, "indptr"), + read(data_sparse, "shape"), read(data_sparse, "sizes") + indx = cumsum(sizes) + pushfirst!(indx, 0) + indx_r = cumsum(shape[1, :].+1) + pushfirst!(indx_r, 0) + + matrices_list = [] + for i in 1:(length(indx)-1) + indx_s, indx_e = indx[i]+1, indx[i+1] + data_ele = data[indx_s: indx_e] + indices_ele = indices[indx_s: indx_e] + indx_r_s, indx_r_e = indx_r[i]+1, indx_r[i+1] + indptr_ele = indptr[indx_r_s: indx_r_e] + shape_ele = shape[:, i] + matrix = SparseMatrixCSC(shape_ele..., indptr_ele.+1, indices_ele.+1, data_ele) + push!(matrices_list, matrix) + end + + dense_matrices = [Matrix(mat) for mat in matrices_list] + return cat(dense_matrices..., dims=3) + + else + data, indices, indptr, shape = read(data_sparse, "data"), read(data_sparse, "indices"), read(data_sparse, "indptr"), read(data_sparse, "shape") + return Matrix(SparseMatrixCSC(shape..., indptr.+1, indices.+1, data)) + + end + +end + + +# Developers Notes: +# - The functions within this module are mostly just convenience wrappers for the HDF5 +# `read` method and as such very little logic is contained within. Thus no unit tests +# are provided for this module at this time. However, this will change when and if +# unit and write functionality are added. +# +# Todo: +# - A version number flag should be added to each group when written to prevent the +# breaking compatibility with existing databases every time an update is made. +# - The unit information provided in the HDF5 databases should be made use of once +# a grand consensus as to what internal units should be used. + +export load_atoms, load_hamiltonian, load_overlap, gamma_only, load_k_points_and_weights, load_cell_translations, load_basis_set_definition, load_density_of_states, load_fermi_level, load_density_matrix + + +# Booleans stored by python are interpreted as Int8 by Julia rather than as booleans. Thus +# a cleaner is required. +_clean_bool(bool::I) where I<:Integer = Bool(bool) +_clean_bool(bool::Vector{<:Integer}) = convert(Vector{Bool}, bool) +_clean_bool(bool) = bool + + +function _recentre!(x, l, l_inv) + x[:] = l_inv' * x .- 1E-8 + x[:] = l' * (x - round.(x) .+ 1E-8) + nothing +end + +""" + load_atoms(src) + +Instantiate a `JuLIP.Atoms` object from an HDF5 `Group`. + + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose `Atoms` object is to be + returned. +- `recentre:Bool`: By default, atoms are assumed to span the fractional coordinate domain + [0.0, 1.0). Setting `recentre` to `true` will remap atomic positions to the fractional + coordinate domain of [-0.5, 0.5). This is primarily used when interacting with real-space + matrices produced by the FHI-aims code base. + +# Returns +- `atoms::Atoms`: atoms object representing the structure of the target system. + +""" +function load_atoms(src::Group; recentre=false) + # Developers Notes: + # - Currently non-molecular/cluster systems are assumed to be fully periodic + # along each axis if no `pbc` condition is explicitly specified. + # Todo: + # - Use unit information provided by the "positions" & "lattice" datasets. + + # All system specific structural data should be contained within the "Structure" + # sub-group. Extract the group to a variable for ease of access. + src = src["Structure"] + + # Species and positions are always present to read them in + species, positions = read(src["atomic_numbers"]), read(src["positions"]) + + if haskey(src, "lattice") # If periodic + l = collect(read(src["lattice"])') + if recentre + l_inv = pinv(l) + for x in eachcol(positions) + _recentre!(x, l, l_inv) + end + end + + pbc = haskey(src, "pbc") ? _clean_bool(read(src["pbc"])) : true + return Atoms(; Z=species, X=positions, cell=l, pbc=pbc) + else # If molecular/cluster + return Atoms(; Z=species, X=positions) + end +end + + +""" + load_basis_set_definition(src) + +Load the basis definition of the target system. + +This returns a `BasisDef` dictionary which specifies the azimuthal quantum number of each +shell for each species. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose basis set definition is + to be read. + +# Returns +- `basis_def::BasisDef`: a dictionary keyed by species and valued by a vector specifying + the azimuthal quantum number for each shell on said species. + +""" +function load_basis_set_definition(src::Group) + # Basis set definition is stored in "/Info/Basis" relative to the system's top level + # group. + src = src["Info/Basis"] + # Basis set definition is stored as a series of vector datasets with names which + # correspond the associated atomic number. + return BasisDef{Int}(parse(Int, k) => read(v)[2, :] for (k, v) in zip(keys(src), src)) +end + +""" + load_k_points_and_weights(src) + +Parse k-points and their weights. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose k-points & k-weights + are to be returned. + +# Returns +- `k_points::Matrix`: a 3×n matrix where n is the number of k-points. +- `k_weights::Vector`: a vector with a weight for each k-point. + +# Warnings +This will error out for gamma-point only calculations. + +""" +function load_k_points_and_weights(src::Group) + # Read in the k-points and weights from the "/Info/k-points" matrix. + knw = read(src["Info/k-points"]) + return knw[1:3, :], knw[4, :] +end + + +""" + load_cell_translations(src) + +Load the cell translation vectors associated with the real Hamiltonian & overlap matrices. +Relevant when Hamiltonian & overlap matrices are stored in their real N×N×M form, where N +is the number of orbitals per primitive cell and M is the number of cell equivalents. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose cell translation vectors + are to be returned. + +# Returns +- `T::Matrix`: a 3×M matrix where M is the number of primitive cell equivalents. + +# Notes +There is one translation vector for each translated cell; i.e. if the Hamiltonian matrix +is N×N×M then there will be M cell translation vectors. Here, the first translation +vector is always that of the origin cell, i.e. [0, 0, 0]. + +# Warnings +This will error out for gamma point only calculations or datasets in which the real +matrices are not stored. +""" +load_cell_translations(src::Group) = read(src["Info/Translations"]) + +""" + load_hamiltonian(src) + +Load the Hamiltonian matrix stored for the target system. This may be either an N×N single +k-point (commonly the gamma point) matrix, or an N×N×M real space matrix; where N is the +number or orbitals and M the number of unit cell equivalents. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose Hamiltonian matrix is + to be returned. + +# Returns +- `H::Array`: Hamiltonian matrix. This may be either an N×N matrix, as per the single + k-point case, or an N×N×M array for the real space case. +""" +# load_hamiltonian(src::Group) = read(src["Data/H"]) +load_hamiltonian(src::Group) = isa(src["Data/H"], Dataset) ? read(src["Data/H"]) : read_sparse(src["Data/H"]) +# Todo: add unit conversion to `load_hamiltonian` + +""" + load_overlap(src) + +Load the overlap matrix stored for the target system. This may be either an N×N single +k-point (commonly the gamma point) matrix, or an N×N×M real space matrix; where N is the +number or orbitals and M the number of unit cell equivalents. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose overlap matrix is to be + returned. + +# Returns +- `H::Array`: overlap matrix. This may be either an N×N matrix, as per the single + k-point case, or an N×N×M array for the real space case. +""" +# load_overlap(src::Group) = read(src["Data/S"]) +load_overlap(src::Group) = isa(src["Data/S"], Dataset) ? read(src["Data/S"]) : read_sparse(src["Data/S"]) +# Todo: add unit conversion to `load_overlap` + +load_density_matrix(src::Group) = isa(src["Data/dm"], Dataset) ? read(src["Data/dm"]) : read_sparse(src["Data/dm"]) + + +""" + gamma_only(src) + +Returns `true` if the stored Hamiltonian & overlap matrices are for a single k-point only. +Useful for determining whether or not one should attempt to read cell translations or +k-points, etc. +""" +gamma_only(src::Group) = !haskey(src, "Info/Translations") + +# Get the gamma point only matrix; these are for debugging and are will be removed later. +# load_hamiltonian_gamma(src::Group) = read(src["Data/H_gamma"]) +load_hamiltonian_gamma(src::Group) = isa(src["Data/H_gamma"], Dataset) ? read(src["Data/H_gamma"]) : read_sparse(src["Data/H_gamma"]) +# Todo: add unit conversion to `load_hamiltonian_gamma` +# load_overlap_gamma(src::Group) = read(src["Data/S_gamma"]) +load_overlap_gamma(src::Group) = isa(src["Data/S_gamma"], Dataset) ? read(src["Data/S_gamma"]) : read_sparse(src["Data/S_gamma"]) +# Todo: add unit conversion to `load_overlap_gamma` +load_density_matrix_gamma(src::Group) = isa(src["Data/dm_gamma"], Dataset) ? read(src["Data/dm_gamma"]) : read_sparse(src["Data/dm_gamma"]) + + +""" + load_density_of_states(src) + +Load the density of states associated with the target system. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system for which the density of + states is to be returned. + +# Returns +- `values::Vector`: density of states. +- `energies::Vector`: energies at which densities of states were evaluated relative to + the fermi-level. +- `broadening::AbstractFloat`: broadening factor used by the smearing function. + +""" +function load_density_of_states(src::Group) + # Todo: + # - This currently returns units of eV for energy and 1/(eV.V_unit_cell) for DoS. + return ( + read(src, "Data/DoS/values"), + read(src, "Data/DoS/energies"), + read(src, "Data/DoS/broadening")) +end + + +""" + load_fermi_level(src) + +Load the calculated Fermi level (chemical potential). + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system for which the fermi level is + to be returned. + +# Returns +- `fermi_level::AbstractFloat`: the fermi level. +""" +function load_fermi_level(src) + # Todo: + # - This really should make use of unit attribute that is provided. + return read(src, "Data/fermi_level") +end + + +# These functions exist to support backwards compatibility with previous database structures. +# They are not intended to be called by general users as they will eventually be excised. +function _load_old_hamiltonian(path::String) + return h5open(path) do database + read(database, "aitb/H")[:, :] + end +end + +function _load_old_overlap(path::String) + return h5open(path) do database + read(database, "aitb/S")[:, :] + end +end + +function _load_old_atoms(path::String; groupname=nothing) + h5open(path, "r") do fd + groupname === nothing && (groupname = HDF5.name(first(fd))) + positions = HDF5.read(fd, string(groupname,"/positions")) + unitcell = HDF5.read(fd, string(groupname,"/unitcell")) + species = HDF5.read(fd, string(groupname,"/species")) + atoms = Atoms(; X = positions, Z = species, + cell = unitcell, + pbc = [true, true, true]) + return atoms + end + +end + +end diff --git a/examples/H2O/python_interface/large/src/models.jl b/examples/H2O/python_interface/large/src/models.jl new file mode 100644 index 0000000..9ba500f --- /dev/null +++ b/examples/H2O/python_interface/large/src/models.jl @@ -0,0 +1,241 @@ +module Models + +using ACEhamiltonians, ACE, ACEbase +import ACEbase: read_dict, write_dict +using ACEhamiltonians.Parameters: OnSiteParaSet, OffSiteParaSet +using ACEhamiltonians.Bases: AHSubModel, is_fitted +using ACEhamiltonians: DUAL_BASIS_MODEL +# Once we change the keys of basis_def from Integer to AtomicNumber, we will no +# longer need JuLIP here +using JuLIP + + +export Model + + +# ╔═══════╗ +# ║ Model ║ +# ╚═══════╝ + + +# Todo: +# - On-site and off-site components should be optional. +# - Document +# - Clean up +struct Model + + on_site_submodels + off_site_submodels + on_site_parameters + off_site_parameters + basis_definition + + label::String + + meta_data::Dict{String, Any} + + function Model( + on_site_submodels, off_site_submodels, on_site_parameters::OnSiteParaSet, + off_site_parameters::OffSiteParaSet, basis_definition, label::String, + meta_data::Union{Dict, Nothing}=nothing) + + # If no meta-data is supplied then just default to a blank dictionary + meta_data = isnothing(meta_data) ? Dict{String, Any}() : meta_data + + new(on_site_submodels, off_site_submodels, on_site_parameters, off_site_parameters, + basis_definition, label, meta_data) + end + + function Model( + basis_definition::BasisDef, on_site_parameters::OnSiteParaSet, + off_site_parameters::OffSiteParaSet, label::String, + meta_data::Union{Dict, Nothing}=nothing) + + # Developers Notes + # This makes the assumption that all z₁-z₂-ℓ₁-ℓ₂ interactions are represented + # by the same model. + + # get a species list from basis_definition + species = AtomicNumber.([ keys(basis_definition)... ]) + + # Discuss use of the on/off_site_cache entities + + on_sites = Dict{NTuple{3, keytype(basis_definition)}, AHSubModel}() + off_sites = Dict{NTuple{4, keytype(basis_definition)}, AHSubModel}() + + # Caching the basis functions of the functions is faster and allows us to reuse + # the same basis function for similar interactions. + ace_basis_on = with_cache(on_site_ace_basis) + ace_basis_off = with_cache(off_site_ace_basis) + + # Sorting the basis definition makes avoiding interaction doubling easier. + # That is to say, we don't create models for both H-C and C-H interactions + # as they represent the same thing. + basis_definition_sorted = sort(collect(basis_definition), by=first) + + @debug "Building model" + # Loop over all unique species pairs then over all combinations of their shells. + for (zₙ, (zᵢ, shellsᵢ)) in enumerate(basis_definition_sorted) + for (zⱼ, shellsⱼ) in basis_definition_sorted[zₙ:end] + homo_atomic = zᵢ == zⱼ + for (n₁, ℓ₁) in enumerate(shellsᵢ), (n₂, ℓ₂) in enumerate(shellsⱼ) + + # Skip symmetrically equivalent interactions. + homo_atomic && n₁ > n₂ && continue + + if homo_atomic + id = (zᵢ, n₁, n₂) + @debug "Building on-site model : $id" + ace_basis = ace_basis_on( # On-site bases + ℓ₁, ℓ₂, on_site_parameters[id]...; species = species) + + on_sites[(zᵢ, n₁, n₂)] = AHSubModel(ace_basis, id) + end + + id = (zᵢ, zⱼ, n₁, n₂) + @debug "Building off-site model: $id" + + ace_basis = ace_basis_off( # Off-site bases + ℓ₁, ℓ₂, off_site_parameters[id]...; species = species) + + @static if DUAL_BASIS_MODEL + if homo_atomic && n₁ == n₂ + off_sites[(zᵢ, zⱼ, n₁, n₂)] = AHSubModel(ace_basis, id) + else + ace_basis_i = ace_basis_off( + ℓ₂, ℓ₁, off_site_parameters[(zⱼ, zᵢ, n₂, n₁)]...) + off_sites[(zᵢ, zⱼ, n₁, n₂)] = AHSubModel(ace_basis, ace_basis_i, id) + end + else + off_sites[(zᵢ, zⱼ, n₁, n₂)] = AHSubModel(ace_basis, id) + end + end + end + end + + # If no meta-data is supplied then just default to a blank dictionary + meta_data = isnothing(meta_data) ? Dict{String, Any}() : meta_data + new(on_sites, off_sites, on_site_parameters, off_site_parameters, basis_definition, label, meta_data) + end + +end + +# Associated methods + +Base.:(==)(x::Model, y::Model) = ( + x.on_site_submodels == y.on_site_submodels && x.off_site_submodels == y.off_site_submodels + && x.on_site_parameters == y.on_site_parameters && x.off_site_parameters == y.off_site_parameters) + + +# ╭───────┬──────────────────╮ +# │ Model │ IO Functionality │ +# ╰───────┴──────────────────╯ + +function ACEbase.write_dict(m::Model) + # ACE bases are stored as hash values which are checked against the "bases_hashes" + # dictionary during reading. This avoids saving multiple copies of the same object; + # which is common as `AHSubModel` objects tend to share basis functions. + + + bases_hashes = Dict{String, Any}() + + function add_basis(basis) + # Store the hash/basis pair in the bases_hashes dictionary. As the `write_dict` + # method can be quite costly to evaluate it is best to only call it when strictly + # necessary; hence this function exists. + basis_hash = string(hash(basis)) + if !haskey(bases_hashes, basis_hash) + bases_hashes[basis_hash] = write_dict(basis) + end + end + + for basis in union(values(m.on_site_submodels), values(m.off_site_submodels)) + add_basis(basis.basis) + end + + # Serialise the meta-data + meta_data = Dict{String, Any}( + # Invoke the `read_dict` method on values as and where appropriate + k => hasmethod(write_dict, (typeof(v),)) ? write_dict(v) : v + for (k, v) in m.meta_data + ) + + dict = Dict( + "__id__"=>"HModel", + "on_site_submodels"=>Dict(k=>write_dict(v, true) for (k, v) in m.on_site_submodels), + "off_site_submodels"=>Dict(k=>write_dict(v, true) for (k, v) in m.off_site_submodels), + "on_site_parameters"=>write_dict(m.on_site_parameters), + "off_site_parameters"=>write_dict(m.off_site_parameters), + "basis_definition"=>Dict(k=>write_dict(v) for (k, v) in m.basis_definition), + "bases_hashes"=>bases_hashes, + "label"=>m.label, + "meta_data"=>meta_data) + + return dict +end + + +function ACEbase.read_dict(::Val{:HModel}, dict::Dict)::Model + + function set_bases(target, basis_functions) + for v in values(target) + v["basis"] = basis_functions[v["basis"]] + end + end + + # Replace basis object hashs with the appropriate object. + set_bases(dict["on_site_submodels"], dict["bases_hashes"]) + set_bases(dict["off_site_submodels"], dict["bases_hashes"]) + + ensure_int(v) = v isa String ? parse(Int, v) : v + + # Parse meta-data + if haskey(dict, "meta_data") + meta_data = Dict{String, Any}() + for (k, v) in dict["meta_data"] + if typeof(v) <: Dict && haskey(v, "__id__") + meta_data[k] = read_dict(v) + else + meta_data[k] = v + end + end + else + meta_data = nothing + end + + # One of the important entries present in the meta-data dictionary is the `occupancy` + # data. This should be keyed by integers; however the serialisation/de-serialisation + # process converts this into a string. A hard-coded fix is implemented here, but it + # would be better to create a more general way of handling this later on. + if !isnothing(meta_data) && haskey(meta_data, "occupancy") && (keytype(meta_data["occupancy"]) ≡ String) + meta_data["occupancy"] = Dict(parse(Int, k)=>v for (k, v) in meta_data["occupancy"]) + end + + return Model( + Dict(parse_key(k)=>read_dict(v) for (k, v) in dict["on_site_submodels"]), + Dict(parse_key(k)=>read_dict(v) for (k, v) in dict["off_site_submodels"]), + read_dict(dict["on_site_parameters"]), + read_dict(dict["off_site_parameters"]), + Dict(ensure_int(k)=>read_dict(v) for (k, v) in dict["basis_definition"]), + dict["label"], + meta_data) +end + + +# Todo: this is mostly to stop terminal spam and should be updated +# with more meaningful information later on. +function Base.show(io::IO, model::Model) + + # Work out if the on/off site bases are fully, partially or un-fitted. + f = b -> if all(b) "no" elseif all(!, b) "yes" else "partially" end + on = f([!is_fitted(i) for i in values(model.on_site_submodels)]) + off = f([!is_fitted(i) for i in values(model.off_site_submodels)]) + + # Identify the species present + species = join(sort(unique(getindex.(collect(keys(model.on_site_submodels)), 1))), ", ", " & ") + + print(io, "Model(fitted=(on: $on, off: $off), species: ($species))") +end + + +end diff --git a/examples/H2O/python_interface/large/src/parameters.jl b/examples/H2O/python_interface/large/src/parameters.jl new file mode 100644 index 0000000..6a54de3 --- /dev/null +++ b/examples/H2O/python_interface/large/src/parameters.jl @@ -0,0 +1,902 @@ +module Parameters +using Base, ACEbase +export NewParams, GlobalParams, AtomicParams, AzimuthalParams, ShellParams, ParaSet, OnSiteParaSet, OffSiteParaSet, ison + +# The `Params` structure has been temporarily renamed to `NewParams` to avoid conflicts +# with the old code. However, this will be rectified when the old code is overridden. + +# ╔════════════╗ +# ║ Parameters ║ +# ╚════════════╝ +# Parameter related code. + + +# +# Todo: +# - Parameters: +# - Need to enforce limits on key values, shells must be larger than zero and +# azimuthal numbers must be non-negative. +# - All Params should be combinable, with compound classes generated when combining +# different Params types. Compound types should always check the more refined +# struct first (i.e. ShellParamsLabel conversion + t2l(val::Pair{NTuple{N, I}, V}) where {N, I<:Integer, V} = convert(Pair{Label{N, I}, V}, val) + + if with_basis + return quote + function $(esc(name))(b_def, arg::$T1) where {K<:Label{$N, <:Integer}, V, N} + $(Expr(:call, Expr(:curly, esc(:new), :K, :V), Expr(:call, Dict, :arg), :b_def)) + end + + function $(esc(name))(b_def, arg::$T1) where {K<:NTuple{$N, <:Integer}, V, N} + $(Expr(:call, esc(name), :b_def, Expr(:(...), Expr(:call, :map, esc(t2l), :arg)))) + end + end + else + return quote + function $(esc(name))(arg::$T1) where {K<:Label{$N, <:Integer}, V, N} + $(Expr(:call, Expr(:curly, esc(:new), :K, :V), Expr(:call, Dict, :arg))) + end + + function $(esc(name))(arg::$T1) where {K<:NTuple{$N, <:Integer}, V, N} + $(Expr(:call, esc(name), Expr(:(...), Expr(:call, :map, esc(t2l), :arg)))) + end + end + end +end + + +# ╭────────┬────────────╮ +# │ Params │ Definition │ +# ╰────────┴────────────╯ +""" +Dictionary-like structures for specifying model parameters. + +These are used to provide the parameters needed when constructing models within the +`ACEhamiltonians` framework. There are currently four `Params` type structures, namely +`GlobalParams`, `AtomicParams`, `AzimuthalParams`, and `ShellParams`, each offering +varying levels of specificity. + +Each parameter, correlation order, maximum polynomial degree, environmental cutoff +distance, etc. may be specified using any of the available `Params` based structures. +However, i) each `Params` instance may represent one, and only one, parameter, and ii) +on/off-site parameters must not be mixed. +""" +abstract type NewParams{K, V} end + +""" + GlobalParams(val) + +A `GlobalParams` instance indicates that a single value should be used for all relevant +interactions. Querying such instances will always return the value `val`; so long as the +query is valid. For example: +``` +julia> p = GlobalParams(10.) +GlobalParams{Float64} with 1 entries: + () => 10.0 + +julia> p[1] # <- query parameter associated with H +10. +julia> p[(1, 6)] # <- query parameter associated with H-C interaction +10. +julia> p[(1, 6, 1, 2)] # <- interaction between 1ˢᵗ shell on H and 2ⁿᵈ shell on C +10. +``` +As can be seen the specified value `10.` will always be returned so long as the query is +valid. These instances are useful when specifying parameters that are constant across all +bases, such as the internal cutoff distance, as it avoids having to repeatedly specify it +for each and every interaction. + +# Arguments + - `val::Any`: value of the parameter + +""" +struct GlobalParams{K, V} <: NewParams{K, V} + _vals::Dict{K, V} + + @build GlobalParams 0 false + # Catch for special case where a single value passed + GlobalParams(arg) = GlobalParams(Label()=>arg) +end + + +""" + AtomicParams(k₁=>v₁, k₂=>v₂, ..., kₙ=>vₙ) + +These instances allow for parameters to be specified on a species by species basis. This +equates to one parameter per species for on-site interactions and one parameter per species +pair for off-site interactions. This will then result in all associated bases associated +with a specific species/species-pair all using a common value, like so: +``` +julia> p_on = AtomicParams(1=>9., 6=>11.) +AtomicParams{Float64} with 2 entries: + 6 => 11.0 + 1 => 9.0 + +julia> p_off = AtomicParams((1, 1)=>9., (1, 6)=>10., (6, 6)=>11.) +AtomicParams{Float64} with 3 entries: + (6, 6) => 11.0 + (1, 6) => 10.0 + (1, 1) => 9.0 + +# The value 11. is returned for all on-site C interaction queries +julia> p_on[(6, 1, 1)] == p_on[(6, 1, 2)] == p_on[(6, 2, 2)] == 11. +true +# The value 10. is returned for all off-site H-C interaction queries +julia> p_off[(1, 6, 1, 1)] == p_off[(6, 1, 2, 1)] == p_off[(6, 1, 2, 2)] == 10. +true +``` +These instances are instantiated in a similar manner to dictionaries and offer a finer +degree of control over the parameters than `GlobalParams` structures but are not as +granular as `AzimuthalParams` structures. + +# Arguments +- `pairs::Pair`: a sequence of pair arguments specifying the parameters for each species + or species-pair. Valid parameter forms are: + + - on-site: `z₁=>v` or `(z,)=>v` for on-sites + - off-site: `(z₁, z₂)=>v` + + where `zᵢ` represents the atomic number of species `i` and `v` the parameter valued + associated with this species or specie pair. + + +# Notes +It is important to note that atom pair keys are permutationally invariant, i.e. the keys +`(1, 6)` and `(6, 1)` are redundant and will overwrite one another like so: +``` +julia> test = AtomicParams((1, 6)=>10., (6, 1)=>1000.) +AtomicParams{Float64} with 1 entries: + (1, 6) => 1000.0 + +julia> test[(1, 6)] == test[(6, 1)] == 1000.0 +true +``` +Finally atomic numbers will be sorted so that the lowest atomic number comes first. However, +this is only a superficial visual change and queries will still be invariant to permutation. +""" +struct AtomicParams{K, V} <: NewParams{K, V} + _vals::Dict{K, V} + + @build AtomicParams 1 false + @build AtomicParams 2 false + # Catch for special case where keys are integers rather than tuples + AtomicParams(arg::Vararg{Pair{I, V}, N}) where {I<:Integer, V, N} = AtomicParams( + map(i->((first(i),)=>last(i)), arg)...) + +end + + +""" + AzimuthalParams(basis_definition, k₁=>v₁, k₂=>v₂, ..., kₙ=>vₙ) + +Parameters specified for each azimuthal quantum number of each species. This allows for a +finer degree of control and is a logical extension of the `AtomicParams` structure. It is +important to note that `AzimuthalParams` instances must be supplied with a basis definition. +This allows it to work out the azimuthal quantum number associated with each shell during +lookup. + +``` +# Basis definition describing a H_1s C_2s1p basis set +julia> basis_def = Dict(1=>[0], 6=>[0, 0, 1]) +julia> p_on = AzimuthalParams( + basis_def, (1, 0, 0)=>1, (6, 0, 0)=>2, (6, 0, 1)=>3, (6, 1, 1)=>4) +AzimuthalParams{Int64} with 4 entries: + (6, 0, 0) => 2 + (1, 0, 0) => 1 + (6, 1, 1) => 4 + (6, 0, 1) => 3 + +julia> p_off = AzimuthalParams( + basis_def, (1, 1, 0, 0)=>1, (6, 6, 0, 0)=>2, (6, 6, 0, 1)=>3, (6, 6, 1, 1)=>4, + (1, 6, 0, 0)=>6, (1, 6, 0, 1)=>6) + +AzimuthalParams{Int64} with 6 entries: + (1, 6, 0, 1) => 6 + (6, 6, 0, 1) => 3 + (1, 6, 0, 0) => 6 + (1, 1, 0, 0) => 1 + (6, 6, 1, 1) => 4 + (6, 6, 0, 0) => 2 + +# on-site interactions involving shells 1 % 2 will return 2 as they're both s-shells. +julia> p_on[(6, 1, 1)] == p_on[(6, 1, 2)] == p_on[(6, 2, 2)] == 2 +true + +``` + +# Arguments +- `basis_definition::BasisDef`: basis definition specifying the bases present on each + species. This is used to work out the azimuthal quantum number associated with each + shell when queried. +- `pairs::Pair`: a sequence of pair arguments specifying the parameters for each unique + atomic-number/azimuthal-number pair. Valid forms are: + + - on-site: `(z, ℓ₁, ℓ₂)=>v` + - off-site: `(z₁, z₂, ℓ₁, ℓ₂)=>v` + + where `zᵢ` and `ℓᵢ` represents the atomic and azimuthal numbers of species `i` to which + the parameter `v` is associated. + +# Notes +While keys are agnostic to the ordering of the azimuthal numbers; the first atomic number +`z₁` will always correspond to the first azimuthal number `ℓ₁`, i.e.: + - `(z₁, ℓ₁, ℓ₂) == (z₁, ℓ₂, ℓ₁)` + - `(z₁, z₂, ℓ₁, ℓ₂) == (z₂, z₁, ℓ₂, ℓ₁)` + - `(z₁, z₂, ℓ₁, ℓ₂) ≠ (z₁, z₂ ℓ₂, ℓ₁)` + - `(z₁, z₂, ℓ₁, ℓ₂) ≠ (z₂, z₁ ℓ₁, ℓ₂)` + +""" +struct AzimuthalParams{K, V} <: NewParams{K, V} + _vals::Dict{K, V} + _basis_def + + @build AzimuthalParams 3 true + @build AzimuthalParams 4 true +end + +""" + ShellParams(k₁=>v₁, k₂=>v₂, ..., kₙ=>vₙ) + +`ShellParams` structures allow for individual values to be provided for each and every +unique interaction. While this proved the finest degree of control it can quickly become +untenable for systems with large basis sets or multiple species due the shear number of +variable required. +``` +# For H1s C2s1p basis set. +julia> p_on = ShellParams( + (1, 1, 1)=>1, (6, 1, 1)=>2, (6, 1, 2)=>3, (6, 1, 3)=>4, + (6, 2, 2)=>5, (6, 2, 3)=>6, (6, 3, 3)=>7) + +ShellParams{Int64} with 7 entries: + (6, 3, 3) => 7 + (1, 1, 1) => 1 + (6, 1, 3) => 4 + (6, 2, 2) => 5 + (6, 1, 1) => 2 + (6, 1, 2) => 3 + (6, 2, 3) => 6 + +julia> p_off = ShellParams( + (1, 1, 1, 1)=>1, (1, 6, 1, 1)=>2, (1, 6, 1, 2)=>3, (1, 6, 1, 3)=>4, + (6, 6, 1, 1)=>5, (6, 6, 1, 2)=>6, (6, 6, 1, 3)=>74, (6, 6, 2, 2)=>8, + (6, 6, 2, 3)=>9, (6, 6, 3, 3)=>10) + +ShellParams{Int64} with 10 entries: + (6, 6, 2, 2) => 8 + (6, 6, 3, 3) => 10 + (6, 6, 1, 3) => 74 + (1, 1, 1, 1) => 1 + (1, 6, 1, 2) => 3 + (1, 6, 1, 1) => 2 + (1, 6, 1, 3) => 4 + (6, 6, 1, 1) => 5 + (6, 6, 1, 2) => 6 + (6, 6, 2, 3) => 9 + +``` + +# Arguments +- `pairs::Pair`: a sequence of pair arguments specifying the parameters for each unique + shell pair: + - on-site: `(z, s₁, s₂)=>v`, interaction between shell numbers `s₁` & `s₂` on species `z` + - off-site: `(z₁, z₂, s₁, s₂)=>v`, interaction between shell number `s₁` on species + `zᵢ` and shell number `s₂` on species `z₂`. + +""" +struct ShellParams{K, V} <: NewParams{K, V} + _vals::Dict{K, V} + + @build ShellParams 3 false + @build ShellParams 4 false +end + +# ╭────────┬───────────────────────╮ +# │ Params │ General Functionality │ +# ╰────────┴───────────────────────╯ +# Return the key and value types of the internal dictionary. +Base.valtype(::NewParams{K, V}) where {K, V} = V +Base.keytype(::NewParams{K, V}) where {K, V} = K +Base.valtype(::Type{NewParams{K, V}}) where {K, V} = V +Base.keytype(::Type{NewParams{K, V}}) where {K, V} = K + +# Extract keys and values from the internal dictionary (and number of elements) +Base.keys(x::NewParams) = keys(x._vals) +Base.values(x::NewParams) = values(x._vals) +Base.length(x::T) where T<:NewParams = length(x._vals) + +# Equality check, mostly use during testing +function Base.:(==)(x::T₁, y::T₂) where {T₁<:NewParams, T₂<:NewParams} + dx, dy = x._vals, y._vals + # Different type Params are not comparable + if T₁ ≠ T₂ + return false + # Different key sets means x & y are different + elseif keys(dx) ≠ keys(dy) + return false + # If any key yields a different value in x from x then x & y are different + else + for key in keys(dx) + if dx[key] ≠ dy[key] + return false + end + end + # Otherwise there is no difference between x and y, thus return true + return true + end +end + + + +# ╭────────┬────────────────────╮ +# │ Params │ Indexing Functions │ +# ╰────────┴────────────────────╯ +""" + params_object[key] + +This function makes `Params` structures indexable in the same way that dictionaries are. +This will not only check the `Params` object `params` for the specified key `key` but will +also check for i) permutationally equivalent matches, i.e. (1, 6)≡(6, 1), and ii) keys +that `key` is a subtype of i.e. (1, 6, 1, 1) ⊆ (1, 6). + +Valid key types are: + - z/(z,): single atomic number + - (z₁, z₂): pair of atomic numbers + - (z, s₁, s₂): single atomic number with pair of shell numbers + - (z₁, z₂, s₁, s₂): pair of atomic numbers with pair of shell numbers + +This is primarily intended to be used by the code internally, but is left accessible to the +user. +""" +function Base.getindex(x::T, key::K) where {T<:NewParams, K} + # This will not only match the specified key but also any superset it is a part of; + # i.e. the key (z₁, z₂, s₁, s₂) will match (z₁, z₂). + + # Block 1: convert shell numbers to azimuthal numbers for the AzimuthalParams case. + if T<:AzimuthalParams && !(K<:Integer) + if length(key) == 3 + key = (key[1], [x._basis_def[key[1]][i] for i in key[2:3]]...) + else + key = (key[1:2]..., x._basis_def[key[1]][key[3]], x._basis_def[key[2]][key[4]]) + end + end + + # Block 2: identify closest viable key. + super_key = filter(k->(key ⊆ k), keys(x)) + + # Block 3: retrieve the selected key. + if length(super_key) ≡ 0 + throw(KeyError(key)) + else + return x._vals[first(super_key)] + end +end + + +# ╭────────┬──────────────────╮ +# │ Params │ IO Functionality │ +# ╰────────┴──────────────────╯ +"""Full, multi-line string representation of a `Param` type objected""" +function _multi_line(io, x::T) where T<:NewParams + i = length(keytype(x._vals).types[1].types) ≡ 1 ? 1 : Base.:(:) + indent = repeat(" ", get(io, :indent, 0)+2) + v_string = join(["$(indent)$(k[i]) => $v" for (k, v) in x._vals], "\n") + # Make convert "()" to "(All)" for to make GlobalParams more readable + v_string = replace(v_string, "()" => "All") + return "$(nameof(T)){$(valtype(x))} with $(length(x._vals)) entries:\n$(v_string)" +end + + +function Base.show(io::O, x::T) where {T<:NewParams, O<:IO} + # If printing an isolated Params instance, just use the standard multi-line format + # if !haskey(io.dict, :SHOWN_SET) + # print(io, _multi_line(x)) + if !get(io, :compact, false) && !haskey(io.dict, :SHOWN_SET) + print(io, _multi_line(io, x)) + # If the Params is being printed as part of a group then a more compact + # representation is needed. + else + # Create a slicer remove braces from tuples of length 1 if needed + s = length(keytype(x)) == 1 ? 1 : Base.:(:) + # Sort the keys to ensure consistency + keys_s = sort([j.id for j in keys(x._vals)]) + # Only show first and last keys (or just the first if there is only one) + targets = length(x) != 1 ? [[1, lastindex(keys_s)]] : [1:1] + # Build the key list and print the message out + k_string = join([k[s] for k in keys_s[targets...]], " … ") + # Make convert "()" to "(All)" for to make GlobalParams more readable + k_string = replace(k_string, "()" => "All") + indent = repeat(" ", get(io, :indent, 0)) + print(io, "$(indent)$(nameof(T))($(k_string))") + end +end + +# Special show case: Needed as Base.TTY has no information dictionary +Base.show(io::Base.TTY, x::T) where T<:NewParams = print(io, _multi_line(x)) + + +function ACEbase.write_dict(p::T) where T<:NewParams{K, V} where {K, V} + # Recursive and arbitrary value type storage to be implemented later + # value_parsable = hasmethod(ACEbase.write_dict, (V)) + + dict = Dict( + "__id__"=>"NewParams", + "vals"=>Dict(string(k)=>v for (k, v) in p._vals)) + + if T<:AzimuthalParams + dict["basis_def"] = p._basis_def + end + + return dict +end + +function ACEbase.read_dict(::Val{:NewParams}, dict::Dict) + vals = Dict(Label(k)=>v for (k,v) in dict["vals"]) + n = length(keytype(vals)) + + if n ≡ 0 + return GlobalParams(vals...) + elseif n ≤ 2 + return AtomicParams(vals...) + elseif haskey(dict, "basis_def") + return AzimuthalParams(dict["basis_def"], vals...) + else + return ShellParams(vals...) + end + +end + + +# ╔═════════╗ +# ║ ParaSet ║ +# ╚═════════╝ +# Containers for collections of `Params` instances. These exist mostly to ensure that +# all the required parameters are specified and provide a single location where user +# specified parameters can be collected and checked. + +# ╭─────────┬────────────╮ +# │ ParaSet │ Definition │ +# ╰─────────┴────────────╯ +""" +`ParaSet` instances are structures which collect all the required parameter definitions +for a given interaction type in once place. Once instantiated, the `OnSiteParaSet` and +`OffSiteParaSet` structures should contain all parameters required to construct all of +the desired on/off-site bases. +""" +abstract type ParaSet end + + +""" + OnSiteParaSet(ν, deg, e_cut_out, r0) + +This structure holds all the `Params` instances required to construct the on-site +bases. + + +# Arguments +- `ν::Params{K, Int}`: correlation order, for on-site interactions the body order is one + more than the correlation order. +- `deg::Params{K, Int}`: maximum polynomial degree. +- `e_cut_out::Parameters{K, Float}`: environment's external cutoff distance. +- `r0::Parameters{K, Float}`: scaling parameter (typically set to the nearest neighbour distances). + + +# Todo + - check that r0 is still relevant +""" +struct OnSiteParaSet <: ParaSet + ν + deg + e_cut_out + r0 + + function OnSiteParaSet(ν::T₁, deg::T₂, e_cut_out::T₃, r0::T₄ + ) where {T₁<:NewParams, T₂<:NewParams, T₃<:NewParams, T₄<:NewParams} + ν::NewParams{<:Label, <:Integer} + deg::NewParams{<:Label, <:Integer} + e_cut_out::NewParams{<:Label, <:AbstractFloat} + r0::NewParams{<:Label, <:AbstractFloat} + new(ν, deg, e_cut_out, r0) + end + +end + +""" + OffSiteParaSet(ν, deg, b_cut, e_cut_out, r0) + +This structure holds all the `Params` instances required to construct the off-site +bases. + +# Arguments +- `ν::Params{K, Int}`: correlation order, for off-site interactions the body order is two + more than the correlation order. +- `deg::Params{K, Int}`: maximum polynomial degree. +- `b_cut::Params{K, Float}`: cutoff distance for off-site interactions. +- `e_cut_out::Params{K, Float}`: environment's external cutoff distance. + +# Todo: +- add λₙ & λₗ as parameters +- generate constructor to allow for arbitrary fields +""" +struct OffSiteParaSet <: ParaSet + ν + deg + b_cut + e_cut_out + + function OffSiteParaSet(ν::T₁, deg::T₂, b_cut::T₃, e_cut_out::T₄ + ) where {T₁<:NewParams, T₂<:NewParams, T₃<:NewParams, T₄<:NewParams} + ν::NewParams{<:Label, <:Integer} + deg::NewParams{<:Label, <:Integer} + b_cut::NewParams{<:Label, <:AbstractFloat} + e_cut_out::NewParams{<:Label, <:AbstractFloat} + new(ν, deg, b_cut, e_cut_out) + end + +end + +# ╭─────────┬───────────────────────╮ +# │ ParaSet │ General Functionality │ +# ╰─────────┴───────────────────────╯ +function Base.:(==)(x::T, y::T) where T<:ParaSet + # Check that all fields are equal to one another + for field in fieldnames(T) + # If any do not match then return false + if getfield(x, field) ≠ getfield(y, field) + return false + end + end + + # If all files match then return true + return true +end + + +# ╭─────────┬────────────────────────────────╮ +# │ ParaSet │ Miscellaneous Helper Functions │ +# ╰─────────┴────────────────────────────────╯ +# Returns true if a `ParaSet` corresponds to an on-site interaction. +ison(::OnSiteParaSet) = true +ison(::OffSiteParaSet) = false + + +# ╭─────────┬──────────────────╮ +# │ ParaSet │ IO Functionality │ +# ╰─────────┴──────────────────╯ +function ACEbase.write_dict(p::T) where T<:ParaSet + dict = Dict( + "__id__"=>"ParaSet", + (string(fn)=>write_dict(getfield(p, fn)) for fn in fieldnames(T))...) + return dict +end + + +function ACEbase.read_dict(::Val{:ParaSet}, dict::Dict) + if haskey(dict, "b_cut") + return OffSiteParaSet(( + ACEbase.read_dict(dict[i]) for i in + ["ν", "deg", "b_cut", "e_cut_out"])...) + else + return OnSiteParaSet(( + ACEbase.read_dict(dict[i]) for i in + ["ν", "deg", "e_cut_out", "r0"])...) + end +end + + +function Base.show(io::O, x::T) where {T<:ParaSet, O<:IO} + print(io, "$(nameof(T))") + if !get(io, :compact, false) && !haskey(io.dict, :SHOWN_SET) + new_io = IOContext(io.io, :indent=>get(io, :indent, 0)+4, :compact=>get(io, :compact, false)) + for f in fieldnames(T) + print(new_io, join(["\n", repeat(" ", get(io, :indent, 0)+2)]), "$f: ", getfield(x, f)) + end + + else + for (i, f) in enumerate(fieldnames(T)) + print(io, "$f: ") + show(io, getfield(x, f)) + if i != length(fieldnames(T)) + print(io, ", ") + end + end + print(io, ")") + + end + nothing +end + + + +# ╭─────────┬────────────────────╮ +# │ ParaSet │ Indexing Functions │ +# ╰─────────┴────────────────────╯ +""" + on_site_para_set[key] + +Indexing an `OnSiteParaSet` instance will index each of the internal fields and return +their results in a tuple, i.e. calling `res = on_site_para_set[key]` equates to calling +``` +res = ( + on_site_para_set.ν[key], on_site_para_set.deg[key], + on_site_para_set.e_cut_out[key], on_site_para_set.e_cut_in[key]) +``` + +This is mostly intended as a convenience function. +""" +function Base.getindex(para::OnSiteParaSet, key) + return ( + para.ν[key], para.deg[key], + para.e_cut_out[key], para.r0[key]) +end + + +""" + off_site_para_set[key] + +Indexing an `OffSiteParaSet` instance will index each of the internal fields and return +their results in a tuple, i.e. calling `res = off_site_para_set[key]` equates to calling +``` +res = ( + off_site_para_set.ν[key], off_site_para_set.deg[key], off_site_para_set.b_cut[key], + off_site_para_set.e_cut_out[key], off_site_para_set.e_cut_in[key]) +``` + +This is mostly intended as a convenience function. +""" +function Base.getindex(para::OffSiteParaSet, key) + return ( + para.ν[key], para.deg[key], para.b_cut[key], + para.e_cut_out[key]) +end + + + +# ╔═══════════════════════════╗ +# ║ Internal Helper Functions ║ +# ╚═══════════════════════════╝ + +""" +Sort `Label` tuples so that the lowest atomic-number/shell-number comes first for the +two/one atom interaction labels. If more than four integers are specified then an error +is raised. +""" + +""" + _process_ctuple(tuple) + +Preprocess tuples prior to their conversion into `Label` instances. This ensures that +tuples are ordered so that: + 1. the lowest atomic number comes first, but only if multiple atomic numbers are specified. + 2. the lowest shell number comes first, but only where this does not conflict with point 1. + +An error is then raised if the tuple is of an unexpected length. permitted lengths are: + - 1/(z) single atomic number. + - 2/(z₁, z₂) pair of atomic numbers + - 3/(z, s₁, s₂) single atomic number and pair of shell numbers + - 4/(z₁, z₂, s₁, s₂) pair of atomic numbers and a pair of shell numbers. + +Note that in the latter case s₁ & s₂ correspond to shells on z₁ & z₂ respectively thus +if z₁ and z₂ are flipped due to z₁>z₂ then s₁ & s₂ must also be shuffled. + +This os intended only to be used internally and only during the construction of `Label` +instances. +""" +function _process_tuple(x::NTuple{N, I}) where {N, I<:Integer} + if N <= 1; x + elseif N ≡ 2; x[1] ≤ x[2] ? x : reverse(x) + elseif N ≡ 3; x[2] ≤ x[3] ? x : x[[1, 3, 2]] + elseif N ≡ 4 + if x[1] > x[2] || ((x[1] ≡ x[2]) && (x[3] > x[4])); x[[2, 1, 4, 3]] + else; x + end + else + error( + "Label may contain no more than four integers, valid formats are:\n"* + " ()\n (z₁,)\n (z₁, s₁, s₂)\n (z₁, z₂)\n (z₁, z₂, s₁, s₂)") + end +end + + +# # Guards type conversion of dictionaries keyed with `Label` entities. This is done to +# # ensure that a meaningful message is given to the user when a key-collision occurs. +# function _guarded_convert(t::Type{Dict{Label{N, I}, V}}, x::Dict{NTuple{N, I}, V}) where {N, I<:Integer, V} +# try +# return convert(t, x) +# catch e +# if e.msg == "key collision during dictionary conversion" +# r_keys = _redundant_keys([k for k in keys(x)]) +# error("Redundant keys found:\n$(join([" - $(join(i, ", "))" for i in r_keys], "\n"))") +# else +# rethrow(e) +# end +# end +# end + +# # Collisions cannot occur when input dictionary is keyed by integers not tuples +# _guarded_convert(t::Type{Dict{Label{1, I}, V}}, x::Dict{I, V}) where {N, I<:Integer, V} = convert(t, x) + + +# function _redundant_keys(keys_in::Vector{NTuple{N, I}}) where {I<:Integer, N} +# duplicates = [] +# while length(keys_in) ≥ 1 +# key = Label(pop!(keys_in)) +# matches = [popat!(keys_in, i) for i in findall(i -> i == key, keys_in)] +# if length(matches) ≠ 0 +# append!(duplicates, Ref((key, matches...))) +# end +# end +# return duplicates +# end + +end \ No newline at end of file diff --git a/examples/H2O/python_interface/large/src/predicting.jl b/examples/H2O/python_interface/large/src/predicting.jl new file mode 100644 index 0000000..c06baf3 --- /dev/null +++ b/examples/H2O/python_interface/large/src/predicting.jl @@ -0,0 +1,987 @@ +module Predicting + +using ACE, ACEbase, ACEhamiltonians, LinearAlgebra +using JuLIP: Atoms, neighbourlist +using ACE: ACEConfig, AbstractState, SymmetricBasis, evaluate + +using ACEhamiltonians.States: _get_states +using ACEhamiltonians.Fitting: _evaluate_real + +using ACEhamiltonians: DUAL_BASIS_MODEL + +using SharedArrays, Distributed, TensorCast + +export predict, predict!, cell_translations + + +""" + cell_translations(atoms, cutoff) + +Translation indices of all cells in within range of the origin. Multiplying any translation +index by the lattice vector will return the cell translation vector associated with said +cell. The results of this function are most commonly used in constructing the real space +matrix. + +# Arguments +- `atoms::Atoms`: system for which the cell translation index vectors are to be constructed. +- `cutoff::AbstractFloat`: cutoff distance for diatomic interactions. + +# Returns +- `cell_indices::Matrix{Int}`: a 3×N matrix specifying all cell images that are within + the cutoff distance of the origin cell. + +# Notes +The first index provided in `cell_indices` is always that of the origin cell; i.e. [0 0 0]. +A cell is included if, and only if, at least one atom within it is within range of at least +one atom in the origin cell. Mirror image cell are always included, that is to say if the +cell [i, j, k] is present then the cell [-i, -j, -k] will also be present. +""" +function cell_translations(atoms::Atoms{T}, cutoff) where T<:AbstractFloat + + l⃗, x⃗ = atoms.cell, atoms.X + # n_atoms::Int = size(x⃗, 2) + n_atoms::Int = size(x⃗, 1) + + # Identify how many cell images can fit within the cutoff distance. + aₙ, bₙ, cₙ = convert.(Int, cld.(cutoff, norm.(eachrow(l⃗)))) + + # Matrix in which the resulting translation indices are to be stored + cellᵢ = Matrix{Int}(undef, 3, (2aₙ + 1) * (2bₙ + 1) * (2cₙ + 1)) + + # The first cell is always the origin cell + cellᵢ[:, 1] .= 0 + i = 1 + + # Loop over all possible cells within the cutoff range. + for n₁=-aₙ:aₙ, n₂=-bₙ:bₙ, n₃=-cₙ:cₙ + + # Origin cell is skipped over when encountered as it is already defined. + if n₁ ≠ 0 || n₂ ≠ 0 || n₃ ≠ 0 + + # Construct the translation vector + t⃗ = l⃗[1, :]n₁ + l⃗[2, :]n₂ + l⃗[3, :]n₃ + + # Check if any atom in the shifted cell, n⃗, is within the cutoff distance of + # any other atom in the origin cell, [0,0,0]. + min_distance = 2cutoff + for atomᵢ=1:n_atoms, atomⱼ=1:n_atoms + min_distance = min(min_distance, norm(x⃗[atomᵢ] - x⃗[atomⱼ] + t⃗)) + end + + # If an atom in the shifted cell is within the cutoff distance of another in + # the origin cell then the cell should be included. + if min_distance ≤ cutoff + i += 1 + cellᵢ[:, i] .= n₁, n₂, n₃ + end + end + end + + # Return a truncated view of the cell translation index matrix. + return cellᵢ[:, 1:i] + +end + +""" +cell_translations(atoms, model) + +Translation indices of all cells in within range of the origin. Note, this is a wrapper +for the base `cell_translations(atoms, cutoff)` method which automatically selects an +appropriate cutoff distance. See the base method for more info. + + +# Argument +- `atoms::Atoms`: system for which the cell translation index vectors are to be constructed. +- `model::Model`: model instance from which an appropriate cutoff distance is to be derived. + + +# Returns +- `cell_indices::Matrix{Int}`: a 3×N matrix specifying all cell images that are within + the cutoff distance of the origin cell. + +""" +function cell_translations(atoms::Atoms, model::Model) + # Loop over the interaction cutoff distances and identify the maximum recognise + # interaction distance and use that as the cutoff. + return cell_translations( + atoms, maximum(values(model.off_site_parameters.b_cut))) +end + + +""" + predict!(values, basis, state) + +Predict the values for a given sub-block by evaluating the provided basis on the specified +state; or more accurately the descriptor that is to be constructed from said state. Results +are placed directly into the supplied matrix `values.` + +# Arguments + - `values::AbstractMatrix`: matrix into which the results should be placed. + - `basis::AHSubModel`: basis to be evaluated. + - `state::Vector{States}`: state upon which the `basis` should be evaluated. +""" +function predict!(values::AbstractMatrix, submodel::T, state::Vector{S}) where {T<:AHSubModel, S<:AbstractState} + # If the model has been fitted then use it to predict the results; otherwise just + # assume the results are zero. + if is_fitted(submodel) + # Construct a descriptor representing the supplied state and evaluate the + # basis on it to predict the associated sub-block. + A = evaluate(submodel.basis, ACEConfig(state)) + B = _evaluate_real(A) + values .= (submodel.coefficients' * B) + submodel.mean + + @static if DUAL_BASIS_MODEL + if T<: AnisoSubModel + A = evaluate(submodel.basis_i, ACEConfig(reflect.(state))) + B = _evaluate_real(A) + values .= (values + ((submodel.coefficients_i' * B) + submodel.mean_i)') / 2.0 + elseif !ison(submodel) && (submodel.id[1] == submodel.id[2]) && (submodel.id[3] == submodel.id[4]) + # If the dual basis model is being used then it is assumed that the symmetry + # issue has not been resolved thus an additional symmetrisation operation is + # required. + A = evaluate(submodel.basis, ACEConfig(reflect.(state))) + B = _evaluate_real(A) + values .= (values + ((submodel.coefficients' * B) + submodel.mean)') / 2.0 + end + end + + else + fill!(values, 0.0) + end +end + + +#########for parallelization + +# function get_discriptors(basis::SymmetricBasis, states::Vector{<:Vector{<:AbstractState}}) +# # This will be rewritten once the other code has been refactored. + +# # Should `A` not be constructed using `acquire_B!`? + +# n₁, n₂, type = ACE.valtype(basis).parameters[3:5] +# # Currently the code desires "A" to be an X×Y matrix of Nᵢ×Nⱼ matrices, where X is +# # the number of sub-block samples, Y is equal to `size(bos.basis.A2Bmap)[1]`, and +# # Nᵢ×Nⱼ is the sub-block shape; i.e. 3×3 for pp interactions. This may be refactored +# # at a later data if this layout is not found to be strictly necessary. +# n₃ = length(states) + +# Avalr = SharedArray{real(type), 4}(n₃, length(basis), n₁, n₂) +# np = length(procs(Avalr)) +# nstates = length(states) +# nstates_pp = ceil(Int, nstates/np) +# np = ceil(Int, nstates/nstates_pp) +# idx_begins = [nstates_pp*(idx-1)+1 for idx in 1:np] +# idx_ends = [nstates_pp*(idx) for idx in 1:(np-1)] +# push!(idx_ends, nstates) +# @sync begin +# for (i, id) in enumerate(procs(Avalr)[begin:np]) +# @async begin +# @spawnat id begin +# cfg = ACEConfig.(states[idx_begins[i]:idx_ends[i]]) +# Aval_ele = evaluate.(Ref(basis), cfg) +# Avalr_ele = _evaluate_real.(Aval_ele) +# Avalr_ele = permutedims(reduce(hcat, Avalr_ele), (2, 1)) +# @cast M[i,j,k,l] := Avalr_ele[i,j][k,l] +# Avalr[idx_begins[i]: idx_ends[i], :, :, :] .= M +# end +# end +# end +# end +# @cast A[i,j][k,l] := Avalr[i,j,k,l] + +# return A + +# end + + +# function infer(coefficients::Vector{Float64}, mean::Matrix{Float64}, B::SubArray{<:Any}) +# return coefficients' * B + mean +# end + +# function infer(coefficients::Vector{Float64}, mean::Matrix{Float64}, B::Matrix{<:Any}) +# @cast B_broadcast[i][j] := B[i,j] +# result = infer.(Ref(coefficients), Ref(mean), B_broadcast) +# return result +# end + + + +function infer(coefficients::Vector{Float64}, mean::Matrix{Float64}, B::SubArray{<:Any}) + return coefficients' * B + mean +end + +function infer(coefficients::Vector{Float64}, mean::Matrix{Float64}, B::Vector{Matrix{Float64}}) + return coefficients' * B + mean +end + + +# function infer(coefficients::Vector{Float64}, mean::Matrix{Float64}, B::Array{<:Any, 4}, type::String="sub") +# @cast B_sub[i,j][k,l] := B[i,j,k,l] +# @cast B_broadcast[i][j] := B_sub[i,j] +# result = infer.(Ref(coefficients), Ref(mean), B_broadcast) +# return result +# end + +# function infer(coefficients::Vector{Float64}, mean::Matrix{Float64}, B::Array{<:Any, 4}, type::String="sub") +# @cast B_sub[i,j][k,l] := B[i,j,k,l] +# @cast B_broadcast[i][j] := B_sub[i,j] +# result = infer.(Ref(coefficients), Ref(mean), B_broadcast) +# return result +# end + + + +function predict_single(basis::SymmetricBasis, states::Vector{<:Vector{<:AbstractState}}, coefficients::Vector{Float64}, mean::Matrix{Float64}) +# function predict_single(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) + + #get_discriptors(basis::SymmetricBasis, states::Vector{<:Vector{<:AbstractState}}) + # basis = submodel.basis + + # This will be rewritten once the other code has been refactored. + + # Should `A` not be constructed using `acquire_B!`? + + n₁, n₂, type = ACE.valtype(basis).parameters[3:5] + # Currently the code desires "A" to be an X×Y matrix of Nᵢ×Nⱼ matrices, where X is + # the number of sub-block samples, Y is equal to `size(bos.basis.A2Bmap)[1]`, and + # Nᵢ×Nⱼ is the sub-block shape; i.e. 3×3 for pp interactions. This may be refactored + # at a later data if this layout is not found to be strictly necessary. + n₃ = length(states) + + values_sub = SharedArray{real(type), 3}(n₁, n₂, n₃) + np = length(procs(values_sub)) + nstates = length(states) + nstates_pp = ceil(Int, nstates/np) + np = ceil(Int, nstates/nstates_pp) + idx_begins = [nstates_pp*(idx-1)+1 for idx in 1:np] + idx_ends = [nstates_pp*(idx) for idx in 1:(np-1)] + push!(idx_ends, nstates) + @sync begin + for (i, id) in enumerate(procs(values_sub)[begin:np]) + # @async begin + @spawnat id begin + cfg = ACEConfig.(states[idx_begins[i]:idx_ends[i]]) + Aval_ele = evaluate.(Ref(basis), cfg) + Avalr_ele = _evaluate_real.(Aval_ele) + Avalr_ele = permutedims(reduce(hcat, Avalr_ele), (2, 1)) + result = infer.(Ref(coefficients), Ref(mean), collect(eachrow(Avalr_ele))) + values_sub[:, :, idx_begins[i]: idx_ends[i]] .= cat(result..., dims=3) + # @cast M[i,j,k,l] := Avalr_ele[i,j][k,l] + # Avalr[idx_begins[i]: idx_ends[i], :, :, :] .= M + end + # end + end + end + # @cast A[i,j][k,l] := Avalr[i,j,k,l] + + return values_sub #A + +end + + + +# function infer(coefficients::Vector{Float64}, mean::Matrix{Float64}, B::Array{<:Any, 4}) +# np = length(workers()) +# nstates = size(B, 1) +# nstates_pp = ceil(Int, nstates/np) +# np = ceil(Int, nstates/nstates_pp) +# worker_id = workers() +# idx_begins = [nstates_pp*(idx-1)+1 for idx in 1:np] +# idx_ends = [nstates_pp*(idx) for idx in 1:(np-1)] +# push!(idx_ends, nstates) +# result = [] +# @sync begin +# for i in 1:np +# id = worker_id[i] +# B_sub = @view B[idx_begins[i]:idx_ends[i],:,:,:] +# push!(result, @spawnat id infer(coefficients, mean, B_sub, "sub")) +# end +# end +# # result = fetch.(result) +# result = vcat(result...) +# result = cat(result..., dims=3) +# return result +# end + + + +function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) + # If the model has been fitted then use it to predict the results; otherwise just + # assume the results are zero. + if is_fitted(submodel) + # Construct a descriptor representing the supplied state and evaluate the + # basis on it to predict the associated sub-block. + # A = evaluate(submodel.basis, ACEConfig(state)) + # B = _evaluate_real(A) + # B_batch = get_discriptors(submodel.basis, states) + # # coeffs_expanded = repeat(submodel.coefficients', size(B_batch, 1), 1) + # # means_expanded = fill(submodel.mean, size(B_batch, 1)) + # # values .= dropdims(sum(coeffs_expanded .* B_batch, dims=2), dims=2) + means_expanded + # values .= cat(infer(submodel.coefficients, submodel.mean, B_batch)..., dims=3) + values .= predict_single(submodel.basis, states, submodel.coefficients, submodel.mean) + # values .= cat([(submodel.coefficients' * B_batch[i,:]) + submodel.mean for i in 1: size(B_batch,1)]..., dims=3) + # values = (submodel.coefficients' * B) + submodel.mean + + @static if DUAL_BASIS_MODEL + if typeof(submodel) <: AnisoSubModel + # A = evaluate(submodel.basis_i, ACEConfig(reflect.(state))) + # B = _evaluate_real(A) + # B_batch = get_discriptors(submodel.basis_i, [reflect.(state) for state in states]) + # # values .= (values + ((submodel.coefficients_i' * B) + submodel.mean_i)') / 2.0 + # # values .= (values + cat([((submodel.coefficients_i' * B_batch[i,:]) + submodel.mean_i)' for + # # i in 1: size(B_batch,1)]..., dims=3))/2.0 + # values .= (values + permutedims(cat(infer(submodel.coefficients_i, submodel.mean_i, B_batch)..., dims=3), (2,1,3))) / 2.0 + values .= (values + permutedims(predict_single(submodel.basis_i, [reflect.(state) for state in states], + submodel.coefficients_i, submodel.mean_i), (2,1,3))) / 2.0 + + elseif !ison(submodel) && (submodel.id[1] == submodel.id[2]) && (submodel.id[3] == submodel.id[4]) + # If the dual basis model is being used then it is assumed that the symmetry + # issue has not been resolved thus an additional symmetrisation operation is + # required. + # A = evaluate(submodel.basis, ACEConfig(reflect.(state))) + # B = _evaluate_real(A) + # B_batch = get_discriptors(submodel.basis, [reflect.(state) for state in states]) + # # values .= (values + ((submodel.coefficients' * B) + submodel.mean)') / 2.0 + # # values .= (values + cat([((submodel.coefficients' * B_batch[i,:]) + submodel.mean)' for + # # i in 1: size(B_batch,1)]..., dims=3))/2.0 + # values .= (values + permutedims(cat(infer(submodel.coefficients, submodel.mean, B_batch)..., dims=3), (2,1,3))) / 2.0 + values .= (values + permutedims(predict_single(submodel.basis, [reflect.(state) for state in states], + submodel.coefficients, submodel.mean), (2,1,3))) / 2.0 + end + end + + else + fill!(values, 0.0) + end +end + + + +# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) +# # If the model has been fitted then use it to predict the results; otherwise just +# # assume the results are zero. +# if is_fitted(submodel) +# # Construct a descriptor representing the supplied state and evaluate the +# # basis on it to predict the associated sub-block. +# # A = evaluate(submodel.basis, ACEConfig(state)) +# # B = _evaluate_real(A) +# B_batch = get_discriptors(submodel.basis, states) +# # coeffs_expanded = repeat(submodel.coefficients', size(B_batch, 1), 1) +# # means_expanded = fill(submodel.mean, size(B_batch, 1)) +# # values .= dropdims(sum(coeffs_expanded .* B_batch, dims=2), dims=2) + means_expanded +# values .= cat(infer(submodel.coefficients, submodel.mean, B_batch)..., dims=3) +# # values .= cat([(submodel.coefficients' * B_batch[i,:]) + submodel.mean for i in 1: size(B_batch,1)]..., dims=3) +# # values = (submodel.coefficients' * B) + submodel.mean + +# @static if DUAL_BASIS_MODEL +# if typeof(submodel) <: AnisoSubModel +# # A = evaluate(submodel.basis_i, ACEConfig(reflect.(state))) +# # B = _evaluate_real(A) +# B_batch = get_discriptors(submodel.basis_i, [reflect.(state) for state in states]) +# # values .= (values + ((submodel.coefficients_i' * B) + submodel.mean_i)') / 2.0 +# # values .= (values + cat([((submodel.coefficients_i' * B_batch[i,:]) + submodel.mean_i)' for +# # i in 1: size(B_batch,1)]..., dims=3))/2.0 +# values .= (values + permutedims(cat(infer(submodel.coefficients_i, submodel.mean_i, B_batch)..., dims=3), (2,1,3))) / 2.0 + +# elseif !ison(submodel) && (submodel.id[1] == submodel.id[2]) && (submodel.id[3] == submodel.id[4]) +# # If the dual basis model is being used then it is assumed that the symmetry +# # issue has not been resolved thus an additional symmetrisation operation is +# # required. +# # A = evaluate(submodel.basis, ACEConfig(reflect.(state))) +# # B = _evaluate_real(A) +# B_batch = get_discriptors(submodel.basis, [reflect.(state) for state in states]) +# # values .= (values + ((submodel.coefficients' * B) + submodel.mean)') / 2.0 +# # values .= (values + cat([((submodel.coefficients' * B_batch[i,:]) + submodel.mean)' for +# # i in 1: size(B_batch,1)]..., dims=3))/2.0 +# values .= (values + permutedims(cat(infer(submodel.coefficients, submodel.mean, B_batch)..., dims=3), (2,1,3))) / 2.0 +# end +# end + +# else +# fill!(values, 0.0) +# end +# end + + + + +# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) +# # If the model has been fitted then use it to predict the results; otherwise just +# # assume the results are zero. +# if is_fitted(submodel) +# # Construct a descriptor representing the supplied state and evaluate the +# # basis on it to predict the associated sub-block. +# # A = evaluate(submodel.basis, ACEConfig(state)) +# # B = _evaluate_real(A) +# B_batch = get_discriptors(submodel.basis, states) +# # coeffs_expanded = repeat(submodel.coefficients', size(B_batch, 1), 1) +# # means_expanded = fill(submodel.mean, size(B_batch, 1)) +# # values .= dropdims(sum(coeffs_expanded .* B_batch, dims=2), dims=2) + means_expanded +# values .= cat(infer(submodel.coefficients, submodel.mean, B_batch)..., dims=3) +# # values .= cat([(submodel.coefficients' * B_batch[i,:]) + submodel.mean for i in 1: size(B_batch,1)]..., dims=3) +# # values = (submodel.coefficients' * B) + submodel.mean + +# @static if DUAL_BASIS_MODEL +# if typeof(submodel) <: AnisoSubModel +# # A = evaluate(submodel.basis_i, ACEConfig(reflect.(state))) +# # B = _evaluate_real(A) +# B_batch = get_discriptors(submodel.basis_i, [reflect.(state) for state in states]) +# # values .= (values + ((submodel.coefficients_i' * B) + submodel.mean_i)') / 2.0 +# # values .= (values + cat([((submodel.coefficients_i' * B_batch[i,:]) + submodel.mean_i)' for +# # i in 1: size(B_batch,1)]..., dims=3))/2.0 +# values .= (values + permutedims(cat(infer(submodel.coefficients_i, submodel.mean_i, B_batch)..., dims=3), (2,1,3))) / 2.0 + +# elseif !ison(submodel) && (submodel.id[1] == submodel.id[2]) && (submodel.id[3] == submodel.id[4]) +# # If the dual basis model is being used then it is assumed that the symmetry +# # issue has not been resolved thus an additional symmetrisation operation is +# # required. +# # A = evaluate(submodel.basis, ACEConfig(reflect.(state))) +# # B = _evaluate_real(A) +# B_batch = get_discriptors(submodel.basis, [reflect.(state) for state in states]) +# # values .= (values + ((submodel.coefficients' * B) + submodel.mean)') / 2.0 +# # values .= (values + cat([((submodel.coefficients' * B_batch[i,:]) + submodel.mean)' for +# # i in 1: size(B_batch,1)]..., dims=3))/2.0 +# values .= (values + permutedims(cat(infer(submodel.coefficients, submodel.mean, B_batch)..., dims=3), (2,1,3))) / 2.0 +# end +# end + +# else +# fill!(values, 0.0) +# end +# end + + + +# function predict_state(submodel::T, state::Vector{S}) where {T<:AHSubModel, S<:AbstractState} +# # If the model has been fitted then use it to predict the results; otherwise just +# # assume the results are zero. +# if is_fitted(submodel) +# # Construct a descriptor representing the supplied state and evaluate the +# # basis on it to predict the associated sub-block. +# A = evaluate(submodel.basis, ACEConfig(state)) +# B = _evaluate_real(A) +# values = (submodel.coefficients' * B) + submodel.mean + +# @static if DUAL_BASIS_MODEL +# if T<: AnisoSubModel +# A = evaluate(submodel.basis_i, ACEConfig(reflect.(state))) +# B = _evaluate_real(A) +# values = (values + ((submodel.coefficients_i' * B) + submodel.mean_i)') / 2.0 +# elseif !ison(submodel) && (submodel.id[1] == submodel.id[2]) && (submodel.id[3] == submodel.id[4]) +# # If the dual basis model is being used then it is assumed that the symmetry +# # issue has not been resolved thus an additional symmetrisation operation is +# # required. +# A = evaluate(submodel.basis, ACEConfig(reflect.(state))) +# B = _evaluate_real(A) +# values = (values + ((submodel.coefficients' * B) + submodel.mean)') / 2.0 +# end +# end + +# else +# values = zeros(size(submodel)) +# end +# end + + + + +# Construct and fill a matrix with the results of a single state + +""" +""" +function predict(submodel::AHSubModel, states::Vector{<:AbstractState}) + # Create a results matrix to hold the predicted values. The shape & type information + # is extracted from the basis. However, complex types will be converted to their real + # equivalents as results in ACEhamiltonians are always real. With the current version + # of ACE this is the the easiest way to reliably identify the shape and float type of + # the sub-blocks; at least that Julia is happy with. + n, m, type = ACE.valtype(submodel.basis).parameters[3:5] + values = Matrix{real(type)}(undef, n, m) + predict!(values, submodel, states) + return values +end + + +""" +Predict the values for a collection of sub-blocks by evaluating the provided basis on the +specified states. This is a the batch operable variant of the primary `predict!` method. + +""" +# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) +# for i=1:length(states) +# @views predict!(values[:, :, i], submodel, states[i]) +# end +# end + + +# using Base.Threads +# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) +# @threads for i=1:length(states) +# @views predict!(values[:, :, i], submodel, states[i]) +# end +# end + + +# using Distributed +# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) +# @sync begin +# for i=1:length(states) +# @async begin +# @spawn @views predict!(values[:, :, i], submodel, states[i]) +# end +# end +# end +# end + + + +# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) +# np = length(procs(values)) +# nstates = length(states) +# nstates_pp = ceil(Int, nstates/np) +# np = ceil(Int, nstates/nstates_pp) +# idx_begins = [nstates_pp*(idx-1)+1 for idx in 1:np] +# idx_ends = [nstates_pp*(idx) for idx in 1:(np-1)] +# push!(idx_ends, nstates) +# @sync begin +# for (i, id) in enumerate(procs(values)[begin:np]) +# @async begin +# @spawnat id begin +# values[:, :, idx_begins[i]: idx_ends[i]] = cat(predict_state.(Ref(submodel), states[idx_begins[i]:idx_ends[i]])..., dims=3) +# end +# end +# end +# end +# end + + +# function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) +# for i=1:length(states) +# @views predict!(values[:, :, i], submodel, states[i]) +# end +# end + + + +""" +""" +function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) + # Construct and fill a matrix with the results from multiple states + n, m, type = ACE.valtype(submodel.basis).parameters[3:5] + values = Array{real(type), 3}(undef, n, m, length(states)) + predict!(values, submodel, states) + return values +end + + +# function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) +# # Construct and fill a matrix with the results from multiple states +# n, m, type = ACE.valtype(submodel.basis).parameters[3:5] +# # values = Array{real(type), 3}(undef, n, m, length(states)) +# values = SharedArray{real(type), 3}(n, m, length(states)) +# predict!(values, submodel, states) +# return values +# end + + +# Special version of the batch operable `predict!` method that is used when scattering data +# into a Vector of AbstractMatrix types rather than into a three dimensional tensor. This +# is implemented to facilitate the scattering of data into collection of sub-view arrays. +function predict!(values::Vector{<:Any}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) + for i=1:length(states) + @views predict!(values[i], submodel, states[i]) + end +end + +# using Base.Threads +# function predict!(values::Vector{<:Any}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) +# @threads for i=1:length(states) +# @views predict!(values[i], submodel, states[i]) +# end +# end + +""" +""" +function predict(model::Model, atoms::Atoms, cell_indices::Union{Nothing, AbstractMatrix}=nothing; kwargs...) + # Pre-build neighbour list to avoid edge case which can degrade performance + _preinitialise_neighbour_list(atoms, model) + + if isnothing(cell_indices) + return _predict(model, atoms; kwargs...) + else + return _predict(model, atoms, cell_indices; kwargs...) + end +end + + +function _predict(model, atoms, cell_indices) + + # Todo:- + # - use symmetry to prevent having to compute data for cells reflected + # cell pairs; i.e. [ 0, 0, 1] & [ 0, 0, -1] + # - Setting the on-sites to an identity should be determined by the model + # rather than just assuming that the user always wants on-site overlap + # blocks to be identity matrices. + + basis_def = model.basis_definition + n_orbs = number_of_orbitals(atoms, basis_def) + + # Matrix into which the final results will be placed + matrix = zeros(n_orbs, n_orbs, size(cell_indices, 2)) + + # Mirror index map array required by `_reflect_block_idxs!` + mirror_idxs = _mirror_idxs(cell_indices) + + # The on-site blocks of overlap matrices are approximated as identity matrix. + if model.label ≡ "S" + matrix[1:n_orbs+1:n_orbs^2] .= 1.0 + end + + for (species₁, species₂) in species_pairs(atoms::Atoms) + + # Matrix containing the block indices of all species₁-species₂ atom-blocks + blockᵢ = repeat_atomic_block_idxs( + atomic_block_idxs(species₁, species₂, atoms), size(cell_indices, 2)) + + # Identify on-site sub-blocks now as they as static over the shell pair loop. + # Note that when `species₁≠species₂` `length(on_blockᵢ)≡0`. + on_blockᵢ = filter_on_site_idxs(blockᵢ) + + for (shellᵢ, shellⱼ) in shell_pairs(species₁, species₂, basis_def) + + + # Get the off-site basis associated with this interaction + basis_off = model.off_site_submodels[(species₁, species₂, shellᵢ, shellⱼ)] + + # Identify off-site sub-blocks with bond-distances less than the specified cutoff + off_blockᵢ = filter_idxs_by_bond_distance( + filter_off_site_idxs(blockᵢ), + envelope(basis_off).r0cut, atoms, cell_indices) + + # Blocks in the lower triangle are redundant in the homo-orbital interactions + if species₁ ≡ species₂ && shellᵢ ≡ shellⱼ + off_blockᵢ = filter_upper_idxs(off_blockᵢ) + end + + + if size(off_blockᵢ, 2) != 0 + + off_site_states = _get_states( # Build states for the off-site atom-blocks + off_blockᵢ, atoms, envelope(basis_off), cell_indices) + + # Don't try to compute off-site interactions if none exist + if length(off_site_states) > 0 + let values = predict(basis_off, off_site_states) # Predict off-site sub-blocks + set_sub_blocks!( # Assign off-site sub-blocks to the matrix + matrix, values, off_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + + _reflect_block_idxs!(off_blockᵢ, mirror_idxs) + values = permutedims(values, (2, 1, 3)) + set_sub_blocks!( # Assign data to symmetrically equivalent blocks + matrix, values, off_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + end + end + + end + + + # off_site_states = _get_states( # Build states for the off-site atom-blocks + # off_blockᵢ, atoms, envelope(basis_off), cell_indices) + + # # Don't try to compute off-site interactions if none exist + # if length(off_site_states) > 0 + # let values = predict(basis_off, off_site_states) # Predict off-site sub-blocks + # set_sub_blocks!( # Assign off-site sub-blocks to the matrix + # matrix, values, off_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + + # _reflect_block_idxs!(off_blockᵢ, mirror_idxs) + # values = permutedims(values, (2, 1, 3)) + # set_sub_blocks!( # Assign data to symmetrically equivalent blocks + # matrix, values, off_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + # end + # end + + + # Evaluate on-site terms for homo-atomic interactions; but only if not instructed + # to approximate the on-site sub-blocks as identify matrices. + if species₁ ≡ species₂ && model.label ≠ "S" + # Get the on-site basis and construct the on-site states + basis_on = model.on_site_submodels[(species₁, shellᵢ, shellⱼ)] + on_site_states = _get_states(on_blockᵢ, atoms; r=radial(basis_on).R.ru) + + # Don't try to compute on-site interactions if none exist + if length(on_site_states) > 0 + let values = predict(basis_on, on_site_states) # Predict on-site sub-blocks + set_sub_blocks!( # Assign on-site sub-blocks to the matrix + matrix, values, on_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + values = permutedims(values, (2, 1, 3)) + set_sub_blocks!( # Assign data to the symmetrically equivalent blocks + matrix, values, on_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + end + end + end + + end + end + + return matrix +end + + +# function _predict(model, atoms, cell_indices) + +# # Todo:- +# # - use symmetry to prevent having to compute data for cells reflected +# # cell pairs; i.e. [ 0, 0, 1] & [ 0, 0, -1] +# # - Setting the on-sites to an identity should be determined by the model +# # rather than just assuming that the user always wants on-site overlap +# # blocks to be identity matrices. + +# basis_def = model.basis_definition +# n_orbs = number_of_orbitals(atoms, basis_def) + +# # Matrix into which the final results will be placed +# matrix = zeros(n_orbs, n_orbs, size(cell_indices, 2)) + +# # Mirror index map array required by `_reflect_block_idxs!` +# mirror_idxs = _mirror_idxs(cell_indices) + +# # The on-site blocks of overlap matrices are approximated as identity matrix. +# if model.label ≡ "S" +# matrix[1:n_orbs+1:n_orbs^2] .= 1.0 +# end + +# for (species₁, species₂) in species_pairs(atoms::Atoms) + +# # Matrix containing the block indices of all species₁-species₂ atom-blocks +# blockᵢ = repeat_atomic_block_idxs( +# atomic_block_idxs(species₁, species₂, atoms), size(cell_indices, 2)) + +# # Identify on-site sub-blocks now as they as static over the shell pair loop. +# # Note that when `species₁≠species₂` `length(on_blockᵢ)≡0`. +# on_blockᵢ = filter_on_site_idxs(blockᵢ) + +# Threads.@threads for (shellᵢ, shellⱼ) in shell_pairs(species₁, species₂, basis_def) + + +# # Get the off-site basis associated with this interaction +# basis_off = model.off_site_submodels[(species₁, species₂, shellᵢ, shellⱼ)] + +# # Identify off-site sub-blocks with bond-distances less than the specified cutoff +# off_blockᵢ = filter_idxs_by_bond_distance( +# filter_off_site_idxs(blockᵢ), +# envelope(basis_off).r0cut, atoms, cell_indices) + +# # Blocks in the lower triangle are redundant in the homo-orbital interactions +# if species₁ ≡ species₂ && shellᵢ ≡ shellⱼ +# off_blockᵢ = filter_upper_idxs(off_blockᵢ) +# end + +# off_site_states = _get_states( # Build states for the off-site atom-blocks +# off_blockᵢ, atoms, envelope(basis_off), cell_indices) + +# # Don't try to compute off-site interactions if none exist +# if length(off_site_states) > 0 +# let values = predict(basis_off, off_site_states) # Predict off-site sub-blocks +# set_sub_blocks!( # Assign off-site sub-blocks to the matrix +# matrix, values, off_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + +# _reflect_block_idxs!(off_blockᵢ, mirror_idxs) +# values = permutedims(values, (2, 1, 3)) +# set_sub_blocks!( # Assign data to symmetrically equivalent blocks +# matrix, values, off_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) +# end +# end + + +# # Evaluate on-site terms for homo-atomic interactions; but only if not instructed +# # to approximate the on-site sub-blocks as identify matrices. +# if species₁ ≡ species₂ && model.label ≠ "S" +# # Get the on-site basis and construct the on-site states +# basis_on = model.on_site_submodels[(species₁, shellᵢ, shellⱼ)] +# on_site_states = _get_states(on_blockᵢ, atoms; r=radial(basis_on).R.ru) + +# # Don't try to compute on-site interactions if none exist +# if length(on_site_states) > 0 +# let values = predict(basis_on, on_site_states) # Predict on-site sub-blocks +# set_sub_blocks!( # Assign on-site sub-blocks to the matrix +# matrix, values, on_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + +# values = permutedims(values, (2, 1, 3)) +# set_sub_blocks!( # Assign data to the symmetrically equivalent blocks +# matrix, values, on_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) +# end +# end +# end + +# end +# end + +# return matrix +# end + + +function _predict(model, atoms) + # Currently this method has the tendency to produce non-positive definite overlap + # matrices when working with the aluminum systems, however this is not observed in + # the silicon systems. As such this function should not be used for periodic systems + # until the cause of this issue can be identified. + @warn "This function is not to be trusted" + # TODO: + # - It seems like the `filter_idxs_by_bond_distance` method is not working as intended + # as results change based on whether this is enabled or disabled. + + # See comments in the real space matrix version of `predict` more information. + basis_def = model.basis_definition + n_orbs = number_of_orbitals(atoms, basis_def) + + matrix = zeros(n_orbs, n_orbs) + + # If constructing an overlap matrix then the on-site blocks can just be set to + # an identify matrix. + if model.label ≡ "S" + matrix[1:n_orbs+1:end] .= 1.0 + end + + for (species₁, species₂) in species_pairs(atoms::Atoms) + + blockᵢ = atomic_block_idxs(species₁, species₂, atoms) + + on_blockᵢ = filter_on_site_idxs(blockᵢ) + + for (shellᵢ, shellⱼ) in shell_pairs(species₁, species₂, basis_def) + + basis_off = model.off_site_submodels[(species₁, species₂, shellᵢ, shellⱼ)] + + off_blockᵢ = filter_idxs_by_bond_distance( + filter_off_site_idxs(blockᵢ), + envelope(basis_off).r0cut, atoms) + + if species₁ ≡ species₂ && shellᵢ ≡ shellⱼ + off_blockᵢ = filter_upper_idxs(off_blockᵢ) + end + + off_site_states = _get_states( + off_blockᵢ, atoms, envelope(basis_off)) + + if length(off_site_states) > 0 + let values = predict(basis_off, off_site_states) + set_sub_blocks!( + matrix, values, off_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + + _reflect_block_idxs!(off_blockᵢ) + values = permutedims(values, (2, 1, 3)) + set_sub_blocks!( + matrix, values, off_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + end + end + + + if species₁ ≡ species₂ && model.label ≠ "S" + basis_on = model.on_site_submodels[(species₁, shellᵢ, shellⱼ)] + on_site_states = _get_states(on_blockᵢ, atoms; r=radial(basis_on).R.ru) + + + if length(on_site_states) > 0 + let values = predict(basis_on, on_site_states) + set_sub_blocks!( + matrix, values, on_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + values = permutedims(values, (2, 1, 3)) + set_sub_blocks!( + matrix, values, on_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + end + end + end + + end + end + + return matrix +end + + +# ╭───────────────────────────╮ +# │ Internal Helper Functions │ +# ╰───────────────────────────╯ + +""" +Construct the mirror index map required by `_reflect_block_idxs!`. +""" +function _mirror_idxs(cell_indices) + mirror_idxs = Vector{Int}(undef, size(cell_indices, 2)) + let cell_to_index = Dict(cell=>idx for (idx, cell) in enumerate(eachcol(cell_indices))) + for i=1:length(mirror_idxs) + mirror_idxs[i] = cell_to_index[cell_indices[:, i] * -1] + end + end + return mirror_idxs +end + + +""" +This function takes in a `BlkIdx` entity as an argument & swaps the atomic block indices; +i.e. [1, 2] → [2, 1]. +""" +function _reflect_block_idxs!(block_idxs::BlkIdx) + @inbounds for i=1:size(block_idxs, 2) + block_idxs[1, i], block_idxs[2, i] = block_idxs[2, i], block_idxs[1, i] + end + nothing +end + +""" +Inverts a `BlkIdx` instance by swapping the atomic-indices and substitutes the cell index +for its reflected counterpart; i.e. [i, j, k] → [j, i, idx_mirror[k]]. +""" +function _reflect_block_idxs!(block_idxs::BlkIdx, idx_mirror::AbstractVector) + @inbounds for i=1:size(block_idxs, 2) + block_idxs[1, i], block_idxs[2, i] = block_idxs[2, i], block_idxs[1, i] + block_idxs[3, i] = idx_mirror[block_idxs[3, i]] + end + nothing +end + + +function _maximum_distance_estimation(model::Model) + # Maximum radial distance (on-site) + max₁ = maximum(values(model.on_site_parameters.e_cut_out)) + # Maximum radial distance (off-site) + max₂ = maximum(values(model.off_site_parameters.e_cut_out)) + # Maximum effective envelope distance + max₃ = maximum( + [sqrt((env.r0cut + env.zcut)^2 + (env.rcut/2)^2) + for env in envelope.(values(model.off_site_submodels))]) + + return max(max₁, max₂, max₃) + +end + +""" +The construction of neighbour lists can be computationally intensive. As such lists are +used frequently by the lower levels of the code, they are cached & only every recomputed +recomputed if the requested cutoff distance exceeds that used when building the cached +version. It has been found that because each basis can have a different cutoff distance +it is possible that, due to the effects of evaluation order, that the neighbour list can +end up being reconstructed many times. This can be mitigated by anticipating what the +largest cutoff distance is likely to be and pre-building the neighbour list ahead of time. +Hence this function. +""" +function _preinitialise_neighbour_list(atoms::Atoms, model::Model) + # Get a very rough approximation for what the largest cutoff distance might be when + # constructing the neighbour list. + r = _maximum_distance_estimation(model) * 1.1 + + # Construct construct and cache the maximum likely neighbour list + neighbourlist(atoms, r; fixcell=false); + + nothing +end + + +end diff --git a/examples/H2O/python_interface/large/src/properties.jl b/examples/H2O/python_interface/large/src/properties.jl new file mode 100644 index 0000000..e16670a --- /dev/null +++ b/examples/H2O/python_interface/large/src/properties.jl @@ -0,0 +1,216 @@ + + +# Warning this module is not release ready +module Properties + +using ACEhamiltonians, LinearAlgebra + +export real_to_complex!, real_to_complex, band_structure, density_of_states, eigenvalue_confidence_interval + + +const _π2im = -2.0π * im + +function eigvals_at_k(H::A, S::A, T, k_point; kws...) where A<:AbstractArray{<:AbstractFloat, 3} + return real(eigvals(real_to_complex(H, T, k_point), real_to_complex(S, T, k_point); kws...)) +end + + +phase(k::AbstractVector, T::AbstractVector) = exp(_π2im * (k ⋅ T)) +phase(k::AbstractVector, T::AbstractMatrix) = exp.(_π2im * (k' * T)) + +function real_to_complex!(A_complex::AbstractMatrix{C}, A_real::AbstractArray{F, 3}, T::AbstractMatrix, k_point; sym=false) where {C<:Complex, F<:AbstractFloat} + for i=1:size(T, 2) + @views A_complex .+= A_real[:, :, i] .* phase(k_point, T[:, i]) + end + if sym + A_complex .= (A_complex + A_complex') * 0.5 + end + nothing +end + + +""" + real_to_complex(A_real, T, k_point[; sym=false]) + +Compute the complex matrix at a given k-point for a given real-space matrix. + +# Arguments + - `A_real`: real-space matrix of size N×N×T, where N is the number of atomic + orbitals and T the number of cell translation vectors. + - `T`: cell translation vector matrix of size 3×T. + - `k_point`: the k-points for which the complex matrix should be returned. + - `sym`: if true the resulting matrix will be symmetrised prior to its return. + +# Returns + - `A_complex`: the real space matrix evaluated at the requested k-point. + +""" +function real_to_complex(A_real::AbstractArray{F, 3}, T, k_point::Vector; sym=false) where F<:AbstractFloat + A_complex = zeros(Complex{F}, size(A_real, 2), size(A_real, 2)) + real_to_complex!(A_complex, A_real, T, k_point; sym=sym) + return A_complex +end + + +function eigenvalue_confidence_interval(H, H̃, S, S̃, T, k_points, posterior=false) + n = size(H, 1) + C = complex(valtype(H)) + + H_k = Matrix{C}(undef, n, n) + S_k = Matrix{C}(undef, n, n) + H̃_k = Matrix{C}(undef, n, n) + S̃_k = Matrix{C}(undef, n, n) + ΔH = Matrix{C}(undef, n, n) + ΔS = Matrix{C}(undef, n, n) + + results = Matrix{valtype(H)}(undef, n, size(k_points, 2)) + + for (i, k_point) in enumerate(eachcol(k_points)) + fill!(H_k, zero(C)) + fill!(S_k, zero(C)) + fill!(H̃_k, zero(C)) + fill!(S̃_k, zero(C)) + + real_to_complex!(H_k, H, T, k_point) + real_to_complex!(S_k, S, T, k_point) + real_to_complex!(H̃_k, H̃, T, k_point) + real_to_complex!(S̃_k, S̃, T, k_point) + + ΔH[:, :] = H̃_k - H_k + ΔS[:, :] = S̃_k - S_k + + ϵ, φ = eigen!(H_k, S_k); + + ϵₜ = let + if !posterior + ϵ + else + eigen!(H̃_k, S̃_k).values() + end + end + + for (j, (ϵᵢ, φᵢ)) in enumerate(zip(ϵₜ, eachcol(φ))) + + results[j, i] = real.(φᵢ' * ((ΔH - ϵᵢ * ΔS) * φᵢ)) + end + + end + + return results +end + + +function gaussian_broadening(E, ϵ, σ) + return exp(-((E - ϵ) / σ)^2) / (sqrt(π) * σ) +end + +# function gaussian_broadening(E, dE, ϵ, σ) +# # Broadens in an identical manner to FHI-aims; not that this wil require +# # SpecialFunctions.erf to work. While the results returned by this method +# # match with FHI-aims the final DoS is off by a factor of 0.5 for some +# # reason; so double counting is happening somewhere. +# ga = erf((E - ϵ + (dE/2)) / (sqrt(2.0)σ)) +# gb = erf((E - ϵ - (dE/2)) / (sqrt(2.0)σ)) +# return (ga - gb) / 2dE +# end + + +""" +Density of states (k-point independent) +""" +function density_of_states(E::T, ϵ::T, σ) where T<:Vector{<:AbstractFloat} + dos = T(undef, length(E)) + for i=1:length(E) + dos[i] = sum(gaussian_broadening.(E[i], ϵ, σ)) + end + return dos +end + + + +""" +Density of states (k-point dependant) +""" +function density_of_states(E::V, ϵ::Matrix{F}, k_weights::V, σ; fermi=0.0) where V<:Vector{F} where F<:AbstractFloat + # A non-zero fermi value indicates that the energies `E` are relative to the fermi + # level. The most efficient and less user intrusive way to deal with this is create + # and operate on an offset copy of `E`. + if fermi ≠ 0.0 + E = E .+ fermi + end + + dos = zeros(F, length(E)) + let temp_array = zeros(F, size(ϵ)...) + for i=1:length(E) + temp_array .= gaussian_broadening.(E[i], ϵ, σ) + temp_array .*= k_weights' + dos[i] = sum(temp_array) + end + end + + return dos +end + + +function density_of_states(E::Vector, H::M, S::M, σ) where {M<:AbstractMatrix} + return density_of_states(E, eigvals(H, S), σ) +end + +function density_of_states(E::V, H::A, S::A, k_point::V, T, σ) where {V<:Vector{<:AbstractFloat}, A<:AbstractArray{<:AbstractFloat,3}} + return density_of_states(E, eigvals_at_k(H, S, T, k_point), σ) +end + +function density_of_states(E::Vector, H::A, S::A, k_points::AbstractMatrix, T, k_weights, σ; fermi=0.0) where {A<:AbstractArray{F, 3}} where F<:AbstractFloat + return density_of_states(E, band_structure(H, S, T, k_points), k_weights, σ; fermi) +end + + + + +""" + band_structure(H_real, S_real, T, k_points) + +# Arguments + - `H_real`: real space Hamiltonian matrix of size N×N×T, where N is the number of atomic + orbitals and T the number of cell translation vectors. + - `S_real`: real space overlap matrix of size N×N×T. + - `T`: cell translation vector matrix of size 3×T. + - `k_points`: a matrix specifying the k-points at which the eigenvalues should be evaluated. + +# Returns + - `eigenvalues`: eigenvalues evaluated along the specified k-point path. The columns of + this matrix loop over k-points and rows over states. + +""" +function band_structure(H_real::A, S_real::A, T, k_points) where A<:AbstractArray{F, 3} where F<:AbstractFloat + + C = Complex{F} + nₒ, nₖ = size(H_real, 2), size(k_points, 2) + + # Final results array. + ϵ = Matrix{F}(undef, nₒ, nₖ) + + # Construct the transient storage arrays + let H_complex = Matrix{C}(undef, nₒ, nₒ), S_complex = Matrix{C}(undef, nₒ, nₒ) + + # Loop over each k-point + for i=1:nₖ + # Clear the transient storage arrays + fill!(H_complex, zero(C)) + fill!(S_complex, zero(C)) + + # Evaluate the Hamiltonian and overlap matrices at the iᵗʰ k-point + real_to_complex!(H_complex, H_real, T, k_points[:, i]) + real_to_complex!(S_complex, S_real, T, k_points[:, i]) + + # Calculate the eigenvalues + ϵ[:, i] .= real(eigvals(H_complex, S_complex)) + end + + end + + return ϵ +end + + +end \ No newline at end of file diff --git a/examples/H2O/python_interface/large/src/states.jl b/examples/H2O/python_interface/large/src/states.jl new file mode 100644 index 0000000..e42ccd0 --- /dev/null +++ b/examples/H2O/python_interface/large/src/states.jl @@ -0,0 +1,498 @@ +module States +using ACEhamiltonians, NeighbourLists, JuLIP +using ACEhamiltonians.MatrixManipulation: BlkIdx +using StaticArrays: SVector +using LinearAlgebra: norm, normalize +using ACE: AbstractState, CylindricalBondEnvelope, BondEnvelope, _evaluate_bond, _evaluate_env +using ACEhamiltonians: BOND_ORIGIN_AT_MIDPOINT + +import ACEhamiltonians.Parameters: ison +import ACE: _inner_evaluate + +export BondState, AtomState, reflect, get_state + +# ╔════════╗ +# ║ States ║ +# ╚════════╝ +""" + BondState(mu, mu_i, mu_j, rr, rr0, bond) + +State entities used when representing the environment about a bond. + +# Fields +- `mu`: AtomicNumber of the current atom +- `mu_i`: AtomicNumber of the first bonding atom +- `mu_j`: AtomicNumber of the second bonding atom +- `rr`: environmental atom's position relative to the midpoint of the bond. +- `rr0`: vector between the two "bonding" atoms, i.e. the bond vector. +- `bond`: a boolean, which if true indicates the associated `BondState` entity represents + the bond itself. If false, then the state is taken to represent an environmental atom + about the bond rather than the bond itself. + +# Notes +If `bond == true` then `rr` should be set to `rr0/2`. If an environmental atom lies too +close to bond's midpoint then ACE may crash. Thus a small offset may be required in some +cases. + +# Developers Notes +An additional field will be added at a later data to facilitate multi-species support. It +is possible that the `BondState` structure will have to be split into two sub-structures. + +# Todo + - Documentation should be updated to account for the fact that the bond origin has been + moved back to the first atoms position. + - Call for better names for the first three fields +""" +const AN_default = AtomicNumber(:X) + +struct BondState{T<:SVector{3, <:AbstractFloat}, B<:Bool} <: AbstractState + mu::AtomicNumber + mu_i::AtomicNumber + mu_j::AtomicNumber + rr::T + rr0::T + bond::B +end + +BondState{T, Bool}(rr,rr0,bond::Bool) where T<:SVector{3, <:AbstractFloat} = BondState(AN_default,AN_default,AN_default,T(rr),T(rr0),bond) +BondState(rr,rr0,bond::Bool) = BondState{SVector{3, Float64}, Bool}(rr,rr0,bond) +""" + AtomState(mu,mu_i,rr) + +State entity representing the environment about an atom. + +# Fields +- `mu`: AtomicNumber of the current atom +- `mu_i`: AtomicNumber of the centre atom +- `rr`: environmental atom's position relative to the host atom. + +""" +struct AtomState{T<:SVector{3, <:AbstractFloat}} <: AbstractState + mu::AtomicNumber + mu_i::AtomicNumber + rr::T +end + +AtomState{T}(rr) where T<:SVector{3, <:AbstractFloat} = AtomState(AN_default,AN_default,T(rr)) +AtomState(rr) = AtomState{SVector{3, Float64}}(rr) + +# ╭────────┬───────────────────────╮ +# │ States │ General Functionality │ +# ╰────────┴───────────────────────╯ + +# Display methods to help alleviate endless terminal spam. +function Base.show(io::IO, state::BondState) + mu = state.mu + mu_i = state.mu_i + mu_j = state.mu_j + rr = string([round.(state.rr, digits=5)...]) + rr0 = string([round.(state.rr0, digits=5)...]) + print(io, "BondState(mu:$mu, mu_i:$mu_i, mu_j:$mu_j, rr:$rr, rr0:$rr0, bond:$(state.bond))") +end + +function Base.show(io::IO, state::AtomState) + mu = state.mu + mu_i = state.mu_i + rr = string([round.(state.rr, digits=5)...]) + print(io, "AtomState(mu:$mu, mu_i:$mu_i, rr:$rr)") +end + +# Allow for equality checks (will otherwise default to equivalency) +Base.:(==)(x::T, y::T) where T<:BondState = x.mu == y.mu && x.mu_i == y.mu_i && x.mu_j == y.mu_j && x.rr == y.rr && y.rr0 == y.rr0 && x.bond == y.bond +Base.:(==)(x::T, y::T) where T<:AtomState = x.mu == y.mu && x.mu_i == y.mu_i && x.rr == y.rr + +# The ≈ operator is commonly of more use for State entities than the equality +Base.isapprox(x::T, y::T; kwargs...) where T<:BondState = x.mu == y.mu && x.mu_i == y.mu_i && x.mu_j == y.mu_j && isapprox(x.rr, y.rr; kwargs...) && isapprox(x.rr0, y.rr0; kwargs...) && x.bond == y.bond +Base.isapprox(x::T, y::T; kwargs...) where T<:AtomState = x.mu == y.mu && x.mu_i == y.mu_i && isapprox(x.rr, y.rr; kwargs...) +Base.isapprox(x::T, y::T; kwargs...) where T<:AbstractVector{<:BondState} = all(x .≈ y) +Base.isapprox(x::T, y::T; kwargs...) where T<:AbstractVector{<:AtomState} = all(x .≈ y) + +# ACE requires the `zero` method to be defined for states. +Base.zero(::Type{BondState{T, S}}) where {T, S} = BondState{T, S}(zero(T), zero(T), true) +Base.zero(::Type{AtomState{T}}) where T = AtomState{T}(zero(T)) +Base.zero(::B) where B<:BondState = zero(B) +Base.zero(::B) where B<:AtomState = zero(B) +# Todo: +# - Identify why `zero(BondState)` is being called and if the location where it is used +# is effected by always choosing `bond` to be false. [PRIORITY:LOW] + +""" + ison(state) + +Return a boolean indicating whether the state entity is associated with either an on-site +or off-site interaction. +""" +ison(::T) where T<:AtomState = true +ison(::T) where T<:BondState = false + + +""" + reflect(state) + +Reflect `BondState` across the bond's midpoint. Calling on a state representing the bond A→B +will return the symmetrically B→A state. For states where `bond=true` this will flip the +sign on `rr` & `rr0`; whereas only `rr0` is flipped for the `bond=false` case. + +# Arguments +- `state::BondState`: the state to be reflected. + +# Returns +- `reflected_state::BondState`: a view of `state` reflected across the midpoint. + +# Warnings +This is only valid for bond states whose atomic positions are given relative to the midpoint +of the bond; i.e. `envelope.λ≡0`. +""" +function reflect(state::T) where T<:BondState + @static if BOND_ORIGIN_AT_MIDPOINT + if state.bond + return T(state.mu, state.mu_j, state.mu_i, -state.rr, -state.rr0, true) + else + return T(state.mu, state.mu_j, state.mu_i, state.rr, -state.rr0, false) + end + else + if state.bond + return T(state.mu, state.mu_j, state.mu_i, -state.rr, -state.rr0, true) + else + return T(state.mu, state.mu_j, state.mu_i, state.rr - state.rr0, -state.rr0, false) + end + end +end + +# `reflect` is just an identify function for `AtomState` instances. This is included to +# alleviate the need for branching elsewhere in the code. +reflect(state::AtomState) = state + +# ╭───────┬───────────╮ +# │ State │ Factories │ +# ╰───────┴───────────╯ +""" + get_state(i, atoms[; r=16.0]) + +Construct a state representing the environment about atom `i`. + +# Arguments +- `i::Integer`: index of the atom whose state is to be constructed. +- `atoms::Atoms`: the `Atoms` object in which atom `i` resides. +- `r::AbstractFloat`: maximum distance up to which neighbouring atoms + should be considered. + +# Returns +- `state::Vector{::AtomState}`: state objects representing the environment about atom `i`. + +""" +function get_state(i::Integer, atoms::Atoms; r::AbstractFloat=16.0) + # Construct the neighbour list (this is cached so speed is not an issue) + pair_list = JuLIP.neighbourlist(atoms, r; fixcell=false) + + # Extract environment about each relevant atom from the pair list. These will be tuples + # of the form: (atomic-index, relative-position) + idxs, vecs, species = JuLIP.Potentials.neigsz(pair_list, atoms, i) + + # Construct the `AtomState`` vector + st = [ AtomState(species[j], atoms.Z[i], vecs[j]) for j = 1:length(species) ] + + # Return an AtomState vector without those outside the cutoff sphere. + return filter(k -> norm(k.rr) <= r, st) +end + + +""" + get_state(i, j, atoms, envelope[, image]) + +Construct a state representing the environment about the "bond" between atoms `i` & `j`. + +# Arguments +- `i::Int`: atomic index of the first bonding atom. +- `j::Int`: atomic index of the second bonding atom. +- `atoms::Atoms`: the `Atoms` object in which atoms `i` and `j` reside. +- `envelope::CylindricalBondEnvelope:` an envelope specifying the volume to consider when + constructing the state. This must be centred at the bond's midpoint; i.e. `envelope.λ≡0`. +- `image::Optional{Vector}`: a vector specifying the image in which atom `j` + should reside; i.e. the cell translation vector. This defaults to `nothing` which will + result in the closets periodic image of `j` being used. +- `r::AbstractFloat`: this can be used to manually specify the cutoff distance used when + building the neighbour list. This will override the locally computed value for `r` and + is primarily used to aid debugging. + +# Returns +- `state::Vector{::BondState}`: state objects representing the environment about the bond + between atoms `i` and `j`. + +# Notes +It is worth noting that a state will be constructed for the ij bond even when the distance +between them exceeds the bond-cutoff specified by the `envelope`. The maximum cutoff +distance for neighbour list construction is handled automatically. + +# Warnings +- The neighbour list for the bond is constructed by applying an offset to the first atom's + neighbour list. As such spurious states will be encountered when the ij distance exceeds + the bond cutoff value `envelope.r0cut`. Do not ignore this warning! +- It is vital to ensure that when an `image` is supplied that all atomic coordinates are + correctly wrapped into the unit cell. If fractional coordinates lie outside of the range + [0, 1] then the results of this function will not be correct. + +""" +function get_state( + i::I, j::I, atoms::Atoms, envelope::CylindricalBondEnvelope, + image::Union{AbstractVector{I}, Nothing}=nothing; r::Union{Nothing, <:AbstractFloat}=nothing) where {I<:Integer} + + # Todo: + # - Combine the neighbour lists of atom i and j rather than just the former. This + # will reduce the probably of spurious state construction. But will increase run + # time as culling of duplicate states and bond states will be required. + # - rr for the bond really should be halved and inverted to match up with the + # environmental coordinate system. + + # Neighbour list cutoff distance; accounting for distances being relative to atom `i` + # rather than the bond's mid-point + if isnothing(r) + r = sqrt((envelope.r0cut + envelope.zcut)^2 + envelope.rcut^2) + end + + # Neighbours list construction (about atom `i`) + idxs, vecs, cells, species = _neighbours(i, atoms, r) + + # Get the bond vector between atoms i & j; where i is in the origin cell & j resides + # in either i) closest periodic image, or ii) that specified by `image` if provided. + if isnothing(image) + # Identify the shortest i→j vector account for PBC. + idx = _locate_minimum_image(j, idxs, vecs) + rr0 = vecs[idx] + else + @assert length(image) == 3 "image must be a vector of length three" + # Find the vector between atom i in the origin cell and atom j in cell `image`. + idx = _locate_target_image(j, idxs, cells, image) + if idx != 0 + rr0 = vecs[idx] + else # Special case where the cutoff was too short to catch the desired i→j bond. + # In this case we must calculate rr0 manually. + rr0 = atoms.X[j] - atoms.X[i] + (adjoint(image .* atoms.pbc) * atoms.cell).parent + end + end + + # The i→j bond vector must be removed from `vecs` so that it does not get treated as + # an environmental atom in the for loop later on. This operation is done even if the + # `idx==0` to maintain type stability. + @views vecs_no_bond = vecs[1:end .!= idx] + @views species_no_bond = species[1:end .!= idx] + + # `BondState` entity vector + states = Vector{BondState{typeof(rr0), Bool}}(undef, length(vecs_no_bond) + 1) + + # Construct the bond vector state; i.e where `bond=true` + states[1] = BondState(atoms.Z[j],atoms.Z[i],atoms.Z[j],rr0, rr0, true) + + @static if BOND_ORIGIN_AT_MIDPOINT + # As the mid-point of the bond is used as the origin an offset is needed to shift + # vectors so they're relative to the bond's midpoint and not atom `i`. + offset = rr0 * 0.5 + end + + # Construct the environmental atom states; i.e. where `bond=false`. + for (k, v⃗) in enumerate(vecs_no_bond) + @static if BOND_ORIGIN_AT_MIDPOINT + # Offset the positions so that they are relative to the bond's midpoint. + states[k+1] = BondState{typeof(rr0), Bool}(species_no_bond[k], atoms.Z[i], atoms.Z[j], v⃗ - offset, rr0, false) + else + states[k+1] = BondState{typeof(rr0), Bool}(species_no_bond[k], atoms.Z[i], atoms.Z[j], v⃗, rr0, false) + end + + end + + # Cull states outside of the bond envelope using the envelope's filter operator. This + # task is performed manually here in an effort to reduce run time and memory usage. + @views mask = _inner_evaluate.(Ref(envelope), states[2:end]) .!= 0.0 + @views n = sum(mask) + 1 + @views states[2:n] = states[2:end][mask] + + return states[1:n] + +end + +# Commonly one will need to collect multiple states rather than single states on their +# own. Hence the `get_state[s]` functions. These functions have been tagged for internal +# use only until they can be polished up. + +""" + _get_states(block_idxs, atoms, envelope[, images]) + +Get the states describing the environments about a collection of bonds as defined by the +block index list `block_idxs`. This is effectively just a fancy wrapper for `get_state'. + +# Arguments +- `block_idxs`: atomic index matrix in which the first & second rows specify the indices of the + two "bonding" atoms. The third row, if present, is used to index `images` to collect + cell in which the second atom lies. +- `atoms`: the `Atoms` object in which that atom pair resides. +- `envelope`: an envelope specifying the volume to consider when constructing the states. +- `images`: Cell translation index lookup list, this is only relevant when `block_idxs` + supplies and cell index value. The cell translation index for the iᵗʰ state will be + taken to be `images[block_indxs[i, 3]]`. + +# Returns +- `bond_states::Vector{::Vector{::BondState}}`: a vector providing the requested bond states. + +# Developers Notes +This is currently set to private until it is cleaned up. + +""" +function _get_states(block_idxs::BlkIdx, atoms::Atoms{T}, envelope::CylindricalBondEnvelope, + images::Union{AbstractMatrix{I}, Nothing}=nothing) where {I, T} + if isnothing(images) + if size(block_idxs, 1) == 3 && any block_idxs[3, :] != 1 + throw(ArgumentError("`idxs` provides non-origin cell indices but no + `images` argument was given!")) + end + return get_state.(block_idxs[1, :], block_idxs[2, :], Ref(atoms), Ref(envelope))::Vector{Vector{BondState{SVector{3, T}, Bool}}} + else + # If size(block_idxs,1) == 2, i.e. no cell index is supplied then this will error out. + # Thus not manual error handling is required. If images are supplied then block_idxs + # must contain the image index. + return get_state.( + block_idxs[1, :], block_idxs[2, :], Ref(atoms), + Ref(envelope), eachcol(images[:, block_idxs[3, :]]))::Vector{Vector{BondState{SVector{3, T}, Bool}}} + end +end + + +""" + _get_states(block_idxs, atoms[; r=16.0]) + +Get states describing the environments around each atom block specified in `block_idxs`. +Note that `block_idxs` is assumed to contain only on-site blocks. This is just a wrapper +for `get_state'. + +# Developers Notes +This is currently set to private until it is cleaned up. + +""" +function _get_states(block_idxs::BlkIdx, atoms::Atoms{T}; r=16.0) where T + if @views block_idxs[1, :] != block_idxs[2, :] + throw(ArgumentError( + "The supplied `block_idxs` represent a hetroatomic interaction. But the function + called is for retrieving homoatomic states.")) + end + # Type ambiguities in the JuLIP.Atoms structure means that Julia cannot determine the + # function's return type; specifically the value type of the static vector. Thus some + # pseudo type hard coding must be done here. + return get_state.(block_idxs[1, :], (atoms,); r=r)::Vector{Vector{AtomState{SVector{3, T}}}} +end + + + +# ╭────────┬──────────────────────────╮ +# │ States │ Factory Helper Functions │ +# ╰────────┴──────────────────────────╯ + + +""" + _neighbours(i, atoms, r) + +Identify and return information about atoms neighbouring atom `i` in system `atoms`. + +# Arguments +- `i::Int`: index of the atom for which the neighbour list is to be constructed. +- `atoms::Atoms`: system in which atom `i` is located. +- `r::AbstractFloat`: cutoff distance to for the neighbour list. Due to the effects off + cacheing this should be treated as if it were a lower bounds for the cutoff rather than + the cutoff itself. + +# Returns +- `idxs`: atomic index of each neighbour. +- `vecs`: distance vector to each neighbour. +- `cells`: index specifying the cell in which the neighbouring atom resides. + +# Warnings +Due to the effects of caching there is a high probably that the returned neighbour list +will contain neighbours at distances greater than `r`. + +""" +function _neighbours(i::Integer, atoms::Atoms, r::AbstractFloat) + pair_list = JuLIP.neighbourlist(atoms, r; fixcell=false) + return [ NeighbourLists.neigss(pair_list, i)..., JuLIP.Potentials.neigsz(pair_list,atoms,i)[3] ] +end + + +""" + _locate_minimum_image(j, idxs, vecs) + +Index of the closest `j` neighbour accounting for periodic boundary conditions. + +# Arguments +- `j::Integer`: Index of the atom for for whom the minimum image is to be identified. +- `idxs::Vector{::Integer}`: Integers specifying the indices of the atoms to two which + the distances in `vecs` correspond. +- `vecs::Vector{SVector{3, ::AbstractFloat}}`: Vectors between the the source atom and + the target atom. + +# Returns +- `index::Integer`: an index `k` for which `vecs[k]` will yield the vector between the + source atom and the closest periodic image of atom `j`. + +# Notes +If multiple minimal vectors are found, then the first one will be returned. + +# Todo +- This will error out when the cutoff distance is lower than the bond distance. While such + an occurrence is unlikely in smaller cells it will no doubt occur in larger ones. + +""" +function _locate_minimum_image(j::Integer, idxs::AbstractVector{<:Integer}, vecs::AbstractVector{<:AbstractVector{<:AbstractFloat}}) + # Locate all entries in the neighbour list that correspond to atom `j` + js = findall(==(j), idxs) + if length(js) == 0 + # See the "Todo" section in the docstring. + error("Neighbour not in range") + end + + # Identify which image of atom `j` is closest + return js[findmin(norm, vecs[js])[2]] +end + +""" + _locate_target_image(j, idxs, images, image) + +Search through the neighbour list for atoms with the atomic index `j` that reside in +the specified `image` and return its index. If no such match is found, as can happen +when the cutoff distance is too short, then an index of zero is returned. + +# Arguments +- `j`: index of the desired neighbour. +- `idxs`: atomic indices of atoms in the neighbour list. +- `images`: image in which the neighbour list atoms reside. +- `image`: image in which the target neighbour should reside. + +# Returns +- `idx::Int`: index of the first entry in the neighbour list representing an atom with the + atom index `j` residing in the image `image`. Zero if no matches are found. + +# Notes +The `images` argument is set vector of vectors here as this is represents the type returned +by the neighbour list constructor. Blocking other types prevents any misunderstandings. + +# Todo: +- Test for type instabilities + +""" +function _locate_target_image(j::I, idxs::AbstractVector{I}, images::AbstractVector{<:AbstractVector{I}}, image::AbstractVector{I})::I where I<:Integer + js = findall(==(j), idxs) + idx = findfirst(i -> all(i .== image), images[js]) + return isnothing(idx) ? zero(I) : js[idx] +end + + +# ╭────────┬───────────╮ +# │ States │ Overrides │ +# ╰────────┴───────────╯ +# Local override to account for the `BondState` field `be::Symbol` being replaced with the +# field `bond::Bool`. +function _inner_evaluate(env::BondEnvelope, state::BondState) + if state.bond + return _evaluate_bond(env, state) + else + return _evaluate_env(env, state) + end +end + +end diff --git a/examples/H2O/python_interface/small/script/function.jl b/examples/H2O/python_interface/small/script/function.jl new file mode 100644 index 0000000..57d05fc --- /dev/null +++ b/examples/H2O/python_interface/small/script/function.jl @@ -0,0 +1,31 @@ +using Distributed +addprocs(28) +@everywhere begin + using JuLIP: Atoms + using LinearAlgebra + using Statistics + using PyCall + using ACEhamiltonians + import ACEhamiltonians: predict, Model +end + +@everywhere function predictpp(atoms::Vector{PyObject}, model::Model) + atoms = [Atoms(Z=atom.get_atomic_numbers(), X=transpose(atom.positions), cell=collect(Float64.(I(3) * 100)), pbc=true) for atom in atoms] + n = nworkers() + len = length(atoms) + chunk_size = ceil(Int, len / n) + chunks_atoms = [atoms[(i-1)*chunk_size+1:min(i*chunk_size, len)] for i in 1:n] + + images = cell_translations.(atoms, Ref(model)) + chunks_images = [images[(i-1)*chunk_size+1:min(i*chunk_size, len)] for i in 1:n] + + predicted = Any[] + for (chunk_atoms, chunk_images) in zip(chunks_atoms, chunks_images) + task = @spawn predict.(Ref(model), chunk_atoms, chunk_images) + push!(predicted, task) + end + predicted = fetch.(predicted) + + predicted = vcat(predicted...) + return predicted +end diff --git a/examples/H2O/python_interface/small/script/py4ACE.py b/examples/H2O/python_interface/small/script/py4ACE.py new file mode 100644 index 0000000..d68ae17 --- /dev/null +++ b/examples/H2O/python_interface/small/script/py4ACE.py @@ -0,0 +1,50 @@ +############################configure the julia project and load the interface###################### +import os +import julia +import time + +# Specify the path to your Julia project or environment +os.environ["JULIA_PROJECT"] = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA" +julia.install() +from julia.api import Julia +jl = Julia(compiled_modules=False) + +from julia import Main, Serialization, Base +Main.include("/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/data_interface/debug_mono/function.jl") + + +##########################################load the model############################################# +model_path = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_1_rcut_6/H2O_H_aims.bin" +model = Serialization.deserialize(model_path) + + + +################################define a python format ase.atoms object############################### +from ase import Atoms +import numpy as np +O_H_distance = 0.96 +angle = 104.5 + +angle_rad = np.radians(angle / 2) + +x = O_H_distance * np.sin(angle_rad) +y = O_H_distance * np.cos(angle_rad) +z = 0.0 + +positions = [ + (0, 0, 0), # Oxygen + (x, y, 0), # Hydrogen 1 + (x, -y, 0) # Hydrogen 2 +] +water = Atoms('OH2', positions=positions) + + + +####################################################################################################### +ts = time.time() +predicted = Main.predictpp([water]*28, model) +ti = time.time() +print(f"time to predict is {ti-ts} second.") + +predicted = [h for h in predicted] +np.save("/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/data_interface/debug_mono/predicted.npy", predicted) diff --git a/examples/H2O/python_interface/small/src/ACEhamiltonians.jl b/examples/H2O/python_interface/small/src/ACEhamiltonians.jl new file mode 100644 index 0000000..c4c3ec0 --- /dev/null +++ b/examples/H2O/python_interface/small/src/ACEhamiltonians.jl @@ -0,0 +1,95 @@ +module ACEhamiltonians + +using JuLIP, JSON, HDF5, Reexport, LinearAlgebra + +using ACE.SphericalHarmonics: SphericalCoords +import ACE.SphericalHarmonics: cart2spher + +export BasisDef, DUAL_BASIS_MODEL, BOND_ORIGIN_AT_MIDPOINT, SYMMETRY_FIX_ENABLED + +# Enabling this will activate the dual basis mode +const DUAL_BASIS_MODEL = true + +# If `true` then bond origin will be set to the bond's mid-point +const BOND_ORIGIN_AT_MIDPOINT = false + +# This can be used to enable/disable the symmetry fix code +const SYMMETRY_FIX_ENABLED = false + +if SYMMETRY_FIX_ENABLED && !BOND_ORIGIN_AT_MIDPOINT + @warn "Symmetry fix code is only viable when the bond origin is set to the midpoint " +end + + +if SYMMETRY_FIX_ENABLED && DUAL_BASIS_MODEL + @warn "It is ill advised to enable the symmetry fix when in dual basis mode" +end + +""" + BasisDef(atomic_number => [ℓ₁, ..., ℓᵢ], ...) + +Provides information about the basis set by specifying the azimuthal quantum numbers (ℓ) +of each shell on each species. Dictionary is keyed by atomic numbers & valued by vectors +of ℓs i.e. `Dict{atomic_number, [ℓ₁, ..., ℓᵢ]}`. + +A minimal basis set for hydrocarbon systems would be `BasisDef(1=>[0], 6=>[0, 0, 1])`. +This declares hydrogen atoms as having only a single s-shell and carbon atoms as having +two s-shells and one p-shell. +""" +BasisDef = Dict{I, Vector{I}} where I<:Integer + +@static if BOND_ORIGIN_AT_MIDPOINT + function cart2spher(r⃗::AbstractVector{T}) where T + @assert length(r⃗) == 3 + # When the length of the vector `r⃗` is zero then the signs can have a + # destabalising effect on the results; i.e. + # cart2spher([0., 0., 0.]) ≠ cart2spher([0. 0., -0.]) + # Hense the following catch: + if norm(r⃗) ≠ 0.0 + φ = atan(r⃗[2], r⃗[1]) + θ = atan(hypot(r⃗[1], r⃗[2]), r⃗[3]) + sinφ, cosφ = sincos(φ) + sinθ, cosθ = sincos(θ) + return SphericalCoords{T}(norm(r⃗), cosφ, sinφ, cosθ, sinθ) + else + return SphericalCoords{T}(0.0, 1.0, 0.0, 1.0, 0.0) + end + end +end + +include("common.jl") +@reexport using ACEhamiltonians.Common + +include("io.jl") +@reexport using ACEhamiltonians.DatabaseIO + +include("parameters.jl") +@reexport using ACEhamiltonians.Parameters + +include("data.jl") +@reexport using ACEhamiltonians.MatrixManipulation + +include("states.jl") +@reexport using ACEhamiltonians.States + +include("basis.jl") +@reexport using ACEhamiltonians.Bases + +include("models.jl") +@reexport using ACEhamiltonians.Models + +include("datastructs.jl") +@reexport using ACEhamiltonians.DataSets + +include("fitting.jl") +@reexport using ACEhamiltonians.Fitting + +include("predicting.jl") +@reexport using ACEhamiltonians.Predicting + +include("properties.jl") +@reexport using ACEhamiltonians.Properties + +include("api/dftbp_api.jl") + +end diff --git a/examples/H2O/python_interface/small/src/api/dftbp_api.jl b/examples/H2O/python_interface/small/src/api/dftbp_api.jl new file mode 100644 index 0000000..efbbb93 --- /dev/null +++ b/examples/H2O/python_interface/small/src/api/dftbp_api.jl @@ -0,0 +1,357 @@ +module DftbpApi +using ACEhamiltonians +using BlockArrays, StaticArrays, Serialization +using LinearAlgebra: norm, diagind +using ACEbase: read_dict, load_json + +using ACEhamiltonians.States: _inner_evaluate + + +export load_model, n_orbs_per_atom, offers_species, offers_species, species_name_to_id, max_interaction_cutoff, + max_environment_cutoff, shells_on_species!, n_shells_on_species, shell_occupancies!, build_on_site_atom_block!, + build_off_site_atom_block! + + + +# WARNING; THIS CODE IS NOT STABLE UNTIL THE SYMMETRY ISSUE HAS BEEN RESOLVED, DO NOT USE. + + +# Todo: +# - resolve unit mess. + +_Bohr2Angstrom = 1/0.188972598857892E+01 +_F64SV = SVector{3, Float64} +_FORCE_SHOW_ERROR = true + + +macro FSE(func) + # Force show error + if _FORCE_SHOW_ERROR + func.args[2] = quote + try + $(func.args[2].args...) + catch e + println("\nError encountered in Julia-DFTB+ API") + for (exc, bt) in current_exceptions() + showerror(stdout, exc, bt) + println(stdout) + println("Terminating....") + # Ensure streams are flushed prior to the `exit` call below + flush(stdout) + flush(stderr) + end + # The Julia thread must be explicitly terminated, otherwise the DFTB+ + # calculation will continue. + exit() + end + end + end + return func +end + + +_sub_block_sizes(species, basis_def) = 2basis_def[species] .+ 1 + +function _reshape_to_block(array, species_1, species_2, basis_definition) + return PseudoBlockArray( + reshape( + array, + number_of_orbitals(species_1, basis_definition), + number_of_orbitals(species_2, basis_definition) + ), + _sub_block_sizes(species_1, basis_definition), + _sub_block_sizes(species_2, basis_definition) + ) +end + +# Setup related function + +_s2n = Dict( + "H"=>1, "He"=>2, "Li"=>3, "Be"=>4, "B"=>5, "C"=>6, "N"=>7, "O"=>8, "F"=>9,"Ne"=>10, + "Na"=>11, "Mg"=>12, "Al"=>13, "Si"=>14, "P"=>15, "S"=>16, "Cl"=>17, "Ar"=>18, + "K"=>19, "Ca"=>20, "Sc"=>21, "Ti"=>22, "V"=>23, "Cr"=>24, "Mn"=>25, "Fe"=>26, + "Co"=>27, "Ni"=>28, "Cu"=>29, "Zn"=>30, "Ga"=>31, "Ge"=>32, "As"=>33, "Se"=>34, + "Br"=>35, "Kr"=>36, "Rb"=>37, "Sr"=>38, "Y"=>39, "Zr"=>40, "Nb"=>41, "Mo"=>42, + "Tc"=>43, "Ru"=>44, "Rh"=>45, "Pd"=>46, "Ag"=>47, "Cd"=>48, "In"=>49, "Sn"=>50, + "Sb"=>51, "Te"=>52, "I"=>53, "Xe"=>54, "Cs"=>55, "Ba"=>56, "La"=>57, "Ce"=>58, + "Pr"=>59, "Nd"=>60, "Pm"=>61, "Sm"=>62, "Eu"=>63, "Gd"=>64, "Tb"=>65, "Dy"=>66, + "Ho"=>67, "Er"=>68, "Tm"=>69, "Yb"=>70, "Lu"=>71, "Hf"=>72, "Ta"=>73, "W"=>74, + "Re"=>75, "Os"=>76, "Ir"=>77, "Pt"=>78, "Au"=>79, "Hg"=>80, "Tl"=>81, "Pb"=>82, + "Bi"=>83, "Po"=>84, "At"=>85, "Rn"=>86, "Fr"=>87, "Ra"=>88, "Ac"=>89, "Th"=>90, + "Pa"=>91, "U"=>92, "Np"=>93, "Pu"=>94, "Am"=>95, "Cm"=>96, "Bk"=>97, "Cf"=>98, + "Es"=>99, "Fm"=>10, "Md"=>10, "No"=>10, "Lr"=>10, "Rf"=>10, "Db"=>10, "Sg"=>10, + "Bh"=>10, "Hs"=>10, "Mt"=>10, "Ds"=>11, "Rg"=>111) + + +"""Return the unique id number associated with species of a given name.""" +@FSE function species_name_to_id(name, model) + # The current implementation of this function is only a temporary measure. + # Once atom names have been replaced with enums and the atomic number in "Basis.id" + # is replaced with a species id then this function will be able interrogate the + # model and its bases for information on what species it supports. Furthermore + # this will allow for multiple species of the same atomic number to be used. + return Int32(_s2n[name]) + +end + +"""Returns true of the supplied model supports the provided species""" +@FSE function offers_species(name, model) + if haskey(_s2n, name) + return _s2n[name] ∈ Set([id[1] for id in keys(model.on_site_bases)]) + else + return false + end +end + +@FSE function n_orbs_per_atom(species_id, model) + return Int32(sum(2model.basis_definition[species_id].+1)) +end + +"""Maximum environmental cutoff distance""" +@FSE function max_environment_cutoff(model) + max_on = maximum(values(model.on_site_parameters.e_cut_out)) + max_off = maximum(values(model.off_site_parameters.e_cut_out)) + distance = max(max_on, max_off) + return distance / _Bohr2Angstrom +end + + +"""Maximum interaction cutoff distance""" +@FSE function max_interaction_cutoff(model) + distance = maximum([env.r0cut for env in envelope.(values(model.off_site_bases))]) + return distance / _Bohr2Angstrom +end + + +@FSE function n_shells_on_species(species, model) + return Int32(length(model.basis_definition[species])) +end + +@FSE function shells_on_species!(array, species, model) + shells = model.basis_definition[species] + if length(array) ≠ length(shells) + println("shells_on_species!: Provided array is of incorrect length.") + throw(BoundsError("shells_on_species!: Provided array is of incorrect length.")) + end + array[:] = shells + nothing +end + +@FSE function shell_occupancies!(array, species, model) + if !haskey(model.meta_data, "occupancy") + throw(KeyError( + "shell_occupancies!: an \"occupancy\" key must be present in the model's + `meta_data` which provides the occupancies for each shell of each species." + )) + end + + occupancies = model.meta_data["occupancy"][species] + + if length(array) ≠ length(occupancies) + throw(BoundsError("shell_occupancies!: Provided array is of incorrect length.")) + end + + array[:] = occupancies + +end + + +@FSE function load_model(path::String) + if endswith(lowercase(path), ".json") + return read_dict(load_json(path)) + elseif endswith(lowercase(path), ".bin") + return deserialize(path) + else + error("Unknown file extension used; only \"json\" & \"bin\" are supported") + end +end + + +@FSE function _build_atom_state(coordinates, cutoff) + + # Todo: + # - The distance filter can likely be removed as atoms beyond the cutoff will be + # ignored by ACE. Tests will need to be performed to identify which is more + # performant; culling here or letting ACE handle the culling. + + # Build a list of static coordinate vectors, excluding the origin. + positions = map(_F64SV, eachcol(coordinates[:, 2:end])) + + # Exclude positions that lie outside of the cutoff allowed by the model. + positions_in_range = positions[norm.(positions) .≤ cutoff] + + # Construct the associated state object vector + return map(AtomState, positions_in_range) +end + +# Method for when bond origin is the midpoint +# function _build_bond_state(coordinates, envelope) +# # Build a list of static coordinate vectors, excluding the two bonding +# # atoms. +# positions = map(_F64SV, eachcol(coordinates[:, 3:end])) + +# # Coordinates must be rounded to prevent stability issues associated with +# # noise. This mostly only effects situations where atoms lie near the mid- +# # point of a bond. +# positions = [round.(i, digits=8) for i in positions] + + +# # The rest of this function copies code directly from `states.get_state`. +# # Here, rr0 is multiplied by two as vectors provided by DFTB+ point to +# # the midpoint of the bond. ACEhamiltonians expects the bond vector and +# # to be inverted, hence the second position is taken. +# rr0 = _F64SV(round.(coordinates[:, 2], digits=8) * 2) +# states = Vector{BondState{_F64SV, Bool}}(undef, length(positions) + 1) +# states[1] = BondState(_F64SV(round.(coordinates[:, 2], digits=8)), rr0, true) + +# for k=1:length(positions) +# states[k+1] = BondState{_F64SV, Bool}(positions[k], rr0, false) +# end + +# @views mask = _inner_evaluate.(Ref(envelope), states[2:end]) .!= 0.0 +# @views n = sum(mask) + 1 +# @views states[2:n] = states[2:end][mask] +# return states[1:n] +# end + +# Method for when bond origin is the first atoms position +function _build_bond_state(coordinates, envelope) + # Build a list of static coordinate vectors, excluding the two bonding + # atoms. + positions = map(_F64SV, eachcol(coordinates[:, 3:end])) + + # The rest of this function copies code directly from `states.get_state`. + # Here, rr0 is multiplied by two as vectors provided by DFTB+ point to + # the midpoint of the bond. ACEhamiltonians expects the bond vector and + # to be inverted, hence the second position is taken. + rr0 = _F64SV(coordinates[:, 2] * 2) + offset = rr0 / 2 + states = Vector{BondState{_F64SV, Bool}}(undef, length(positions) + 1) + states[1] = BondState(_F64SV(coordinates[:, 2] * 2), rr0, true) + + for k=1:length(positions) + states[k+1] = BondState{_F64SV, Bool}(positions[k] + offset, rr0, false) + end + + @views mask = _inner_evaluate.(Ref(envelope), states[2:end]) .!= 0.0 + @views n = sum(mask) + 1 + @views states[2:n] = states[2:end][mask] + return states[1:n] +end + + +function build_on_site_atom_block!(block::Vector{Float64}, coordinates::Vector{Float64}, species, model) + basis_def = model.basis_definition + n_shells = length(basis_def[species[1]]) + + # Unflatten the coordinates array + coordinates = reshape(coordinates, 3, :) * _Bohr2Angstrom + + # Unflatten the atom-block array and convert it into a PseudoBlockMatrix + block = _reshape_to_block(block, species[1], species[1], basis_def) + + # On-site atom block of the overlap matrix are just an identify matrix + if model.label == "S" + block .= 0.0 + block[diagind(block)] .= 1.0 + return nothing + end + + # Loop over all shell pairs + for i_shell=1:n_shells + for j_shell=i_shell:n_shells + + # Pull out the associated sub-block as a view + @views sub_block = block[Block(i_shell, j_shell)] + + # Select the appropriate model + basis = model.on_site_bases[(species[1], i_shell, j_shell)] + + # Construct the on-site state taking into account the required cutoff + state = _build_atom_state(coordinates, radial(basis).R.ru) + + # Make the prediction + predict!(sub_block, basis, state) + + # Set the symmetrically equivalent block when appropriate + if i_shell ≠ j_shell + @views block[Block(j_shell, i_shell)] = sub_block' + end + end + end + +end + + +@FSE function build_off_site_atom_block!(block::Vector{Float64}, coordinates::Vector{Float64}, species, model) + # Need to deal with situation where Z₁ > Z₂ + basis_def = model.basis_definition + species_i, species_j = species[1:2] + n_shells_i = number_of_shells(species_i, basis_def) + n_shells_j = number_of_shells(species_j, basis_def) + + # Unflatten the coordinates array + coordinates = reshape(coordinates, 3, :) * _Bohr2Angstrom + + # Unflatten the atom-block array and convert it into a PseudoBlockMatrix + block = _reshape_to_block(block, species_i, species_j, basis_def) + + # By default only interactions where species-i ≥ species-j are defined as + # adding interactions for species-i < species-j would be redundant. + if species_i > species_j + block = block' + n_shells_i, n_shells_j = n_shells_j, n_shells_i + reflect_state = true + else + reflect_state = false + end + + # Loop over all shell pairs + for i_shell=1:n_shells_i + for j_shell=1:n_shells_j + + # Skip over i_shell > j_shell homo-atomic interactions + species_i ≡ species_j && i_shell > j_shell && continue + + # Pull out the associated sub-block as a view + @views sub_block = block[Block(i_shell, j_shell)] + + # Select the appropriate model + basis = model.off_site_bases[(species[1], species[2], i_shell, j_shell)] + + # Construct the on-site state taking into account the required cutoff + state = _build_bond_state(coordinates, envelope(basis)) + + if reflect_state + state = reflect.(state) + end + + # Make the prediction + predict!(sub_block, basis, state) + + if species_i ≡ species_j + @views predict!(block[Block(j_shell, i_shell)]', basis, reflect.(state)) + end + + end + end + +end + + +# if model.label == "H" +# basis = model.off_site_bases[(species[1], species[2], 1, 1)] +# state = _build_bond_state(coordinates, envelope(basis)) +# dump("states.bin", state) +# dump("atom_blocks.bin", block) +# end + +# function dump(path, data::T) where T +# data_set = isfile(path) ? deserialize(path) : T[] +# append!(data_set, (data,)) +# serialize(path, data_set) +# nothing +# end + + +end \ No newline at end of file diff --git a/examples/H2O/python_interface/small/src/basis.jl b/examples/H2O/python_interface/small/src/basis.jl new file mode 100644 index 0000000..553e859 --- /dev/null +++ b/examples/H2O/python_interface/small/src/basis.jl @@ -0,0 +1,579 @@ +module Bases + +using ACEhamiltonians, ACE, ACEbase, SparseArrays, LinearAlgebra, ACEatoms, JuLIP + +using ACEhamiltonians.Parameters: OnSiteParaSet, OffSiteParaSet +using ACE: SymmetricBasis, SphericalMatrix, Utils.RnYlm_1pbasis, SimpleSparseBasis, + CylindricalBondEnvelope, Categorical1pBasis, cutoff_radialbasis, cutoff_env, + get_spec, coco_dot + +using ACEbase.ObjectPools: VectorPool +using ACEhamiltonians: BOND_ORIGIN_AT_MIDPOINT, SYMMETRY_FIX_ENABLED + + +import ACEbase: read_dict, write_dict +import LinearAlgebra.adjoint, LinearAlgebra.transpose, Base./ +import Base +import ACE: SphericalMatrix + + +export AHSubModel, radial, angular, categorical, envelope, on_site_ace_basis, off_site_ace_basis, filter_offsite_be, is_fitted, SubModel, AnisoSubModel +""" +TODO: + - A warning should perhaps be given if no filter function is given when one is + expected; such as off-site functions. If no-filter function is desired than + a dummy filter should be required. + - Improve typing for the Model structure. + - Replace e_cutₒᵤₜ, e_cutᵢₙ, etc. with more "typeable" names. +""" +###################### +# SubModel Structure # +###################### +# Todo: +# - Document +# - Give type information +# - Serialization routines + + + +# ╔══════════╗ +# ║ SubModel ║ +# ╚══════════╝ + +abstract type AHSubModel end + + + +""" + + +A Linear ACE model for modelling symmetry invariant interactions. +In the context of Hamiltonian, this is just a (sub-)model for some specific +blocks and is hence called SubModel + +# Fields +- `basis::SymmetricBasis`: +- `id::Tuple`: +- `coefficients::Vector`: +- `mean::Matrix`: + +""" +struct SubModel{T₁<:SymmetricBasis, T₂, T₃, T₄} <: AHSubModel + basis::T₁ + id::T₂ + coefficients::T₃ + mean::T₄ + + function SubModel(basis, id) + t = ACE.valtype(basis) + F = real(t.parameters[5]) + SubModel(basis, id, zeros(F, length(basis)), zeros(F, size(zero(t)))) + end + + function SubModel(basis::T₁, id::T₂, coefficients::T₃, mean::T₄) where {T₁, T₂, T₃, T₄} + new{T₁, T₂, T₃, T₄}(basis, id, coefficients, mean) + end + +end + +""" + +Another linear ACE model for modelling symmetry variant interactions. + + +- `basis::SymmetricBasis`: +- `basis_i::SymmetricBasis`: +- `id::Tuple`: +- `coefficients::Vector`: +- `coefficients_i::Vector`: +- `mean::Matrix`: +- `mean_i::Matrix`: + +""" +struct AnisoSubModel{T₁<:SymmetricBasis, T₂<:SymmetricBasis, T₃, T₄, T₅, T₆, T₇} <: AHSubModel + basis::T₁ + basis_i::T₂ + id::T₃ + coefficients::T₄ + coefficients_i::T₅ + mean::T₆ + mean_i::T₇ + + function AnisoSubModel(basis, basis_i, id) + t₁, t₂ = ACE.valtype(basis), ACE.valtype(basis_i) + F = real(t₁.parameters[5]) + AnisoSubModel( + basis, basis_i, id, zeros(F, length(basis)), zeros(F, length(basis_i)), + zeros(F, size(zero(t₁))), zeros(F, size(zero(t₂)))) + end + + function AnisoSubModel(basis::T₁, basis_i::T₂, id::T₃, coefficients::T₄, coefficients_i::T₅, mean::T₆, mean_i::T₇) where {T₁, T₂, T₃, T₄, T₅, T₆, T₇} + new{T₁, T₂, T₃, T₄, T₅, T₆, T₇}(basis, basis_i, id, coefficients, coefficients_i, mean, mean_i) + end +end + +AHSubModel(basis, id) = SubModel(basis, id) +AHSubModel(basis, basis_i, id) = AnisoSubModel(basis, basis_i, id) + + +# ╭──────────┬───────────────────────╮ +# │ SubModel │ General Functionality │ +# ╰──────────┴───────────────────────╯ +"""Boolean indicating whether a `SubModel` instance is fitted; i.e. has non-zero coefficients""" +is_fitted(submodel::AHSubModel) = !all(submodel.coefficients .≈ 0.0) || !all(submodel.mean .≈ 0.0) + + +"""Check if two `SubModel` instances are equivalent""" +function Base.:(==)(x::T₁, y::T₂) where {T₁<:AHSubModel, T₂<:AHSubModel} + + # Check that the ID's, coefficients and means match up first + check = x.id == y.id && size(x.mean) == size(y.mean) && x.mean == y.mean + + # If they don't then return false. Otherwise perform a check of the basis object + # itself. A try/catch block must be used when comparing the bases as this can + # result in a DimensionMismatch. + if !check + return check + else + try + return x.basis == y.basis + catch y + if isa(y, DimensionMismatch) + return false + else + rethrow(y) + end + + end + end +end + + +"""Expected shape of the sub-block associated with the `SubModel`; 3×3 for a pp basis etc.""" +Base.size(submodel::AHSubModel) = (ACE.valtype(submodel.basis).parameters[3:4]...,) + +# """Expected type of resulting sub-blocks.""" +# Base.valtype(::Basis{T}) where T = T + +"""Expected type of resulting sub-blocks.""" +function Base.valtype(::AHSubModel) + throw("AHSubModel structure type has been changed this function must be updated.") +end + + +"""Azimuthal quantum numbers associated with the `SubModel`.""" +azimuthals(submodel::AHSubModel) = (ACE.valtype(submodel.basis).parameters[1:2]...,) + +"""Returns a boolean indicating if the submodel instance represents an on-site interaction.""" +Parameters.ison(x::AHSubModel) = length(x.id) ≡ 3 + + +""" + _filter_bases(submodel, type) + +Helper function to retrieve specific submodel function information out of a `AHSubModel` instance. +This is an internal function which is not expected to be used outside of this module. + +Arguments: +- `submodel::AHSubModel`: submodel instance from which function is to be extracted. +- `type::DataType`: type of the submodel functions to extract; e.g. `CylindricalBondEnvelope`. +""" +function _filter_bases(submodel::AHSubModel, T) + functions = filter(i->i isa T, submodel.basis.pibasis.basis1p.bases) + if length(functions) == 0 + error("Could not locate submodel function matching the supplied type") + elseif length(functions) ≥ 2 + @warn "Multiple matching submodel functions found, only the first will be returned" + end + return functions[1] +end + +"""Extract and return the radial component of a `AHSubModel` instance.""" +radial(submodel::AHSubModel) = _filter_bases(submodel, ACE.Rn1pBasis) + +"""Extract and return the angular component of a `AHSubModel` instance.""" +angular(submodel::AHSubModel) = _filter_bases(submodel, ACE.Ylm1pBasis) + +"""Extract and return the categorical component of a `AHSubModel` instance.""" +categorical(submodel::AHSubModel) = _filter_bases(submodel, ACE.Categorical1pBasis) + +"""Extract and return the bond envelope component of a `AHSubModel` instance.""" +envelope(submodel::AHSubModel) = _filter_bases(submodel, ACE.BondEnvelope) + +#TODO: add extract function for species1p basis, if needed + + +# ╭──────────┬──────────────────╮ +# │ SubModel │ IO Functionality │ +# ╰──────────┴──────────────────╯ +""" + write_dict(submodel[,hash_basis]) + +Convert a `SubModel` structure instance into a representative dictionary. + +# Arguments +- `submodel::SubModel`: the `SubModel` instance to parsed. +- `hash_basis::Bool`: ff `true` then hash values will be stored in place of + the `SymmetricBasis` objects. +""" +function write_dict(submodel::T, hash_basis=false) where T<:SubModel + return Dict( + "__id__"=>"SubModel", + "basis"=>hash_basis ? string(hash(submodel.basis)) : write_dict(submodel.basis), + "id"=>submodel.id, + "coefficients"=>write_dict(submodel.coefficients), + "mean"=>write_dict(submodel.mean)) + +end + + +"""Instantiate a `SubModel` instance from a representative dictionary.""" +function ACEbase.read_dict(::Val{:SubModel}, dict::Dict) + return SubModel( + read_dict(dict["basis"]), + Tuple(dict["id"]), + read_dict(dict["coefficients"]), + read_dict(dict["mean"])) +end + + +function Base.show(io::IO, submodel::T) where T<:AHSubModel + print(io, "$(nameof(T))(id: $(submodel.id), fitted: $(is_fitted(submodel)))") +end + + +# ╔════════════════════════╗ +# ║ ACE Basis Constructors ║ +# ╚════════════════════════╝ + +# Codes hacked to proceed to enable multispecies basis construction +ACE.get_spec(basis::Species1PBasis, i::Integer) = (μ = basis.zlist.list[i],) +ACE.get_spec(basis::Species1PBasis) = ACE.get_spec.(Ref(basis), 1:length(basis)) +# Base.length(a::Nothing) = 0 + +@doc raw""" + + on_site_ace_basis(ℓ₁, ℓ₂, ν, deg, e_cutₒᵤₜ[, r0]) + +Initialise a simple on-site `SymmetricBasis` instance with sensible default parameters. + +The on-site `SymmetricBasis` entities are produced by applying a `SimpleSparseBasis` +selector to a `Rn1pBasis` instance. The latter of which is initialised via the `Rn_basis` +method, using all the defaults associated therein except `e_cutₒᵤₜ` and `e_cutᵢₙ` which are +provided by this function. This facilitates quick construction of simple on-site `Bases` +instances; if more fine-grain control over over the initialisation process is required +then bases must be instantiated manually. + +# Arguments +- `(ℓ₁,ℓ₂)::Integer`: azimuthal numbers of the basis function. +- `ν::Integer`: maximum correlation order = body order - 1. +- `deg::Integer`: maximum polynomial degree. +- `e_cutₒᵤₜ::AbstractFloat`: only atoms within the specified cutoff radius will contribute + to the local environment. +- `r0::AbstractFloat`: scaling parameter (typically set to the nearest neighbour distances). +- `species::Union{Nothing, Vector{AtomicNumber}}`: A set of species of a system + +# Returns +- `basis::SymmetricBasis`: ACE basis entity for modelling the specified interaction. + +""" +function on_site_ace_basis(ℓ₁::I, ℓ₂::I, ν::I, deg::I, e_cutₒᵤₜ::F, r0::F=2.5; species::Union{Nothing, Vector{AtomicNumber}}=nothing + ) where {I<:Integer, F<:AbstractFloat} + # Build i) a matrix indicating the desired sub-block shape, ii) the one + # particle Rₙ·Yₗᵐ basis describing the environment, & iii) the basis selector. + # Then instantiate the SymmetricBasis required by the Basis structure. + if !isnothing(species) + return SymmetricBasis( + SphericalMatrix(ℓ₁, ℓ₂; T=ComplexF64), + Species1PBasis(species) * RnYlm_1pbasis(maxdeg=deg, r0=r0, rcut=e_cutₒᵤₜ), + SimpleSparseBasis(ν, deg)) + else + return SymmetricBasis( + SphericalMatrix(ℓ₁, ℓ₂; T=ComplexF64), + RnYlm_1pbasis(maxdeg=deg, r0=r0, rcut=e_cutₒᵤₜ), + SimpleSparseBasis(ν, deg)) + end +end + + + +function _off_site_ace_basis_no_sym(ℓ₁::I, ℓ₂::I, ν::I, deg::I, b_cut::F, e_cutₒᵤₜ::F=5.; + λₙ::F=.5, λₗ::F=.5, species::Union{Nothing, Vector{AtomicNumber}}=nothing) where {I<:Integer, F<:AbstractFloat} + + + # Bond envelope which controls which atoms are seen by the bond. + @static if BOND_ORIGIN_AT_MIDPOINT + env = CylindricalBondEnvelope(b_cut, e_cutₒᵤₜ, e_cutₒᵤₜ, floppy=false, λ=0.0) + else + env = CylindricalBondEnvelope(b_cut, e_cutₒᵤₜ, e_cutₒᵤₜ, floppy=false, λ=0.5) + end + + # Categorical1pBasis is applied to the basis to allow atoms which are part of the + # bond to be treated differently to those that are just part of the environment. + discriminator = Categorical1pBasis([true, false]; varsym=:bond, idxsym=:bond) + + # The basis upon which the above entities act. Note that the internal cutoff "rin" must + # be set to 0.0 to allow for atoms a the bond's midpoint to be observed. + RnYlm = RnYlm_1pbasis( + maxdeg=deg, rcut=cutoff_env(env), + trans=IdTransform(), rin=0.0) + + B1p = RnYlm * env * discriminator + if !isnothing(species) + B1p = Species1PBasis(species) * B1p + end + + # Finally, construct and return the SymmetricBasis entity + basis = SymmetricBasis( + SphericalMatrix(ℓ₁, ℓ₂; T=ComplexF64), + B1p, SimpleSparseBasis(ν + 1, deg), + filterfun = indices -> _filter_offsite_be(indices, deg, λₙ, λₗ)) + + + return basis +end + + +function _off_site_ace_basis_sym(ℓ₁::I, ℓ₂::I, ν::I, deg::I, b_cut::F, e_cutₒᵤₜ::F=5.; + λₙ::F=.5, λₗ::F=.5, species::Union{Nothing, Vector{AtomicNumber}}=nothing) where {I<:Integer, F<:AbstractFloat} + + # TODO: For now the symmetrised basis only works for single species case but it can be easily extended + @assert species == nothing || length(species) <= 1 + + basis = _off_site_ace_basis_no_sym(ℓ₁, ℓ₂, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ, species = species) + + if ℓ₁ == ℓ₂ + Uᵢ = let A = get_spec(basis.pibasis) + U = adjoint.(basis.A2Bmap) * (_perm(A) * _dpar(A)) + (basis.A2Bmap + U) .* 0.5 + end + + # Purge system of linear dependence + svdC = svd(_build_G(Uᵢ)) + rk = rank(Diagonal(svdC.S), rtol = 1e-7) + Uⱼ = sparse(Diagonal(sqrt.(svdC.S[1:rk])) * svdC.U[:, 1:rk]') + U_new = Uⱼ * Uᵢ + + # construct symmetric offsite basis + basis = SymmetricBasis(basis.pibasis, U_new, basis.symgrp, basis.real) + + elseif ℓ₁ > ℓ₂ + U_new = let A = get_spec(basis.pibasis) + adjoint.(_off_site_ace_basis_no_sym(ℓ₂, ℓ₁, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ).A2Bmap) * _perm(A) * _dpar(A) + end + + basis = SymmetricBasis(basis.pibasis, U_new, basis.symgrp, basis.real) + end + + + return basis +end + +@doc raw""" + + off_site_ace_basis(ℓ₁, ℓ₂, ν, deg, b_cut[,e_cutₒᵤₜ, λₙ, λₗ]) + + +Initialise a simple off-site `SymmetricBasis` instance with sensible default parameters. + +Operates similarly to [`on_site_ace_basis`](@ref) but applies a `CylindricalBondEnvelope` to +the `Rn1pBasis` basis instance. The length and radius of the cylinder are defined as +maths: ``b_{cut}+2e_{cut\_out}`` and maths: ``e_{cut\_out}`` respectively; all other +parameters resolve to their defaults as defined by their constructors. Again, instances +must be manually instantiated if more fine-grained control is desired. + +# Arguments +- `(ℓ₁,ℓ₂)::Integer`: azimuthal numbers of the basis function. +- `ν::Integer`: maximum correlation order. +- `deg::Integer`: maximum polynomial degree. +- `b_cut::AbstractFloat`: cutoff distance for bonded interactions. +- `e_cutₒᵤₜ::AbstractFloat`: radius and axial-padding of the cylindrical bond envelope that + is used to determine which atoms impact to the bond's environment. +- `λₙ::AbstractFloat`: +- `λₗ::AbstractFloat`: +- `species::Union{Nothing, Vector{AtomicNumber}}`: A set of species of a system + +# Returns +- `basis::SymmetricBasis`: ACE basis entity for modelling the specified interaction. + +""" +function off_site_ace_basis(ℓ₁::I, ℓ₂::I, ν::I, deg::I, b_cut::F, e_cutₒᵤₜ::F=5.; + λₙ::F=.5, λₗ::F=.5, symfix=true, species::Union{Nothing, Vector{AtomicNumber}}=nothing) where {I<:Integer, F<:AbstractFloat} + # WARNING symfix might cause issues when applied to interactions between different species. + # It is still not clear how appropriate this non homo-shell interactions. + @static if SYMMETRY_FIX_ENABLED + if symfix && ( isnothing(species) || length(species) <= 1) + basis = _off_site_ace_basis_sym(ℓ₁, ℓ₂, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ, species = species) + else + basis = _off_site_ace_basis_no_sym(ℓ₁, ℓ₂, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ, species = species) + end + else + basis = _off_site_ace_basis_no_sym(ℓ₁, ℓ₂, ν, deg, b_cut, e_cutₒᵤₜ; λₙ=λₙ, λₗ=λₗ, species = species) + end + + return basis + +end + + + + +""" + _filter_offsite_be(states, max_degree[, λ_n=0.5, λ_l=0.5]) + +Cap the the maximum polynomial components. + +This filter function should be passed, via the keyword `filterfun`, to `SymmetricBasis` +when instantiating. + + +# Arguments +- `indices::Tuple`: set of index that defines 1pbasis, e.g., (n,l,m,:be). +- `max_degree::Integer`: maximum polynomial degree. +- `λ_n::AbstractFloat`: +- `λ_l::AbstractFloat`: + +# Developers Notes +This function and its doc-string will be rewritten once its function and arguments have +been identified satisfactorily. + +# Examples +This is primarily intended to act as a filter function for off site bases like so: +``` +julia> off_site_sym_basis = SymmetricBasis( + φ, basis, selector, + filterfun = states -> filter_offsite_be(states, max_degree) +``` + +# Todo + - This should be inspected and documented. + - Refactoring may improve performance. + +""" +function _filter_offsite_be(indices, max_degree, λ_n=.5, λ_l=.5) + if length(indices) == 0; return false; end + deg_n, deg_l = ceil(Int, max_degree * λ_n), ceil(Int, max_degree * λ_l) + for idx in indices + if !idx.bond && (idx.n>deg_n || idx.l>deg_l) + return false + end + end + return sum(idx.bond for idx in indices) == 1 +end + + +adjoint(A::SphericalMatrix{ℓ₁, ℓ₂, LEN1, LEN2, T, LL}) where {ℓ₁, ℓ₂, LEN1, LEN2, T, LL} = SphericalMatrix(A.val', Val{ℓ₂}(), Val{ℓ₁}()) +transpose(A::SphericalMatrix{ℓ₁, ℓ₂, LEN1, LEN2, T, LL}) where {ℓ₁, ℓ₂, LEN1, LEN2, T, LL} = SphericalMatrix(transpose(A.val), Val{ℓ₂}(), Val{ℓ₁}()) +/(A::SphericalMatrix{ℓ₁, ℓ₂, LEN1, LEN2, T, LL}, b::Number) where {ℓ₁, ℓ₂, LEN1, LEN2, T, LL} = SphericalMatrix(A.val / b, Val{ℓ₁}(), Val{ℓ₂}()) +Base.:(*)(A::SphericalMatrix{ℓ₁, ℓ₂, LEN1, LEN2, T, LL}, b::Number) where {ℓ₁, ℓ₂, LEN1, LEN2, T, LL} = SphericalMatrix(A.val * b, Val{ℓ₁}(), Val{ℓ₂}()) + + +function _dpar(A) + parity = spzeros(Int, length(A), length(A)) + + for i=1:length(A) + for j=1:length(A[i]) + if A[i][j].bond + parity[i,i] = (-1)^A[i][j].l + break + end + end + end + + return parity + end + +""" +This function that takes an arbitrary named tuple and returns an identical copy with the +value of its "m" field inverted, if present. +""" +@generated function _invert_m(i) + + filed_names = Meta.parse( + join( + [field_name ≠ :m ? "i.$field_name" : "-i.$field_name" + for field_name in fieldnames(i)], + ", ") + ) + + return quote + $i(($(filed_names))) + end +end + +function _perm(A::Vector{Vector{T}}) where T + # This function could benefit from some optimisation. However it is stable for now. + + # Dictionary mapping groups of A to their column index. + D = Dict(sort(j)=>i for (i,j) in enumerate(A)) + + # Ensure that there is no double mapping going on; as there is not clear way + # to handel such an occurrence. + @assert length(D) ≡ length(A) "Double mapping ambiguity present in \"A\"" + + # Sparse zero matrix to hold the results + P = spzeros(length(A), length(A)) + + # Track which systems have been evaluated ahead of their designated loop via + # symmetric equivalence. This done purely for the sake of speed. + done_by_symmetry = zeros(Int, length(A)) + + for (i, A_group) in enumerate(A) + # Skip over groups that have already been assigned during evaluation of their + # symmetric equivalent. + i in done_by_symmetry && continue + + # Construct the "m" inverted named tuple + U = sort([_invert_m(t) for t in A_group]) + + # Identify which column the conjugate group `U` is associated. + idx = D[U] + + # Update done_by_symmetry checklist to prevent evaluating the symmetrically + # equivalent group. (if it exists) + done_by_symmetry[i] = idx + + # Compute and assign the cumulative "m" parity term for the current group and + # its symmetric equivalent (if present.) + P[i, idx] = P[idx, i] = (-1)^sum(o->o.m, A_group) + + end + + return P +end + +function _build_G(U) + # This function is highly inefficient and would benefit from a rewrite. It should be + # possible to make use of sparsity information present in U ahead of time to Identify + # when, where, and how many non-sparse points there are. + n_rows = size(U, 1) + # A sparse matrix would be more appropriate here given the high degree of sparsity + # present in `G`. However, this matrix is to be passed into the LinearAlgebra.svd + # method which is currently not able to operate upon sparse arrays. Hence a dense + # array is used here. + # G = spzeros(valtype(U[1].val), n_rows, n_rows) + G = zeros(valtype(U[1].val), n_rows, n_rows) + for row=1:n_rows, col=row:n_rows + result = sum(coco_dot.(U[row, :], U[col, :])) + if !iszero(result) + G[row, col] = G[col, row] = result + end + end + return G +end + + + +""" +At the time of writing there is an oversite present in the SparseArrays module which +prevents it from correctly identifying if an operation is sparsity preserving. That +is to say, in many cases a sparse matrix will be converted into its dense form which +can have profound impacts on performance. This function exists correct this behaviour, +and should be removed once the fixes percolate through to the stable branch. +""" +function Base.:(==)(x::SphericalMatrix, y::Integer) + return !(any(real(x.val) .!= y) || any(imag(x.val) .!= y)) +end + + +end diff --git a/examples/H2O/python_interface/small/src/common.jl b/examples/H2O/python_interface/small/src/common.jl new file mode 100644 index 0000000..7bbe630 --- /dev/null +++ b/examples/H2O/python_interface/small/src/common.jl @@ -0,0 +1,107 @@ +module Common +using ACEhamiltonians +using JuLIP: Atoms +export parse_key, with_cache, number_of_orbitals, species_pairs, shell_pairs, number_of_shells, test_function + +# Converts strings into tuples of integers or integers as appropriate. This function +# should be refactored and moved to a more appropriate location. It is mostly a hack +# at the moment. +function parse_key(key) + if key isa Integer || key isa Tuple + return key + elseif '(' in key + return Tuple(parse.(Int, split(strip(key, ['(', ')', ' ']), ", "))) + else + return parse(Int, key) + end +end + + +"""Nᵒ of orbitals present in system `atoms` based on a given basis definition.""" +function number_of_orbitals(atoms::Atoms, basis_definition::BasisDef) + # Work out the number of orbitals on each species + n_orbs = Dict(k=>sum(v * 2 .+ 1) for (k, v) in basis_definition) + # Use this to get the sum of the number of orbitals on each atom + return sum(getindex.(Ref(n_orbs), getfield.(atoms.Z, :z))) +end + +"""Nᵒ of orbitals present on a specific element, `z`, based on a given basis definition.""" +number_of_orbitals(z::I, basis_definition::BasisDef) where I<:Integer = sum(basis_definition[z] * 2 .+ 1) + +"""Nᵒ of shells present on a specific element, `z`, based on a given basis definition.""" +number_of_shells(z::I, basis_definition::BasisDef) where I<:Integer = length(basis_definition[z]) + +""" +Todo: + - Document this function correctly. + +Returns a cached guarded version of a function that stores known argument-result pairs. +This reduces the overhead associated with making repeated calls to expensive functions. +It is important to note that results for identical inputs will be the same object. + +# Warnings +Do no use this function, it only supports a very specific use-case. + +Although similar to the Memoize packages this function's use case is quite different and +is not supported by Memoize; hence the reimplementation here. This is mostly a stop-gap +measure and will be refactored at a later data. +""" +function with_cache(func::Function)::Function + cache = Dict() + function cached_function(args...; kwargs...) + k = (args..., kwargs...) + if !haskey(cache, k) + cache[k] = func(args...; kwargs...) + end + return cache[k] + end + return cached_function +end + + +_triangular_number(n::I) where I<:Integer = n*(n + 1)÷2 + +function species_pairs(atoms::Atoms) + species = sort(unique(getfield.(atoms.Z, :z))) + n = length(species) + + pairs = Vector{NTuple{2, valtype(species)}}(undef, _triangular_number(n)) + + c = 0 + for i=1:n, j=i:n + c += 1 + pairs[c] = (species[i], species[j]) + end + + return pairs +end + + +function shell_pairs(species_1, species_2, basis_def) + n₁, n₂ = length(basis_def[species_1]), length(basis_def[species_2]) + c = 0 + + if species_1 ≡ species_2 + pairs = Vector{NTuple{2, Int}}(undef, _triangular_number(n₁)) + + for i=1:n₁, j=i:n₁ + c += 1 + pairs[c] = (i, j) + end + + return pairs + else + pairs = Vector{NTuple{2, Int}}(undef, n₁ * n₂) + + for i=1:n₁, j=1:n₂ + c += 1 + pairs[c] = (i, j) + end + + return pairs + + end +end + + +end \ No newline at end of file diff --git a/examples/H2O/python_interface/small/src/data.jl b/examples/H2O/python_interface/small/src/data.jl new file mode 100644 index 0000000..8cb8e71 --- /dev/null +++ b/examples/H2O/python_interface/small/src/data.jl @@ -0,0 +1,779 @@ +module MatrixManipulation + +using LinearAlgebra: norm, pinv +using ACEhamiltonians +using JuLIP: Atoms + +export BlkIdx, atomic_block_idxs, repeat_atomic_block_idxs, filter_on_site_idxs, + filter_off_site_idxs, filter_upper_idxs, filter_lower_idxs, get_sub_blocks, + filter_idxs_by_bond_distance, set_sub_blocks!, get_blocks, set_blocks!, + locate_and_get_sub_blocks + +# ╔═════════════════════╗ +# ║ Matrix Manipulation ║ +# ╚═════════════════════╝ + +# ╭─────────────────────┬────────╮ +# │ Matrix Manipulation │ BlkIdx │ +# ╰─────────────────────┴────────╯ +""" +An alias for `AbstractMatrix` used to signify a block index matrix. Given the frequency & +significance of block index matrices it became prudent to create an alias for it. This +helps to i) make it clear when and where a block index is used and ii) prevent having to +repeatedly explained what a block index matrix was each time one was used. + +As the name suggests, these are matrices which specifies the indices of a series of atomic +blocks. The 1ˢᵗ row specifies the atomic indices of the 1ˢᵗ atom in each block and the 2ⁿᵈ +row the indices of the 2ⁿᵈ atom. That is to say `block_index_matrix[:,i]` would yield the +atomic indices of the atoms associated with the iᵗʰ block listed in `block_index_matrix`. +The 3ʳᵈ row, if present, specifies the index of the cell in in which the second atom lies; +i.e. it indexes the cell translation vector list. + +For example; `BlkIdx([1 3; 2 4])` specifies two atomic blocks, the first being between +atoms 1&2, and the second between atoms 3&4; `BlkIdx([5; 6; 10])` represents the atomic +block between atoms 5&6, however in this case there is a third number 10 which give the +cell number. The cell number is can be used to help 3D real-space matrices or indicate +which cell translation vector should be applied. + +It is important to note that the majority of functions that take `BlkIdx` as an argument +will assume the first and second species are consistent across all atomic-blocks. +""" +BlkIdx = AbstractMatrix + + +# ╭─────────────────────┬─────────────────────╮ +# │ Matrix Manipulation │ BlkIdx:Constructors │ +# ╰─────────────────────┴─────────────────────╯ +# These are the main methods by which `BlkIdx` instances are are constructed & expanded. + +""" + atomic_block_idxs(z_1, z_2, z_s[; order_invariant=false]) + +Construct a block index matrix listing all atomic-blocks present in the supplied system +where the first and second species are `z_1` and `z_2` respectively. + +# Arguments +- `z_1::Int`: first interacting species +- `z_2::Int`: second interacting species +- `z_s::Vector`: atomic numbers present in system +- `order_invariant::Bool`: by default, `block_idxs` only indexes atomic blocks in which the + 1ˢᵗ & 2ⁿᵈ species are `z_1` & `z_2` respectively. However, if `order_invariant` is enabled + then `block_idxs` will index all atomic blocks between the two species irrespective of + which comes first. + +# Returns +- `block_idxs::BlkIdx`: a 2×N matrix in which each column represents the index of an atomic block. + With `block_idxs[:, i]` yielding the atomic indices associated with the iᵗʰ atomic-block. + +# Notes +Enabling `order_invariant` is analogous to the following compound call: +`hcat(atomic_block_idxs(z_1, z_2, z_s), atomic_block_idxs(z_2, z_1, z_s))` +Furthermore, only indices associated with the origin cell are returned; if extra-cellular +blocks are required then `repeat_atomic_block_idxs` should be used. + +# Examples +``` +julia> atomic_numbers = [1, 1, 8] +julia> atomic_block_idxs(1, 8, atomic_numbers) +2×2 Matrix{Int64}: + 1 2 + 3 3 + +julia> atomic_block_idxs(8, 1, atomic_numbers) +2×2 Matrix{Int64}: + 3 3 + 1 2 + +julia> atomic_block_idxs(8, 1, atomic_numbers; order_invariant=true) +2×4 Matrix{Int64}: + 3 3 1 2 + 1 2 3 3 +``` +""" +function atomic_block_idxs(z_1::I, z_2::I, z_s::Vector; order_invariant::Bool=false) where I<:Integer + # This function uses views, slices and reshape operations to construct the block index + # list rather than an explicitly nested for-loop to reduce speed. + z_1_idx, z_2_idx = findall(==(z_1), z_s), findall(==(z_2), z_s) + n, m = length(z_1_idx), length(z_2_idx) + if z_1 ≠ z_2 && order_invariant + res = Matrix{I}(undef, 2, n * m * 2) + @views let res = res[:, 1:end ÷ 2] + @views reshape(res[1, :], (m, n)) .= z_1_idx' + @views reshape(res[2, :], (m, n)) .= z_2_idx + end + + @views let res = res[:, 1 + end ÷ 2:end] + @views reshape(res[1, :], (n, m)) .= z_2_idx' + @views reshape(res[2, :], (n, m)) .= z_1_idx + end + else + res = Matrix{I}(undef, 2, n * m) + @views reshape(res[1, :], (m, n)) .= z_1_idx' + @views reshape(res[2, :], (m, n)) .= z_2_idx + end + return res +end + +function atomic_block_idxs(z_1, z_2, z_s::Atoms; kwargs...) + return atomic_block_idxs(z_1, z_2, convert(Vector{Int}, z_s.Z); kwargs...) +end + + +""" + repeat_atomic_block_idxs(block_idxs, n) + +Repeat the atomic blocks indices `n` times and adds a new row specifying image number. +This is primarily intended to be used as a way to extend an atom block index list to +account for periodic images as they present in real-space matrices. + +# Arguments +- `block_idxs::BlkIdx`: the block indices which are to be expanded. + +# Returns +- `block_indxs_expanded::BlkIdx`: expanded block indices. + +# Examples +``` +julia> block_idxs = [10 10 20 20; 10 20 10 20] +julia> repeat_atomic_block_idxs(block_idxs, 2) +3×8 Matrix{Int64}: + 10 10 20 20 10 10 20 20 + 10 20 10 20 10 20 10 20 + 1 1 1 1 2 2 2 2 +``` +""" +function repeat_atomic_block_idxs(block_idxs::BlkIdx, n::T) where T<:Integer + @assert size(block_idxs, 1) != 3 "`block_idxs` has already been expanded." + m = size(block_idxs, 2) + res = Matrix{T}(undef, 3, m * n) + @views reshape(res[3, :], (m, n)) .= (1:n)' + @views reshape(res[1:2, :], (2, m, n)) .= block_idxs + return res +end + + +# ╭─────────────────────┬──────────────────╮ +# │ Matrix Manipulation │ BlkIdx:Ancillary │ +# ╰─────────────────────┴──────────────────╯ +# Internal functions that operate on `BlkIdx`. + +""" + _block_starts(block_idxs, atoms, basis_def) + +This function takes a series of atomic block indices, `block_idxs`, and returns the index +of the first element in each atomic block. This is helpful when wanting to locate atomic +blocks in a Hamiltonian or overlap matrix associated with a given pair of atoms. + +# Arguments +- `block_idxs::BlkIdx`: block indices specifying the blocks whose starts are to be returned. +- `atoms::Atoms`: atoms object of the target system. +- `basis_def::BasisDef`: corresponding basis set definition. + +# Returns +- `block_starts::Matrix`: A copy of `block_idxs` where the first & second rows now provide an + index specifying where the associated block starts in the Hamiltonian/overlap matrix. The + third row, if present, is left unchanged. +""" +function _block_starts(block_idxs::BlkIdx, atoms::Atoms, basis_def::BasisDef) + n_orbs = Dict(k=>sum(2v .+ 1) for (k,v) in basis_def) # N∘ orbitals per species + n_orbs_per_atom = [n_orbs[z] for z in atoms.Z] # N∘ of orbitals on each atom + block_starts = copy(block_idxs) + @views block_starts[1:2, :] = ( + cumsum(n_orbs_per_atom) - n_orbs_per_atom .+ 1)[block_idxs[1:2, :]] + + return block_starts +end + +""" + _sub_block_starts(z_1, z_2, s_i, s_j, basis_def) + + +Get the index of the first element of the sub-block formed between shells `s_i` and `s_j` +of species `z_1` and `z_2` respectively. The results of this method are commonly added to +those of `_block_starts` to give the first index of a desired sub-block in some arbitrary +Hamiltonian or overlap matrix. + +# Arguments +- `z_1::Int`: species on which shell `s_i` resides. +- `z_2::Int`: species on which shell `s_j` resides. +- `s_i::Int`: first shell of the sub-block. +- `s_j::Int`: second shell of the sub-block. +- `basis_def::BasisDef`: corresponding basis set definition. + +# Returns +- `sub_block_starts::Vector`: vector specifying the index in a `z_1`-`z_2` atom-block at which + the first element of the `s_i`-`s_j` sub-block is found. + +""" +function _sub_block_starts(z_1, z_2, s_i::I, s_j::I, basis_def::BasisDef) where I<:Integer + sub_block_starts = Vector{I}(undef, 2) + sub_block_starts[1] = sum(2basis_def[z_1][1:s_i-1] .+ 1) + 1 + sub_block_starts[2] = sum(2basis_def[z_2][1:s_j-1] .+ 1) + 1 + return sub_block_starts +end + + +# ╭─────────────────────┬────────────────╮ +# │ Matrix Manipulation │ BlkIdx:Filters │ +# ╰─────────────────────┴────────────────╯ +# Filtering operators to help with differentiating between and selecting specific block +# indices or collections thereof. + +""" + filter_on_site_idxs(block_idxs) + +Filter out all but the on-site block indices. + +# Arguments +- `block_idxs::BlkIdx`: block index matrix to be filtered. + +# Returns +- `filtered_block_idxs::BlkIdx`: copy of `block_idxs` with only on-site block indices remaining. + +""" +function filter_on_site_idxs(block_idxs::BlkIdx) + # When `block_idxs` is a 2×N matrix then the only requirement for an interaction to be + # on-site is that the two atomic indices are equal to one another. If `block_idxs` is a + # 3×N matrix then the interaction must lie within the origin cell. + if size(block_idxs, 1) == 2 + return block_idxs[:, block_idxs[1, :] .≡ block_idxs[2, :]] + else + return block_idxs[:, block_idxs[1, :] .≡ block_idxs[2, :] .&& block_idxs[3, :] .== 1] + end +end + +""" + filter_off_site_idxs(block_idxs) + +Filter out all but the off-site block indices. + +# Arguments +- `block_idxs::BlkIdx`: block index matrix to be filtered. + +# Returns +- `filtered_block_idxs::BlkIdx`: copy of `block_idxs` with only off-site block indices remaining. + +""" +function filter_off_site_idxs(block_idxs::BlkIdx) + if size(block_idxs, 1) == 2 # Locate where atomic indices not are equal + return block_idxs[:, block_idxs[1, :] .≠ block_idxs[2, :]] + else # Find where atomic indices are not equal or the cell≠1. + return block_idxs[:, block_idxs[1, :] .≠ block_idxs[2, :] .|| block_idxs[3, :] .≠ 1] + end +end + + +""" + filter_upper_idxs(block_idxs) + +Filter out atomic-blocks that reside in the lower triangle of the matrix. This is useful +for removing duplicate data in some cases. (blocks on the diagonal are retained) + +# Arguments +- `block_idxs::BlkIdx`: block index matrix to be filtered. + +# Returns +- `filtered_block_idxs::BlkIdx`: copy of `block_idxs` with only blocks from the upper + triangle remaining. +""" +filter_upper_idxs(block_idxs::BlkIdx) = block_idxs[:, block_idxs[1, :] .≤ block_idxs[2, :]] + + +""" + filter_lower_idxs(block_idxs) + +Filter out atomic-blocks that reside in the upper triangle of the matrix. This is useful +for removing duplicate data in some cases. (blocks on the diagonal are retained) + +# Arguments +- `block_idxs::BlkIdx`: block index matrix to be filtered. + +# Returns +- `filtered_block_idxs::BlkIdx`: copy of `block_idxs` with only blocks from the lower + triangle remaining. +""" +filter_lower_idxs(block_idxs::BlkIdx) = block_idxs[:, block_idxs[1, :] .≥ block_idxs[2, :]] + +""" + filter_idxs_by_bond_distance(block_idxs, distance, atoms[, images]) + +Filters out atomic-blocks associated with interactions between paris of atoms that are +separated by a distance greater than some specified cutoff. + + +# Arguments +- `block_idx::BlkIdx`: block index matrix holding the off-site atom-block indices that + are to be filtered. +- `distance::AbstractFloat`: maximum bond distance; atom blocks representing interactions + between atoms that are more than `distance` away from one another will be filtered out. +- `atoms::Atoms`: system in which the specified blocks are located. +- `images::Matrix{<:Integer}`: cell translation index lookup list, this is only relevant + when `block_idxs` supplies and cell index value. The cell translation index for the iᵗʰ + state will be taken to be `images[block_indxs[i, 3]]`. + +# Returns +- `filtered_block_idx::BlkIdx`: a copy of `block_idx` in which all interactions associated + with interactions separated by a distance greater than `distance` have been filtered out. + +# Notes +It is only appropriate to use this method to filter `block_idx` instance in which all +indices pertain to off-site atom-blocks. + +""" +function filter_idxs_by_bond_distance( + block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, + images::Union{Nothing, AbstractMatrix{<:Integer}}=nothing) + + let mask = _distance_mask(block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, images) + return block_idxs[:, mask] + end +end + + +function _filter_block_indices(block_indices, focus::AbstractVector) + mask = ∈(focus).(block_indices[1, :]) .& ∈(focus).(block_indices[2, :]) + return block_indices[:, mask] + +end + +function _filter_block_indices(block_indices, focus::AbstractMatrix) + mask = ∈(collect(eachcol(focus))).(collect(eachcol(block_indices[1:2, :]))) + return block_indices[:, mask] +end + +# ╭─────────────────────┬─────────────────╮ +# │ Matrix Manipulation │ Data Assignment │ +# ╰─────────────────────┴─────────────────╯ + +# The `_get_blocks!` methods are used to collect either atomic-blocks or sub-blocks from +# a Hamiltonian or overlap matrix. + +""" + _get_blocks!(src, target, starts) + +Gather blocks from a `src` matrix and store them in the array `target`. This method is +to be used when gathering data from two-dimensional, single k-point, matrices. + +# Arguments +- `src::Matrix`: matrix from which data is to be drawn. +- `target::Array`: array in which data should be stored. +- `starts::BlkIdx`: a matrix specifying where each target block starts. + +# Notes +The size of each block to ge gathered is worked out form the size of `target`. + +""" +function _get_blocks!(src::Matrix{T}, target::AbstractArray{T, 3}, starts::BlkIdx) where T + for i in 1:size(starts, 2) + @views target[:, :, i] = src[ + starts[1, i]:starts[1, i] + size(target, 1) - 1, + starts[2, i]:starts[2, i] + size(target, 2) - 1] + end +end + +""" + _get_blocks!(src, target, starts) + +Gather blocks from a `src` matrix and store them in the array `target`. This method is +to be used when gathering data from three-dimensional, real-space, matrices. + +# Arguments +- `src::Matrix`: matrix from which data is to be drawn. +- `target::Array`: array in which data should be stored. +- `starts::BlkIdx`: a matrix specifying where each target block starts. Note that in + this case the cell index, i.e. the third row, specifies the cell index. + +# Notes +The size of each block to ge gathered is worked out form the size of `target`. + +""" +function _get_blocks!(src::AbstractArray{T, 3}, target::AbstractArray{T, 3}, starts::BlkIdx) where T + for i in 1:size(starts, 2) + @views target[:, :, i] = src[ + starts[1, i]:starts[1, i] + size(target, 1) - 1, + starts[2, i]:starts[2, i] + size(target, 2) - 1, + starts[3, i]] + end +end + +# The `_set_blocks!` methods perform the inverted operation of their `_get_blocks!` +# counterparts as they place data **into** the Hamiltonian or overlap matrix. + +""" + _set_blocks!(src, target, starts) + +Scatter blocks from the `src` matrix into the `target`. This method is to be used when +assigning data to two-dimensional, single k-point, matrices. + +# Arguments +- `src::Matrix`: matrix from which data is to be drawn. +- `target::Array`: array in which data should be stored. +- `starts::BlkIdx`: a matrix specifying where each target block starts. + +# Notes +The size of each block to ge gathered is worked out form the size of `target`. + +""" +function _set_blocks!(src::AbstractArray{T, 3}, target::Matrix{T}, starts::BlkIdx) where T + for i in 1:size(starts, 2) + @views target[ + starts[1, i]:starts[1, i] + size(src, 1) - 1, + starts[2, i]:starts[2, i] + size(src, 2) - 1, + ] = src[:, :, i] + end +end + +""" + _set_blocks!(src, target, starts) + +Scatter blocks from the `src` matrix into the `target`. This method is to be used when +assigning data to three-dimensional, real-space, matrices. + +# Arguments +- `src::Matrix`: matrix from which data is to be drawn. +- `target::Array`: array in which data should be stored. +- `starts::BlkIdx`: a matrix specifying where each target block starts. Note that in + this case the cell index, i.e. the third row, specifies the cell index. + +# Notes +The size of each block to ge gathered is worked out form the size of `target`. + +""" +function _set_blocks!(src::AbstractArray{T, 3}, target::AbstractArray{T, 3}, starts::BlkIdx) where T + for i in 1:size(starts, 2) + @views target[ + starts[1, i]:starts[1, i] + size(src, 1) - 1, + starts[2, i]:starts[2, i] + size(src, 2) - 1, + starts[3, i] + ] = src[:, :, i] + end +end + + + + +""" + get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def) + +Collect sub-blocks of a given type from select atom-blocks in a provided matrix. + +This method will collect, from `matrix`, the `s_i`-`s_j` sub-block of each atom-block +listed in `block_idxs`. It is assumed that all atom-blocks are between identical pairs +of species. + +# Arguments +- `matrix::Array`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `block_idxs::BlkIdx`: atomic-blocks from which sub-blocks are to be gathered. +- `s_i::Int`: first shell +- `s_j::Int`: second shell +- `atoms::Atoms`: target system's `JuLIP.Atoms` objects +- `basis_def:BasisDef`: corresponding basis set definition object (`BasisDef`) + +# Returns +- `sub_blocks`: an array containing the collected sub-blocks. + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +function get_sub_blocks(matrix::AbstractArray{T}, block_idxs::BlkIdx, s_i, s_j, atoms::Atoms, basis_def) where T + z_1, z_2 = atoms.Z[block_idxs[1:2, 1]] + + # Identify where each target block starts (first column and row) + starts = _block_starts(block_idxs, atoms, basis_def) + + # Shift `starts` so it points to the start of the **sub-blocks** rather than the block + starts[1:2, :] .+= _sub_block_starts(z_1, z_2, s_i, s_j, basis_def) .- 1 + + data = Array{T, 3}( # Array in which the resulting sub-blocks are to be collected + undef, 2basis_def[z_1][s_i] + 1, 2basis_def[z_2][s_j] + 1, size(block_idxs, 2)) + + # Carry out the assignment operation. + _get_blocks!(matrix, data, starts) + + return data +end + + +""" + set_sub_blocks(matrix, values, block_idxs, s_i, s_j, atoms, basis_def) + +Place sub-block data from `values` representing the interaction between shells `s_i` & +`s_j` into the matrix at the atom-blocks listed in `block_idxs`. This is this performs +the inverse operation to `set_sub_blocks`. + +# Arguments +- `matrix::Array`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `values::Array`: sub-block values. +- `block_idxs::BlkIdx`: atomic-blocks from which sub-blocks are to be gathered. +- `s_i::Int`: first shell +- `s_j::Int`: second shell +- `atoms::Atoms`: target system's `JuLIP.Atoms` objects +- `basis_def:BasisDef`: corresponding basis set definition object (`BasisDef`) + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +function set_sub_blocks!(matrix::AbstractArray, values, block_idxs::BlkIdx, s_i, s_j, atoms::Atoms, basis_def) + + if size(values, 3) != size(block_idxs, 2) + throw(DimensionMismatch( + "The last dimensions of `values` & `block_idxs` must be of the same length.")) + end + + z_1, z_2 = atoms.Z[block_idxs[1:2, 1]] + + # Identify where each target block starts (first column and row) + starts = _block_starts(block_idxs, atoms, basis_def) + + # Shift `starts` so it points to the start of the **sub-blocks** rather than the block + starts[1:2, :] .+= _sub_block_starts(z_1, z_2, s_i, s_j, basis_def) .- 1 + + # Carry out the scatter operation. + _set_blocks!(values, matrix, starts) +end + + +""" + get_blocks(matrix, block_idxs, atoms, basis_def) + +Collect, from `matrix`, the blocks listed in `block_idxs`. + +# Arguments +- `matrix::Array`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `block_idxs::BlkIdx`: the atomic-blocks to gathered. +- `atoms::Atoms`: target system's `JuLIP.Atoms` objects +- `basis_def:BasisDef`: corresponding basis set definition object (`BasisDef`) + +# Returns +- `sub_blocks`: an array containing the collected sub-blocks. + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +function get_blocks(matrix::AbstractArray{T}, block_idxs::BlkIdx, atoms::Atoms, basis_def) where T + z_1, z_2 = atoms.Z[block_idxs[1:2, 1]] + + # Identify where each target block starts (first column and row) + starts = _block_starts(block_idxs, atoms, basis_def) + + data = Array{T, 3}( # Array in which the resulting blocks are to be collected + undef, sum(2basis_def[z_1].+ 1), sum(2basis_def[z_2].+ 1), size(block_idxs, 2)) + + # Carry out the assignment operation. + _get_blocks!(matrix, data, starts) + + return data +end + + +""" + set_sub_blocks(matrix, values, block_idxs, s_i, s_j, atoms, basis_def) + +Place atom-block data from `values` into the matrix at the atom-blocks listed in `block_idxs`. +This is this performs the inverse operation to `set_blocks`. + +# Arguments +- `matrix::Array`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `values::Array`: sub-block values. +- `block_idxs::BlkIdx`: atomic-blocks from which sub-blocks are to be gathered. +- `s_i::Int`: first shell +- `s_j::Int`: second shell +- `atoms::Atoms`: target system's `JuLIP.Atoms` objects +- `basis_def:BasisDef`: corresponding basis set definition object (`BasisDef`) + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +function set_blocks!(matrix::AbstractArray, values, block_idxs::BlkIdx, atoms::Atoms, basis_def) + + if size(values, 3) != size(block_idxs, 2) + throw(DimensionMismatch( + "The last dimensions of `values` & `block_idxs` must be of the same length.")) + end + + # Identify where each target block starts (first column and row) + starts = _block_starts(block_idxs, atoms, basis_def) + + # Carry out the scatter operation. + _set_blocks!(values, matrix, starts) +end + + +""" + locate_and_get_sub_blocks(matrix, z_1, z_2, s_i, s_j, atoms, basis_def) + +Collects sub-blocks from the supplied matrix that correspond to off-site interactions +between the `s_i`'th shell on species `z_1` and the `s_j`'th shell on species `z_2`. + +# Arguments +- `matrix`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `z_1`: 1ˢᵗ species (atomic number) +- `z_2`: 2ⁿᵈ species (atomic number) +- `s_i`: shell on 1ˢᵗ species +- `s_j`: shell on 2ⁿᵈ species +- `atoms`: target system's `JuLIP.Atoms` objects +- `basis_def`: corresponding basis set definition object (`BasisDef`) + +# Returns +- `sub_blocks`: an Nᵢ×Nⱼ×M array containing the collected sub-blocks; where Nᵢ & Nⱼ are + the number of orbitals on the `s_i`'th & `s_j`'th shells of species `z_1` & `z_2` + respectively, and M is the N∘ of sub-blocks found. +- `block_idxs`: A matrix specifying which atomic block each sub-block in `sub_blocks` + was taken from. If `matrix` is a 3D real space matrix then `block_idxs` will also + include the cell index. + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +locate_and_get_sub_blocks(matrix, z_1, z_2, s_i, s_j, atoms::Atoms, basis_def; focus=nothing, no_reduce=false) = _locate_and_get_sub_blocks(matrix, z_1, z_2, s_i, s_j, atoms, basis_def; focus=focus, no_reduce=no_reduce) + +""" + locate_and_get_sub_blocks(matrix, z, s_i, s_j, atoms, basis_def) + +Collects sub-blocks from the supplied matrix that correspond to on-site interactions +between the `s_i`'th & `s_j`'th shells on species `z`. + +# Arguments +- `matrix`: matrix from which to draw. This may be in either the 3D real-space N×N×C form + or the single k-point N×N form; where N & C are the N∘ of orbitals & images respectively. +- `z_1`: target species (atomic number) +- `s_i`: 1ˢᵗ shell +- `s_j`: 2ⁿᵈ shell +- `atoms`: target system's `JuLIP.Atoms` objects +- `basis_def`: corresponding basis set definition object (`BasisDef`) + +# Returns +- `sub_blocks`: an Nᵢ×Nⱼ×M array containing the collected sub-blocks; where Nᵢ & Nⱼ are + the number of orbitals on the `s_i`'th & `s_j`'th shells of species `z_1` & `z_2` + respectively, and M is the N∘ of sub-blocks found. +- `block_idxs`: A matrix specifying which atomic block each sub-block in `sub_blocks` + was taken from. If `matrix` is a 3D real space matrix then `block_idxs` will also + include the cell index. + +# Notes +If `matrix` is supplied in its 3D real-space form then it is imperative to ensure that +the origin cell is first. +""" +locate_and_get_sub_blocks(matrix, z, s_i, s_j, atoms::Atoms, basis_def; focus=nothing, kwargs...) = _locate_and_get_sub_blocks(matrix, z, s_i, s_j, atoms, basis_def; focus=focus) + +# Multiple dispatch is used to avoid the type instability in `locate_and_get_sub_blocks` +# associated with the creation of the `block_idxs` variable. It is also used to help +# distinguish between on-site and off-site collection operations. The following +# `_locate_and_get_sub_blocks` functions differ only in how they construct `block_idxs`. + +# Off site _locate_and_get_sub_blocks functions +function _locate_and_get_sub_blocks(matrix::AbstractArray{T, 2}, z_1, z_2, s_i, s_j, atoms::Atoms, basis_def; + focus=nothing, no_reduce=false) where T + block_idxs = atomic_block_idxs(z_1, z_2, atoms.Z) + + if !isnothing(focus) + block_idxs = _filter_block_indices(block_idxs, focus) + end + + block_idxs = filter_off_site_idxs(block_idxs) + + # Duplicate blocks present when gathering off-site homo-atomic homo-orbital interactions + # must be purged. + if (z_1 == z_2) && (s_i == s_j) && !no_reduce + block_idxs = filter_upper_idxs(block_idxs) + end + + return get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def), block_idxs +end + +function _locate_and_get_sub_blocks(matrix::AbstractArray{T, 3}, z_1, z_2, s_i, s_j, atoms::Atoms, basis_def; + focus=nothing, no_reduce=false) where T + block_idxs = atomic_block_idxs(z_1, z_2, atoms.Z) + + if !isnothing(focus) + block_idxs = _filter_block_indices(block_idxs, focus) + end + + block_idxs = repeat_atomic_block_idxs(block_idxs, size(matrix, 3)) + block_idxs = filter_off_site_idxs(block_idxs) + + ############################# Chen ########################## + if size(block_idxs, 2) == 0 + return zeros(basis_def[z_1][s_i] * 2 + 1, basis_def[z_2][s_j] * 2 + 1, 0), block_idxs + end + ############################################################# + + if (z_1 == z_2) && (s_i == s_j) && !no_reduce + block_idxs = filter_upper_idxs(block_idxs) + end + + return get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def), block_idxs +end + +# On site _locate_and_get_sub_blocks functions +function _locate_and_get_sub_blocks(matrix::AbstractArray{T, 2}, z, s_i, s_j, atoms::Atoms, basis_def; focus=nothing) where T + block_idxs = atomic_block_idxs(z, z, atoms.Z) + + if !isnothing(focus) + block_idxs = _filter_block_indices(block_idxs, focus) + end + + block_idxs = filter_on_site_idxs(block_idxs) + + return get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def), block_idxs +end + +function _locate_and_get_sub_blocks(matrix::AbstractArray{T, 3}, z, s_i, s_j, atoms::Atoms, basis_def; focus=nothing) where T + block_idxs = atomic_block_idxs(z, z, atoms.Z) + + if !isnothing(focus) + block_idxs = _filter_block_indices(block_idxs, focus) + end + + block_idxs = filter_on_site_idxs(block_idxs) + block_idxs = repeat_atomic_block_idxs(block_idxs, 1) + + return get_sub_blocks(matrix, block_idxs, s_i, s_j, atoms, basis_def), block_idxs +end + + + +# ╭─────────────────────┬─────────────────────────────────────────╮ +# │ Matrix Manipulation │ BlkIdx:Miscellaneous Internal Functions │ +# ╰─────────────────────┴─────────────────────────────────────────╯ + +# This function is tasked with constructing the boolean mask used to filter out atom-blocks +# associated with interactions between pairs of atoms separated by a distance greater than +# some specified cutoff. Note that this is intended for internal use only and is primarily +# used by the `filter_idxs_by_bond_distance` method. +function _distance_mask( + block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, + images::Union{Nothing, AbstractMatrix{<:Integer}}=nothing) + + if isnothing(images) + l = atoms.cell' + l_inv = pinv(l) + mask = Vector{Bool}(undef, size(block_idxs, 2)) + for i=1:size(block_idxs, 2) + mask[i] = norm(_wrap(atoms.X[block_idxs[2, i]] - atoms.X[block_idxs[1, i]], l, l_inv)) <= distance + end + else + shift_vectors = collect(eachrow(images' * atoms.cell)) + mask = norm.(atoms.X[block_idxs[2, :]] - atoms.X[block_idxs[1, :]] + shift_vectors[block_idxs[3, :]]) .<= distance + end + return mask +end + +# Internal method used exclusively by _distance_mask. +function _wrap(x_vec, l, l_inv) + x_vec_frac = l_inv * x_vec + return l * (x_vec_frac .- round.(x_vec_frac)) +end + + +end diff --git a/examples/H2O/python_interface/small/src/datastructs.jl b/examples/H2O/python_interface/small/src/datastructs.jl new file mode 100644 index 0000000..c3620e5 --- /dev/null +++ b/examples/H2O/python_interface/small/src/datastructs.jl @@ -0,0 +1,478 @@ +module DataSets +using ACEhamiltonians +using Random: shuffle +using LinearAlgebra: norm +using JuLIP: Atoms +using ACEhamiltonians.MatrixManipulation: BlkIdx, _distance_mask +using ACEhamiltonians.States: reflect, _get_states, _neighbours, _locate_minimum_image, _locate_target_image + +import ACEhamiltonians.Parameters: ison + +# ╔══════════╗ +# ║ DataSets ║ +# ╚══════════╝ + +# `AbstractFittingDataSet` based structures contain all data necessary to perform a fit. +abstract type AbstractFittingDataSet end + +export DataSet, filter_sparse, filter_bond_distance, get_dataset, AbstractFittingDataSet, random_split, random_sample, random_distance_sample + +""" + DataSet(values, blk_idxs, states) + +A structure for storing collections of sub-blocks & states representing the environments +from which they came. These are intended to be used during the model fitting process. +While the block index matrix is not strictly necessary to the fitting process it is useful +enough to merit inclusion. + +# Fields +- `values::AbstractArray`: an i×j×n matrix containing extracted the sub-block values; + where i & j are the number of orbitals associate with the two shells, and n the number + of sub-blocks. +- `blk_idxs::BlkIdx`: a block index matrix specifying from which block each sub-block in + `values` was taken from. This acts mostly as meta-data. +- `states`: states representing the atomic-block from which each sub-block was taken. + +# Notes +Some useful properties of `DataSet` instances have been highlighted below: +- Addition can be used to combine one or more datasets. +- Indexing can be used to take a sub-set of a dataset. +- `size` acts upon `DataSet.values`. +- `length` returns `size(DataSet.values, 3)`, i.e. the number of sub-blocks. +- The adjoint of a `DataSet` will return a copy where: + - `values` is the hermitian conjugate of its parent. + - atomic indices `blk_idxs` have been exchanged, i.e. rows 1 & 2 are swapped. + - `states` are reflected (where appropriate) . + +""" +struct DataSet{V<:AbstractArray{<:Any, 3}, B<:BlkIdx, S<:AbstractVector} <: AbstractFittingDataSet + values::V + blk_idxs::B + states::S +end + +# ╭──────────┬───────────────────────╮ +# │ DataSets │ General Functionality │ +# ╰──────────┴───────────────────────╯ + +function Base.show(io::IO, data_set::T) where T<:AbstractFittingDataSet + F = valtype(data_set.values) + mat_shape = join(size(data_set), '×') + print(io, "$(nameof(T)){$F}($mat_shape)") +end + +function Base.:(==)(x::T, y::T) where T<:DataSet + return x.blk_idxs == y.blk_idxs && x.values == y.values && x.states == y.states +end + +# Two or more AbstractFittingDataSet entities can be added together via the `+` operator +Base.:(+)(x::T, y::T) where T<:AbstractFittingDataSet = T( + (cat(getfield(x, fn), getfield(y, fn), dims=ndims(getfield(x, fn))) + for fn in fieldnames(T))...) + +# Allow AbstractFittingDataSet objects to be indexed so that a subset may be selected. This is +# mostly used when filtering data. +function Base.getindex(data_set::T, idx::UnitRange) where T<:AbstractFittingDataSet + return T((_getindex_helper(data_set, fn, idx) for fn in fieldnames(T))...) +end + +function Base.getindex(data_set::T, idx::Vararg{<:Integer}) where T<:AbstractFittingDataSet + return data_set[collect(idx)] +end + +function Base.getindex(data_set::T, idx) where T<:AbstractFittingDataSet + return T((_getindex_helper(data_set, fn, idx) for fn in fieldnames(T))...) +end + +function _getindex_helper(data_set, fn, idx) + # This abstraction helps to speed up calls to Base.getindex. + return let data = getfield(data_set, fn) + collect(selectdim(data, ndims(data), idx)) + end +end + +Base.lastindex(data_set::AbstractFittingDataSet) = length(data_set) +Base.length(data_set::AbstractFittingDataSet) = size(data_set, 3) +Base.size(data_set::AbstractFittingDataSet, dim::Integer) = size(data_set.values, dim) +Base.size(data_set::AbstractFittingDataSet) = size(data_set.values) + +""" +Return a copy of the provided `DataSet` in which i) the sub-blocks (i.e. the `values` +field) have set to their adjoint, ii) the atomic indices in the `blk_idxs` field have +been exchanged, and iii) all `BondStates` in the `states` field have been reflected. +""" +function Base.adjoint(data_set::T) where T<:DataSet + swapped_blk_idxs = copy(data_set.blk_idxs); _swaprows!(swapped_blk_idxs, 1, 2) + return T( + # Transpose and take the complex conjugate of the sub-blocks + conj(permutedims(data_set.values, (2, 1, 3))), + # Swap the atomic indices in `blk_idxs` i.e. atom block 1-2 is now block 2-1 + swapped_blk_idxs, + # Reflect bond states across the mid-point of the bond. + [reflect.(i) for i in data_set.states]) +end + +function _swaprows!(matrix::AbstractMatrix, i::Integer, j::Integer) + @inbounds for k = 1:size(matrix, 2) + matrix[i, k], matrix[j, k] = matrix[j, k], matrix[i, k] + end +end + + +""" + ison(dataset) + +Return a boolean indicating whether the `DataSet` entity contains on-site data. +""" +ison(x::T) where T<:DataSet = ison(x.states[1][1]) + + + +# ╭──────────┬─────────╮ +# │ DataSets │ Filters │ +# ╰──────────┴─────────╯ + + +""" + filter_sparse(dataset[, threshold=1E-8]) + +Filter out data-points with fully sparse sub-blocks. Only data-points with sub-blocks +containing at least one element whose absolute value is greater than the specified +threshold will be retained. + +# Arguments +- `dataset::AbstractFittingDataSet`: the dataset that is to be filtered. +- `threshold::AbstractFloat`: value below which an element will be considered to be + sparse. This will defaudfgflt to 1E-8 if omitted. + +# Returns +- `filtered_dataset`: a copy of the, now filtered, dataset. + +""" +function filter_sparse(dataset::AbstractFittingDataSet, threshold::AbstractFloat=1E-8) + return dataset[vec(any(abs.(dataset.values) .>= threshold, dims=(1,2)))] +end + + +""" +filter_bond_distance(dataset, distance) + +Filter out data-points whose bond-vectors exceed the supplied cutoff distance. This allows +for states that will not be used to be removed during the data selection process rather +than during evaluation. Note that this is only applicable to off-site datasets. + +# Arguments +- `dataset::AbstractFittingDataSet`: tje dataset that is to be filtered. +- `distance::AbstractFloat`: data-points with bond distances exceeding this value will be + filtered out. + +# Returns +- `filtered_dataset`: a copy of the, now filtered, dataset. + +""" +function filter_bond_distance(dataset::AbstractFittingDataSet, distance::AbstractFloat) + if length(dataset) != 0 # Don't try and check the state unless one exists. + # Throw an error if the user tries to apply the a bond filter to an on-site dataset + # where there is no bond to be filtered. + @assert !ison(dataset) "Only applicable to off-site datasets" + end + return dataset[[norm(i[1].rr0) <= distance for i in dataset.states]] +end + +""" + random_sample(dataset, n[; with_replacement]) + +Select a random subset of size `n` from the supplied `dataset`. + +# Arguments +- `dataset::AbstractFittingDataSet`: dataset to be sampled. +- `n::Integer`: number of sample points +- `with_replacement::Bool`: if true (default) then duplicate samples will not be drawn. + +# Returns +- `sample_dataset::AbstractFittingDataSet`: a randomly selected subset of `dataset` + of size `n`. +""" +function random_sample(dataset::AbstractFittingDataSet, n::Integer; with_replacement=true) + if with_replacement + @assert n ≤ length(dataset) "Sample size cannot exceed dataset size" + return dataset[shuffle(1:length(dataset))[1:n]] + else + return dataset[rand(1:length(dataset), n)] + end +end + +""" + random_split(dataset, x) + +Split the `dataset` into a pair of randomly selected subsets. + +# Arguments +- `dataset::AbstractFittingDataSet`: dataset to be partitioned. +- `x::AbstractFloat`: partition ratio; the fraction of samples to be + placed into the first subset. + +""" +function random_split(dataset::DataSet, x::AbstractFloat) + split_index = Int(round(length(dataset)x)) + idxs = shuffle(1:length(dataset)) + return dataset[idxs[1:split_index]], dataset[idxs[split_index + 1:end]] + +end + + + +""" + random_distance_sample(dataset, n[; with_replacement=true, rng=rand]) + +Select a random subset of size `n` from the supplied `dataset` via distances. + +This functions in a similar manner to `random_sample` but selects points based on their +bond length. This is intended to ensure a more even sample. + +# Arguments +- `dataset::AbstractFittingDataSet`: dataset to be sampled. +- `n::Integer`: number of sample points +- `with_replacement::Bool`: if true (default) then duplicate samples will not be drawn. +- `rng::Function`: function to generate random numbers. + +""" +function random_distance_sample(dataset, n; with_replacement=true, rng=rand) + + @assert length(dataset) ≥ n + # Construct an array storing the bond lengths of each state, sort it and generate + # the sort permutation array to allow elements in r̄ to be mapped back to their + # corresponding state in the dataset. + r̄ = [norm(i[1].rr0) for i in dataset.states] + r̄_perm = sortperm(r̄) + r̄[:] = r̄[r̄_perm] + m = length(dataset) + + # Work out the maximum & minimum bond distance as well as the range + r_min = minimum(r̄) + r_max = maximum(r̄) + r_range = r_max - r_min + + # Preallocate transient index storage array + selected_idxs = zeros(Int, n) + + for i=1:n + # Select a random distance r ∈ [min(r̄), max(r̄)] + r = rng() * r_range + r_min + + # Identify the first element of r̄ ≥ r and the last element ≤ r + idxs = searchsorted(r̄, r) + idx_i, idx_j = minmax(first(idxs), last(idxs)) + + # Expand the window by one each side, but don't exceed the array's bounds + idx_i = max(idx_i-1, 1) + idx_j = min(idx_j+1, m) + + # Identify which element is closest to r and add the associated index + # to the selected index array. + idx = last(findmin(j->abs(r-j), r̄[idx_i:idx_j])) + idx_i - 1 + + # If this state has already been selected then replace it with the next + # closest one. + + if with_replacement && r̄_perm[idx] ∈ selected_idxs + + # Identify the indices corresponding to the first states with longer and shorter + # bond lengths than the current, duplicate, state. + lb = max(idx-1, 1) + ub = min(idx+1, m) + + while lb >= 1 && r̄_perm[lb] ∈ selected_idxs + lb -= 1 + end + + while ub <= m && r̄_perm[ub] ∈ selected_idxs + ub += 1 + end + + # Select the closets valid state + new_idx = 0 + dx = Inf + + if lb != 0 && lb != idx + new_idx = lb + dx = abs(r̄[lb] - r) + end + + if ub != m+1 && (abs(r̄[ub] - r) < dx) && ub != idx + new_idx = ub + end + + idx = new_idx + + end + + + selected_idxs[i] = r̄_perm[idx] + end + + return dataset[selected_idxs] + +end + +# ╭──────────┬───────────╮ +# │ DataSets │ Factories │ +# ╰──────────┴───────────╯ +# This section will hold the factory methods responsible for automating the construction +# of `DataSet` entities. The `get_dataset` methods will be implemented once the `AHSubModel` +# structures have been implemented. + + +# This is just a reimplementation of `filter_idxs_by_bond_distance` that allows for `blocks` +# to get filtered as well +function _filter_bond_idxs(blocks, block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, images) + let mask = _distance_mask(block_idxs::BlkIdx, distance::AbstractFloat, atoms::Atoms, images) + return blocks[:, :, mask], block_idxs[:, mask] + end +end + + + +""" +# Todo + - This could be made more performant. +""" +function _filter_sparse(values, block_idxs, tolerance) + mask = vec(any(abs.(values) .>= tolerance, dims=(1,2))) + return values[:, :, mask], block_idxs[:, mask] +end + + +""" +get_dataset(matrix, atoms, submodel, basis_def[, images; tolerance, filter_bonds, focus]) + +Construct and return a `DataSet` entity containing the minimal data required to fit a +`AHSubModel` entity. + +# Arguments +- `matrix`: matrix from which sub-blocks are to be gathered. +- `atoms`: atoms object representing the system to which the matrix pertains. +- `submodel`: `AHSubModel` entity for the desired sub-block; the `id` field is used to identify + which sub-blocks should be gathered and how they should be gathered. +- `basis_def`: a basis definition specifying what orbitals are present on each species. +- `images`: cell translation vectors associated with the matrix, this is only required + when the `matrix` is in the three-dimensional real-space form. + +# Keyword Arguments +- `tolerance`: specifying a float value will enact sparse culling in which only sub-blocks + with at least one element greater than the permitted tolerance will be included. This + is used to remove all zero, or near zero, sub-blocks. This is disabled by default. +- `filter_bonds`: if set to `true` then only interactions within the permitted cutoff + distance will be returned. This is only valid off-site interactions and is disabled + by default. The cut-off distance is extracted from the bond envelope contained within + the `basis` object. This defaults to `true` for off-site interactions. +- `focus`: the `focus` argument allows the `get_dataset` call to return only a sub-set + of possible data-points. If a vector of atomic indices is provided then only on/off- + site sub-blocks for/between those atoms will be returned; i.e. [1, 2] would return + on-sites 1-on, 2-on and off-sites 1-1-off, 1-2-off, 2-1-off, & 2-2-off. If a matrix + is provided, like so [1 2; 3 4] then only the associated off-site sub-blocks will be + returned, i.e. 1-2-off and 3-4-off. Note that the matrix form is only valid when + retrieving off-site sub-blocks. +- `no_reduce`: by default symmetrically redundant sub-blocks will not be gathered; this + equivalent blocks from be extracted from the upper and lower triangles of the Hamiltonian + and overlap matrices. This will default to `false`, however it is sometimes useful to + disable this when debugging. + +# Todo: + - Warn that only the upper triangle is returned and discuss how this effects "focus". + +""" +function get_dataset( + matrix::AbstractArray, atoms::Atoms, submodel::AHSubModel, basis_def, + images::Union{Matrix, Nothing}=nothing; + tolerance::Union{Nothing, <:AbstractFloat}=nothing, filter_bonds::Bool=true, + focus::Union{Vector{<:Integer}, Matrix{<:Integer}, Nothing}=nothing, + no_reduce=false) + + if ndims(matrix) == 3 && isnothing(images) + throw("`images` must be provided when provided with a real space `matrix`.") + end + + # Locate and gather the sub-blocks correspond the interaction associated with `basis` + blocks, block_idxs = locate_and_get_sub_blocks(matrix, submodel.id..., atoms, basis_def; focus=focus, no_reduce=no_reduce) + + if !isnothing(focus) + mask = ∈(focus).(block_idxs[1, :]) .& ∈(focus).(block_idxs[2, :]) + block_idxs = block_idxs[:, mask] + blocks = blocks[:, :, mask] + end + + # If gathering off-site data and `filter_bonds` is `true` then remove data-points + # associated with interactions between atom pairs whose bond-distance exceeds the + # cutoff as specified by the bond envelope. This prevents having to construct states + # (which is an expensive process) for interactions which will just be deleted later + # on. Enabling this can save a non-trivial amount of time and memory. + if !ison(submodel) && filter_bonds + blocks, block_idxs = _filter_bond_idxs( + blocks, block_idxs, envelope(submodel).r0cut, atoms, images) + end + + if !isnothing(tolerance) # Filter out sparse sub-blocks; but only if instructed to + blocks, block_idxs = _filter_sparse(blocks, block_idxs, tolerance) + end + + # Construct states for each of the sub-blocks. + if ison(submodel) + # For on-site states the cutoff radius is provided; this results in redundant + # information being culled here rather than later on; thus saving on memory. + states = _get_states(block_idxs, atoms; r=radial(submodel).R.ru) + else + + ########################## Chen ########################## + if size(block_idxs, 2) == 0 + states = zeros(0) + else + states = _get_states(block_idxs, atoms, envelope(submodel), images) + end + ########################################################## + + # # For off-site states the basis' bond envelope must be provided. + # states = _get_states(block_idxs, atoms, envelope(submodel), images) + end + + # Construct and return the requested DataSet object + dataset = DataSet(blocks, block_idxs, states) + + return dataset +end + + + +""" +Construct a collection of `DataSet` instances storing the information required to fit +their associated `AHSubModel` entities. This convenience function will call the original +`get_dataset` method for each and every basis in the supplied model and return a +dictionary storing once dataset for each basis in the model. + +""" +function get_dataset( + matrix::AbstractArray, atoms::Atoms, model::Model, + images::Union{Matrix, Nothing}=nothing; kwargs...) + + basis_def = model.basis_definition + on_site_data = Dict( + basis.id => get_dataset(matrix, atoms, basis, basis_def, images; kwargs...) + for basis in values(model.on_site_submodels)) + + off_site_data = Dict( + basis.id => get_dataset(matrix, atoms, basis, basis_def, images; kwargs...) + for basis in values(model.off_site_submodels)) + + return on_site_data, off_site_data +end + + + +end + +# Notes +# - The matrix and array versions of `get_dataset` could easily be combined. +# - The `get_dataset` method is likely to suffer from type instability issues as it is +# unlikely that Julia will know ahead of time whether the `DataSet` structure returned +# will contain on or off-states states; each having different associated structures. +# Thus type ambiguities in the `AHSubModel` structures should be alleviated. diff --git a/examples/H2O/python_interface/small/src/fitting.jl b/examples/H2O/python_interface/small/src/fitting.jl new file mode 100644 index 0000000..65a48fe --- /dev/null +++ b/examples/H2O/python_interface/small/src/fitting.jl @@ -0,0 +1,421 @@ +module Fitting +using HDF5, ACE, ACEbase, ACEhamiltonians, StaticArrays, Statistics, LinearAlgebra, SparseArrays, IterativeSolvers +using ACEfit: linear_solve, SKLEARN_ARD, SKLEARN_BRR +using HDF5: Group +using JuLIP: Atoms +using ACE: ACEConfig, evaluate, scaling, AbstractState, SymmetricBasis +using ACEhamiltonians.Common: number_of_orbitals +using ACEhamiltonians.Bases: envelope +using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma +using ACEatoms:AtomicNumber +using LowRankApprox: pqrfact + +using ACEhamiltonians: DUAL_BASIS_MODEL + + +export fit! + +# set abs(a::AtomicNumber) = 0 as it is called in the `scaling` function but should not change the output +Base.abs(a::AtomicNumber) = 0 # a.z + +# Once the bond inversion issue has been resolved the the redundant models will no longer +# be required. The changes needed to be made in this file to remove the redundant model +# are as follows: +# - Remove inverted state condition in single model `fit!` method. +# - `_assemble_ls` should take `AHSubModel` entities. +# - Remove inverted state condition from the various `predict` methods. + +# Todo: +# - Need to make sure that the acquire_B! function used by ACE does not actually modify the +# basis function. Otherwise there may be some issues with sharing basis functions. +# - ACE should be modified so that `valtype` inherits from Base. This way there should be +# no errors caused when importing it. +# - Remove hard coded matrix type from the predict function. + +# The _assemble_ls and _evaluate_real methods should be rewritten to use PseudoBlockArrays +# this will prevent redundant allocations from being made and will mean that _preprocessA +# and _preprocessY can also be removed. This should help speed up the code and should +# significantly improve memory usage. +function _preprocessA(A) + # Note; this function was copied over from the original ACEhamiltonians/fit.jl file. + + # S1: number of sites; S2: number of basis, SS1: 2L1+1, SS2: 2L2+1 + S1,S2 = size(A) + SS1,SS2 = size(A[1]) + A_temp = zeros(ComplexF64, S1*SS1*SS2, S2) + for i = 1:S1, j = 1:S2 + A_temp[SS1*SS2*(i-1)+1:SS1*SS2*i,j] = reshape(A[i,j],SS1*SS2,1) + end + return real(A_temp) +end + +function _preprocessY(Y) + # Note; this function was copied over from the original ACEhamiltonians/fit.jl file. + + Len = length(Y) + SS1,SS2 = size(Y[1]) + Y_temp = zeros(ComplexF64,Len*SS1*SS2) + for i = 1:Len + Y_temp[SS1*SS2*(i-1)+1:SS1*SS2*i] = reshape(Y[i],SS1*SS2,1) + end + return real(Y_temp) +end + + +function solve_ls(A, Y, λ, Γ, solver = "LSQR"; niter = 10, inner_tol = 1e-3) + # Note; this function was copied over from the original ACEhamiltonians/fit.jl file. + + A = _preprocessA(A) + Y = _preprocessY(Y) + + num = size(A)[2] + A = [A; λ*Γ] + Y = [Y; zeros(num)] + if solver == "QR" + return real(qr(A) \ Y) + elseif solver == "LSQR" + # The use of distributed arrays is still causing a memory leak. As such the following + # code has been disabled until further notice. + # Ad, Yd = distribute(A), distribute(Y) + # res = real(IterativeSolvers.lsqr(Ad, Yd; atol = 1e-6, btol = 1e-6)) + # close(Ad), close(Yd) + res = real(IterativeSolvers.lsqr(A, Y; atol = 1e-6, btol = 1e-6)) + return res + elseif solver == "ARD" + return linear_solve(SKLEARN_ARD(;n_iter = niter, tol = inner_tol), A, Y)["C"] + elseif solver == "BRR" + return linear_solve(SKLEARN_BRR(;n_iter = niter, tol = inner_tol), A, Y)["C"] + elseif solver == "RRQR" + AP = A / I + θP = pqrfact(A, rtol = inner_tol) \ Y + return I \ θP + elseif solver == "NaiveSolver" + return real((A'*A) \ (A'*Y)) + end + + end + + +function _ctran(l::Int64,m::Int64,μ::Int64) + if abs(m) ≠ abs(μ) + return 0 + elseif abs(m) == 0 + return 1 + elseif m > 0 && μ > 0 + return 1/sqrt(2) + elseif m > 0 && μ < 0 + return (-1)^m/sqrt(2) + elseif m < 0 && μ > 0 + return - im * (-1)^m/sqrt(2) + else + return im/sqrt(2) + end +end + +_ctran(l::Int64) = sparse(Matrix{ComplexF64}([ _ctran(l,m,μ) for m = -l:l, μ = -l:l ])) + +function _evaluate_real(Aval) + L1,L2 = size(Aval[1]) + L1 = Int((L1-1)/2) + L2 = Int((L2-1)/2) + C1 = _ctran(L1) + C2 = _ctran(L2) + return real([ C1 * Aval[i].val * C2' for i = 1:length(Aval)]) +end + +""" +""" +function _assemble_ls(basis::SymmetricBasis, data::T, enable_mean::Bool=false) where T<:AbstractFittingDataSet + # This will be rewritten once the other code has been refactored. + + # Should `A` not be constructed using `acquire_B!`? + + n₁, n₂, n₃ = size(data) + # Currently the code desires "A" to be an X×Y matrix of Nᵢ×Nⱼ matrices, where X is + # the number of sub-block samples, Y is equal to `size(bos.basis.A2Bmap)[1]`, and + # Nᵢ×Nⱼ is the sub-block shape; i.e. 3×3 for pp interactions. This may be refactored + # at a later data if this layout is not found to be strictly necessary. + cfg = ACEConfig.(data.states) + Aval = evaluate.(Ref(basis), cfg) + A = permutedims(reduce(hcat, _evaluate_real.(Aval)), (2, 1)) + + Y = [data.values[:, :, i] for i in 1:n₃] + + # Calculate the mean value x̄ + if enable_mean && n₁ ≡ n₂ && ison(data) + x̄ = mean(diag(mean(Y)))*I(n₁) + else + x̄ = zeros(n₁, n₂) + end + + Y .-= Ref(x̄) + return A, Y, x̄ + +end + + +################### +# Fitting Methods # +################### + +""" + fit!(submodel, data;[ enable_mean]) + +Fits a specified model with the supplied data. + +# Arguments +- `submodel`: a specified submodel that is to be fitted. +- `data`: data that the basis is to be fitted to. +- `enable_mean::Bool`: setting this flag to true enables a non-zero mean to be + used. +- `λ::AbstractFloat`: regularisation term to be used (default=1E-7). +- `solver::String`: solver to be used (default="LSQR") +""" +function fit!(submodel::T₁, data::T₂; enable_mean::Bool=false, λ=1E-7, solver="LSQR") where {T₁<:AHSubModel, T₂<:AbstractFittingDataSet} + + # Get the basis function's scaling factor + Γ = Diagonal(scaling(submodel.basis, 2)) + + # Setup the least squares problem + Φ, Y, x̄ = _assemble_ls(submodel.basis, data, enable_mean) + + # Assign the mean value to the basis set + submodel.mean .= x̄ + + # Solve the least squares problem and get the coefficients + + submodel.coefficients .= collect(solve_ls(Φ, Y, λ, Γ, solver)) + + @static if DUAL_BASIS_MODEL + if T₁<:AnisoSubModel + Γ = Diagonal(scaling(submodel.basis_i, 2)) + Φ, Y, x̄ = _assemble_ls(submodel.basis_i, data', enable_mean) + submodel.mean_i .= x̄ + submodel.coefficients_i .= collect(solve_ls(Φ, Y, λ, Γ, solver)) + end + end + + nothing +end + + +# Convenience function for appending data to a dictionary +function _append_data!(dict, key, value) + if haskey(dict, key) + dict[key] = dict[key] + value + else + dict[key] = value + end +end + + + +""" + fit!(model, systems;[ on_site_filter, off_site_filter, tolerance, recentre, refit, target]) + +Fits a specified model to the supplied data. + +# Arguments +- `model::Model`: Model to be fitted. +- `systems::Vector{Group}`: HDF5 groups storing data with which the model should + be fitted. +- `on_site_filter::Function`: the on-site `DataSet` entities will be passed through this + filter function prior to fitting; defaults `identity`. +- `off_site_filter::Function`: the off-site `DataSet` entities will be passed through this + filter function prior to fitting; defaults `identity`. +- `tolerance::AbstractFloat`: only sub-blocks where at least one value is greater than + or equal to `tolerance` will be fitted. This argument permits sparse blocks to be + ignored. +- `recentre::Bool`: Enabling this will re-wrap atomic coordinates to be consistent with + the geometry layout used internally by FHI-aims. This should be used whenever loading + real-space matrices generated by FHI-aims. +- `refit::Bool`: By default already fitted bases will not be refitted, but this behaviour + can be suppressed by setting `refit=true`. +- `target::String`: a string indicating which matrix should be fitted. This may be either + `H` or `S`. If unspecified then the model's `.label` field will be read and used. +""" +function fit!( + model::Model, systems::Vector{Group}; + on_site_filter::Function = identity, + off_site_filter::Function = identity, + tolerance::Union{F, Nothing}=nothing, + recentre::Bool=false, + target::Union{String, Nothing}=nothing, + refit::Bool=false, solver = "LSQR") where F<:AbstractFloat + + # Todo: + # - Add fitting parameters options which uses a `Params` instance to define fitting + # specific parameters such as regularisation, solver method, whether or not mean is + # used when fitting, and so on. + # - Modify so that redundant data is not extracted; i.e. both A[0,0,0] -> A[1,0,0] and + # A[0,0,0] -> A[-1,0,0] + # - The approach currently taken limits io overhead by reducing redundant operations. + # However, this will likely use considerably more memory. + + # Section 1: Gather the data + + # If no target has been specified; then default to that given by the model's label. + + target = isnothing(target) ? model.label : target + + get_matrix = Dict( # Select an appropriate function to load the target matrix + "H"=>load_hamiltonian, "S"=>load_overlap, + "Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma)[target] + + fitting_data = Dict{Any, DataSet}() + + # Loop over the specified systems + for system in systems + + # Load the required data from the database entry + matrix, atoms = get_matrix(system), load_atoms(system; recentre=recentre) + images = ndims(matrix) == 2 ? nothing : load_cell_translations(system) + + # Loop over the on site bases and collect the appropriate data + for basis in values(model.on_site_submodels) + data_set = get_dataset(matrix, atoms, basis, model.basis_definition, images; + tolerance=tolerance) + + # Don't bother filtering and adding empty datasets + if length(data_set) != 0 + # Apply the on-site data filter function + data_set = on_site_filter(data_set) + # Add the selected data to the fitting data-set. + _append_data!(fitting_data, basis.id, data_set) + end + end + + # Repeat for the off-site models + for basis in values(model.off_site_submodels) + data_set = get_dataset(matrix, atoms, basis, model.basis_definition, images; + tolerance=tolerance, filter_bonds=true) + + if length(data_set) != 0 + data_set = off_site_filter(data_set) + _append_data!(fitting_data, basis.id, data_set) + end + + end + end + + # Fit the on/off-site models + fit!(model, fitting_data; refit=refit, solver = solver) + +end + + +""" + fit!(model, fitting_data[; refit]) + + +Fit the specified model using the provided data. + +# Arguments +- `model::Model`: the model that should be fitted. +- `fitting_data`: dictionary providing the data to which the supplied model should be + fitted. This should hold one entry for each submodel that is to be fitted and should take + the form `{SubModel.id, DataSet}`. +- `refit::Bool`: By default, already fitted bases will not be refitted, but this behaviour + can be suppressed by setting `refit=true`. +""" +function fit!( + model::Model, fitting_data; refit::Bool=false, solver="LSQR") + + @debug "Fitting off site bases:" + for (id, basis) in model.off_site_submodels + if !haskey(fitting_data, id) + @debug "Skipping $(id): no fitting data provided" + elseif is_fitted(basis) && !refit + @debug "Skipping $(id): submodel already fitted" + elseif length(fitting_data) ≡ 0 + @debug "Skipping $(id): fitting dataset is empty" + else + @debug "Fitting $(id): using $(length(fitting_data[id])) fitting points" + fit!(basis, fitting_data[id]; solver = solver) + end + end + + @debug "Fitting on site bases:" + for (id, basis) in model.on_site_submodels + if !haskey(fitting_data, id) + @debug "Skipping $(id): no fitting data provided" + elseif is_fitted(basis) && !refit + @debug "Skipping $(id): submodel already fitted" + elseif length(fitting_data) ≡ 0 + @debug "Skipping $(id): fitting dataset is empty" + else + @debug "Fitting $(id): using $(length(fitting_data[id])) fitting points" + fit!(basis, fitting_data[id]; enable_mean=ison(basis), solver = solver) + end + end +end + + +# The following code was added to `fitting.jl` to allow data to be fitted on databases +# structured using the original database format. +using ACEhamiltonians.DatabaseIO: _load_old_atoms, _load_old_hamiltonian, _load_old_overlap +using Serialization + +function old_fit!( + model::Model, systems, target::Symbol; + tolerance::F=0.0, filter_bonds::Bool=true, recentre::Bool=false, + refit::Bool=false) where F<:AbstractFloat + + # Todo: + # - Check that the relevant data exists before trying to extract it; i.e. don't bother + # trying to gather carbon on-site data from an H2 system. + # - Currently the basis set definition is loaded from the first system under the + # assumption that it is constant across all systems. However, this will break down + # if different species are present in each system. + # - The approach currently taken limits io overhead by reducing redundant operations. + # However, this will likely use considerably more memory. + + # Section 1: Gather the data + + get_matrix = Dict( # Select an appropriate function to load the target matrix + :H=>_load_old_hamiltonian, :S=>_load_old_overlap)[target] + + fitting_data = IdDict{AHSubModel, DataSet}() + + # Loop over the specified systems + for (database_path, index_data) in systems + + # Load the required data from the database entry + matrix, atoms = get_matrix(database_path), _load_old_atoms(database_path) + + println("Loading: $database_path") + + # Loop over the on site bases and collect the appropriate data + if haskey(index_data, "atomic_indices") + println("Gathering on-site data:") + for basis in values(model.on_site_submodels) + println("\t- $basis") + data_set = get_dataset( + matrix, atoms, basis, model.basis_definition; + tolerance=tolerance, focus=index_data["atomic_indices"]) + _append_data!(fitting_data, basis, data_set) + + end + println("Finished gathering on-site data") + end + + # Repeat for the off-site models + if haskey(index_data, "atom_block_indices") + println("Gathering off-site data:") + for basis in values(model.off_site_submodels) + println("\t- $basis") + data_set = get_dataset( + matrix, atoms, basis, model.basis_definition; + tolerance=tolerance, filter_bonds=filter_bonds, focus=index_data["atom_block_indices"]) + _append_data!(fitting_data, basis, data_set) + end + println("Finished gathering off-site data") + end + end + + # Fit the on/off-site models + fit!(model, fitting_data; refit=refit) + +end + +end diff --git a/examples/H2O/python_interface/small/src/io.jl b/examples/H2O/python_interface/small/src/io.jl new file mode 100644 index 0000000..51fc9f7 --- /dev/null +++ b/examples/H2O/python_interface/small/src/io.jl @@ -0,0 +1,385 @@ +# All general io functionality should be placed in this file. With the exception of the +# `read_dict` and `write_dict` methods. + + + +""" +This module contains a series of functions that are intended to aid in the loading of data +from HDF5 structured databases. The code within this module is primarily intended to be +used only during the fitting of new models. Therefore, i) only loading methods are +currently supported, ii) the load target for each function is always the top level group +for each system. + +A brief outline of the expected HDF5 database structure is provided below. Note that +arrays **must** be stored in column major format! + +Database +├─System-1 +│ ├─Structure +│ │ ├─atomic_numbers: +│ │ │ > A vector of integers specifying the atomic number of each atom present in the +│ │ │ > target system. Read by the function `load_atoms`. +│ │ │ +│ │ ├─positions: +│ │ │ > A 3×N matrix, were N is the number of atoms present in the system, specifying +│ │ │ > cartesian coordinates for each atom. Read by the function `load_atoms`. +│ │ │ +│ │ ├─lattice: +│ │ │ > A 3×3 matrix specifying the lattice systems lattice vector in column-wise +│ │ │ > format; i.e. columns loop over vectors. Read by the function `load_atoms` +│ │ │ > if and when present. +│ │ └─pbc: +│ │ > A boolean, or a vector of booleans, indicating if, or along which dimensions, +│ │ > periodic conditions are enforced. This is only read when the lattice is given. +│ │ > This defaults to true for non-molecular/cluster cases. Read by `load_atoms`. +│ │ +│ ├─Info +│ │ ├─Basis: +│ │ │ > This group contains one dataset for each species that specifies the +│ │ │ > azimuthal quantum numbers of each shell present on that species. Read +│ │ │ > by the `load_basis_set_definition` method. +│ │ │ +│ │ │ +│ │ ├─Translations: +│ │ │ > A 3×N matrix specifying the cell translation vectors associated with the real +│ │ │ > space Hamiltonian & overlap matrices. Only present when Hamiltonian & overlap +│ │ │ > matrices are given in their M×M×N real from. Read by `load_cell_translations`. +│ │ │ > Should be integers specifying the cell indices, rather than cartesian vectors. +│ │ │ > The origin cell, [0, 0, 0], must always be first! +│ │ │ +│ │ └─k-points: +│ │ > A 4×N matrix where N is the number of k-points. The first three rows specify +│ │ > k-points themselves with the final row specifying their associated weights. +│ │ > Read by `load_k_points_and_weights`. Only present for multi-k-point calculations. +│ │ +│ └─Data +│ ├─H: +│ │ > The Hamiltonian. Either an M×M matrix or an M×M×N real-space tensor; where M is +│ │ > is the number of orbitals and N then number of primitive cell equivalents. Read +│ │ > in by the `load_hamiltonian` function. +│ │ +│ ├─S: +│ │ > The Overlap matrix; identical in format to the Hamiltonian matrix. This is read +│ │ > in by the `load_overlap` function. +│ │ +│ ├─total_energy: +│ │ > A single float value specifying the total system energy. +│ │ +│ ├─fermi_level: +│ │ > A single float value specifying the total fermi level. +│ │ +│ ├─forces: +│ │ > A 3×N matrix specifying the force vectors on each atom. +│ │ +│ ├─H_gamma: +│ │ > This can be used to store the gamma point only Hamiltonian matrix when 'H' is +│ │ > used to store the real space matrix. This is mostly for debugging & testing. +│ │ +│ └─S_gamma: +| > Overlap equivalent of `H_gamma` +│ +├─System-2 +│ ├─Structure +│ │ └─ ... +│ │ +│ ├─Info +│ │ └─ ... +│ │ +│ └─Data +│ └─ ... +... +└─System-n + └─ ... + + +Datasets and groups should provide information about what units the data they contain are +given in. This can be done through the of the HDF5 metadata `attributes`. When calling +the various load methods within `DatabaseIO` the `src` Group argument must always point +to the target systems top level Group. In the example structure tree given above these +would be 'System-1', 'System-2', and 'System-n'. +""" +module DatabaseIO +using ACEhamiltonians +using HDF5: Group, h5open +using JuLIP: Atoms +using HDF5 # ← Can be removed once support for old HDF5 formats is dropped +using LinearAlgebra: pinv + +# Developers Notes: +# - The functions within this module are mostly just convenience wrappers for the HDF5 +# `read` method and as such very little logic is contained within. Thus no unit tests +# are provided for this module at this time. However, this will change when and if +# unit and write functionality are added. +# +# Todo: +# - A version number flag should be added to each group when written to prevent the +# breaking compatibility with existing databases every time an update is made. +# - The unit information provided in the HDF5 databases should be made use of once +# a grand consensus as to what internal units should be used. + +export load_atoms, load_hamiltonian, load_overlap, gamma_only, load_k_points_and_weights, load_cell_translations, load_basis_set_definition, load_density_of_states, load_fermi_level + + +# Booleans stored by python are interpreted as Int8 by Julia rather than as booleans. Thus +# a cleaner is required. +_clean_bool(bool::I) where I<:Integer = Bool(bool) +_clean_bool(bool::Vector{<:Integer}) = convert(Vector{Bool}, bool) +_clean_bool(bool) = bool + + +function _recentre!(x, l, l_inv) + x[:] = l_inv' * x .- 1E-8 + x[:] = l' * (x - round.(x) .+ 1E-8) + nothing +end + +""" + load_atoms(src) + +Instantiate a `JuLIP.Atoms` object from an HDF5 `Group`. + + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose `Atoms` object is to be + returned. +- `recentre:Bool`: By default, atoms are assumed to span the fractional coordinate domain + [0.0, 1.0). Setting `recentre` to `true` will remap atomic positions to the fractional + coordinate domain of [-0.5, 0.5). This is primarily used when interacting with real-space + matrices produced by the FHI-aims code base. + +# Returns +- `atoms::Atoms`: atoms object representing the structure of the target system. + +""" +function load_atoms(src::Group; recentre=false) + # Developers Notes: + # - Currently non-molecular/cluster systems are assumed to be fully periodic + # along each axis if no `pbc` condition is explicitly specified. + # Todo: + # - Use unit information provided by the "positions" & "lattice" datasets. + + # All system specific structural data should be contained within the "Structure" + # sub-group. Extract the group to a variable for ease of access. + src = src["Structure"] + + # Species and positions are always present to read them in + species, positions = read(src["atomic_numbers"]), read(src["positions"]) + + if haskey(src, "lattice") # If periodic + l = collect(read(src["lattice"])') + if recentre + l_inv = pinv(l) + for x in eachcol(positions) + _recentre!(x, l, l_inv) + end + end + + pbc = haskey(src, "pbc") ? _clean_bool(read(src["pbc"])) : true + return Atoms(; Z=species, X=positions, cell=l, pbc=pbc) + else # If molecular/cluster + return Atoms(; Z=species, X=positions) + end +end + + +""" + load_basis_set_definition(src) + +Load the basis definition of the target system. + +This returns a `BasisDef` dictionary which specifies the azimuthal quantum number of each +shell for each species. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose basis set definition is + to be read. + +# Returns +- `basis_def::BasisDef`: a dictionary keyed by species and valued by a vector specifying + the azimuthal quantum number for each shell on said species. + +""" +function load_basis_set_definition(src::Group) + # Basis set definition is stored in "/Info/Basis" relative to the system's top level + # group. + src = src["Info/Basis"] + # Basis set definition is stored as a series of vector datasets with names which + # correspond the associated atomic number. + return BasisDef{Int}(parse(Int, k) => read(v)[2, :] for (k, v) in zip(keys(src), src)) +end + +""" + load_k_points_and_weights(src) + +Parse k-points and their weights. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose k-points & k-weights + are to be returned. + +# Returns +- `k_points::Matrix`: a 3×n matrix where n is the number of k-points. +- `k_weights::Vector`: a vector with a weight for each k-point. + +# Warnings +This will error out for gamma-point only calculations. + +""" +function load_k_points_and_weights(src::Group) + # Read in the k-points and weights from the "/Info/k-points" matrix. + knw = read(src["Info/k-points"]) + return knw[1:3, :], knw[4, :] +end + + +""" + load_cell_translations(src) + +Load the cell translation vectors associated with the real Hamiltonian & overlap matrices. +Relevant when Hamiltonian & overlap matrices are stored in their real N×N×M form, where N +is the number of orbitals per primitive cell and M is the number of cell equivalents. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose cell translation vectors + are to be returned. + +# Returns +- `T::Matrix`: a 3×M matrix where M is the number of primitive cell equivalents. + +# Notes +There is one translation vector for each translated cell; i.e. if the Hamiltonian matrix +is N×N×M then there will be M cell translation vectors. Here, the first translation +vector is always that of the origin cell, i.e. [0, 0, 0]. + +# Warnings +This will error out for gamma point only calculations or datasets in which the real +matrices are not stored. +""" +load_cell_translations(src::Group) = read(src["Info/Translations"]) + +""" + load_hamiltonian(src) + +Load the Hamiltonian matrix stored for the target system. This may be either an N×N single +k-point (commonly the gamma point) matrix, or an N×N×M real space matrix; where N is the +number or orbitals and M the number of unit cell equivalents. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose Hamiltonian matrix is + to be returned. + +# Returns +- `H::Array`: Hamiltonian matrix. This may be either an N×N matrix, as per the single + k-point case, or an N×N×M array for the real space case. +""" +load_hamiltonian(src::Group) = read(src["Data/H"]) +# Todo: add unit conversion to `load_hamiltonian` + +""" + load_overlap(src) + +Load the overlap matrix stored for the target system. This may be either an N×N single +k-point (commonly the gamma point) matrix, or an N×N×M real space matrix; where N is the +number or orbitals and M the number of unit cell equivalents. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system whose overlap matrix is to be + returned. + +# Returns +- `H::Array`: overlap matrix. This may be either an N×N matrix, as per the single + k-point case, or an N×N×M array for the real space case. +""" +load_overlap(src::Group) = read(src["Data/S"]) +# Todo: add unit conversion to `load_overlap` + + +""" + gamma_only(src) + +Returns `true` if the stored Hamiltonian & overlap matrices are for a single k-point only. +Useful for determining whether or not one should attempt to read cell translations or +k-points, etc. +""" +gamma_only(src::Group) = !haskey(src, "Info/Translations") + +# Get the gamma point only matrix; these are for debugging and are will be removed later. +load_hamiltonian_gamma(src::Group) = read(src["Data/H_gamma"]) +# Todo: add unit conversion to `load_hamiltonian_gamma` +load_overlap_gamma(src::Group) = read(src["Data/S_gamma"]) +# Todo: add unit conversion to `load_overlap_gamma` + +""" + load_density_of_states(src) + +Load the density of states associated with the target system. + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system for which the density of + states is to be returned. + +# Returns +- `values::Vector`: density of states. +- `energies::Vector`: energies at which densities of states were evaluated relative to + the fermi-level. +- `broadening::AbstractFloat`: broadening factor used by the smearing function. + +""" +function load_density_of_states(src::Group) + # Todo: + # - This currently returns units of eV for energy and 1/(eV.V_unit_cell) for DoS. + return ( + read(src, "Data/DoS/values"), + read(src, "Data/DoS/energies"), + read(src, "Data/DoS/broadening")) +end + + +""" + load_fermi_level(src) + +Load the calculated Fermi level (chemical potential). + +# Arguments +- `src::Group`: top level HDF5 `Group` of the target system for which the fermi level is + to be returned. + +# Returns +- `fermi_level::AbstractFloat`: the fermi level. +""" +function load_fermi_level(src) + # Todo: + # - This really should make use of unit attribute that is provided. + return read(src, "Data/fermi_level") +end + + +# These functions exist to support backwards compatibility with previous database structures. +# They are not intended to be called by general users as they will eventually be excised. +function _load_old_hamiltonian(path::String) + return h5open(path) do database + read(database, "aitb/H")[:, :] + end +end + +function _load_old_overlap(path::String) + return h5open(path) do database + read(database, "aitb/S")[:, :] + end +end + +function _load_old_atoms(path::String; groupname=nothing) + h5open(path, "r") do fd + groupname === nothing && (groupname = HDF5.name(first(fd))) + positions = HDF5.read(fd, string(groupname,"/positions")) + unitcell = HDF5.read(fd, string(groupname,"/unitcell")) + species = HDF5.read(fd, string(groupname,"/species")) + atoms = Atoms(; X = positions, Z = species, + cell = unitcell, + pbc = [true, true, true]) + return atoms + end + +end + +end diff --git a/examples/H2O/python_interface/small/src/models.jl b/examples/H2O/python_interface/small/src/models.jl new file mode 100644 index 0000000..9ba500f --- /dev/null +++ b/examples/H2O/python_interface/small/src/models.jl @@ -0,0 +1,241 @@ +module Models + +using ACEhamiltonians, ACE, ACEbase +import ACEbase: read_dict, write_dict +using ACEhamiltonians.Parameters: OnSiteParaSet, OffSiteParaSet +using ACEhamiltonians.Bases: AHSubModel, is_fitted +using ACEhamiltonians: DUAL_BASIS_MODEL +# Once we change the keys of basis_def from Integer to AtomicNumber, we will no +# longer need JuLIP here +using JuLIP + + +export Model + + +# ╔═══════╗ +# ║ Model ║ +# ╚═══════╝ + + +# Todo: +# - On-site and off-site components should be optional. +# - Document +# - Clean up +struct Model + + on_site_submodels + off_site_submodels + on_site_parameters + off_site_parameters + basis_definition + + label::String + + meta_data::Dict{String, Any} + + function Model( + on_site_submodels, off_site_submodels, on_site_parameters::OnSiteParaSet, + off_site_parameters::OffSiteParaSet, basis_definition, label::String, + meta_data::Union{Dict, Nothing}=nothing) + + # If no meta-data is supplied then just default to a blank dictionary + meta_data = isnothing(meta_data) ? Dict{String, Any}() : meta_data + + new(on_site_submodels, off_site_submodels, on_site_parameters, off_site_parameters, + basis_definition, label, meta_data) + end + + function Model( + basis_definition::BasisDef, on_site_parameters::OnSiteParaSet, + off_site_parameters::OffSiteParaSet, label::String, + meta_data::Union{Dict, Nothing}=nothing) + + # Developers Notes + # This makes the assumption that all z₁-z₂-ℓ₁-ℓ₂ interactions are represented + # by the same model. + + # get a species list from basis_definition + species = AtomicNumber.([ keys(basis_definition)... ]) + + # Discuss use of the on/off_site_cache entities + + on_sites = Dict{NTuple{3, keytype(basis_definition)}, AHSubModel}() + off_sites = Dict{NTuple{4, keytype(basis_definition)}, AHSubModel}() + + # Caching the basis functions of the functions is faster and allows us to reuse + # the same basis function for similar interactions. + ace_basis_on = with_cache(on_site_ace_basis) + ace_basis_off = with_cache(off_site_ace_basis) + + # Sorting the basis definition makes avoiding interaction doubling easier. + # That is to say, we don't create models for both H-C and C-H interactions + # as they represent the same thing. + basis_definition_sorted = sort(collect(basis_definition), by=first) + + @debug "Building model" + # Loop over all unique species pairs then over all combinations of their shells. + for (zₙ, (zᵢ, shellsᵢ)) in enumerate(basis_definition_sorted) + for (zⱼ, shellsⱼ) in basis_definition_sorted[zₙ:end] + homo_atomic = zᵢ == zⱼ + for (n₁, ℓ₁) in enumerate(shellsᵢ), (n₂, ℓ₂) in enumerate(shellsⱼ) + + # Skip symmetrically equivalent interactions. + homo_atomic && n₁ > n₂ && continue + + if homo_atomic + id = (zᵢ, n₁, n₂) + @debug "Building on-site model : $id" + ace_basis = ace_basis_on( # On-site bases + ℓ₁, ℓ₂, on_site_parameters[id]...; species = species) + + on_sites[(zᵢ, n₁, n₂)] = AHSubModel(ace_basis, id) + end + + id = (zᵢ, zⱼ, n₁, n₂) + @debug "Building off-site model: $id" + + ace_basis = ace_basis_off( # Off-site bases + ℓ₁, ℓ₂, off_site_parameters[id]...; species = species) + + @static if DUAL_BASIS_MODEL + if homo_atomic && n₁ == n₂ + off_sites[(zᵢ, zⱼ, n₁, n₂)] = AHSubModel(ace_basis, id) + else + ace_basis_i = ace_basis_off( + ℓ₂, ℓ₁, off_site_parameters[(zⱼ, zᵢ, n₂, n₁)]...) + off_sites[(zᵢ, zⱼ, n₁, n₂)] = AHSubModel(ace_basis, ace_basis_i, id) + end + else + off_sites[(zᵢ, zⱼ, n₁, n₂)] = AHSubModel(ace_basis, id) + end + end + end + end + + # If no meta-data is supplied then just default to a blank dictionary + meta_data = isnothing(meta_data) ? Dict{String, Any}() : meta_data + new(on_sites, off_sites, on_site_parameters, off_site_parameters, basis_definition, label, meta_data) + end + +end + +# Associated methods + +Base.:(==)(x::Model, y::Model) = ( + x.on_site_submodels == y.on_site_submodels && x.off_site_submodels == y.off_site_submodels + && x.on_site_parameters == y.on_site_parameters && x.off_site_parameters == y.off_site_parameters) + + +# ╭───────┬──────────────────╮ +# │ Model │ IO Functionality │ +# ╰───────┴──────────────────╯ + +function ACEbase.write_dict(m::Model) + # ACE bases are stored as hash values which are checked against the "bases_hashes" + # dictionary during reading. This avoids saving multiple copies of the same object; + # which is common as `AHSubModel` objects tend to share basis functions. + + + bases_hashes = Dict{String, Any}() + + function add_basis(basis) + # Store the hash/basis pair in the bases_hashes dictionary. As the `write_dict` + # method can be quite costly to evaluate it is best to only call it when strictly + # necessary; hence this function exists. + basis_hash = string(hash(basis)) + if !haskey(bases_hashes, basis_hash) + bases_hashes[basis_hash] = write_dict(basis) + end + end + + for basis in union(values(m.on_site_submodels), values(m.off_site_submodels)) + add_basis(basis.basis) + end + + # Serialise the meta-data + meta_data = Dict{String, Any}( + # Invoke the `read_dict` method on values as and where appropriate + k => hasmethod(write_dict, (typeof(v),)) ? write_dict(v) : v + for (k, v) in m.meta_data + ) + + dict = Dict( + "__id__"=>"HModel", + "on_site_submodels"=>Dict(k=>write_dict(v, true) for (k, v) in m.on_site_submodels), + "off_site_submodels"=>Dict(k=>write_dict(v, true) for (k, v) in m.off_site_submodels), + "on_site_parameters"=>write_dict(m.on_site_parameters), + "off_site_parameters"=>write_dict(m.off_site_parameters), + "basis_definition"=>Dict(k=>write_dict(v) for (k, v) in m.basis_definition), + "bases_hashes"=>bases_hashes, + "label"=>m.label, + "meta_data"=>meta_data) + + return dict +end + + +function ACEbase.read_dict(::Val{:HModel}, dict::Dict)::Model + + function set_bases(target, basis_functions) + for v in values(target) + v["basis"] = basis_functions[v["basis"]] + end + end + + # Replace basis object hashs with the appropriate object. + set_bases(dict["on_site_submodels"], dict["bases_hashes"]) + set_bases(dict["off_site_submodels"], dict["bases_hashes"]) + + ensure_int(v) = v isa String ? parse(Int, v) : v + + # Parse meta-data + if haskey(dict, "meta_data") + meta_data = Dict{String, Any}() + for (k, v) in dict["meta_data"] + if typeof(v) <: Dict && haskey(v, "__id__") + meta_data[k] = read_dict(v) + else + meta_data[k] = v + end + end + else + meta_data = nothing + end + + # One of the important entries present in the meta-data dictionary is the `occupancy` + # data. This should be keyed by integers; however the serialisation/de-serialisation + # process converts this into a string. A hard-coded fix is implemented here, but it + # would be better to create a more general way of handling this later on. + if !isnothing(meta_data) && haskey(meta_data, "occupancy") && (keytype(meta_data["occupancy"]) ≡ String) + meta_data["occupancy"] = Dict(parse(Int, k)=>v for (k, v) in meta_data["occupancy"]) + end + + return Model( + Dict(parse_key(k)=>read_dict(v) for (k, v) in dict["on_site_submodels"]), + Dict(parse_key(k)=>read_dict(v) for (k, v) in dict["off_site_submodels"]), + read_dict(dict["on_site_parameters"]), + read_dict(dict["off_site_parameters"]), + Dict(ensure_int(k)=>read_dict(v) for (k, v) in dict["basis_definition"]), + dict["label"], + meta_data) +end + + +# Todo: this is mostly to stop terminal spam and should be updated +# with more meaningful information later on. +function Base.show(io::IO, model::Model) + + # Work out if the on/off site bases are fully, partially or un-fitted. + f = b -> if all(b) "no" elseif all(!, b) "yes" else "partially" end + on = f([!is_fitted(i) for i in values(model.on_site_submodels)]) + off = f([!is_fitted(i) for i in values(model.off_site_submodels)]) + + # Identify the species present + species = join(sort(unique(getindex.(collect(keys(model.on_site_submodels)), 1))), ", ", " & ") + + print(io, "Model(fitted=(on: $on, off: $off), species: ($species))") +end + + +end diff --git a/examples/H2O/python_interface/small/src/parameters.jl b/examples/H2O/python_interface/small/src/parameters.jl new file mode 100644 index 0000000..6a54de3 --- /dev/null +++ b/examples/H2O/python_interface/small/src/parameters.jl @@ -0,0 +1,902 @@ +module Parameters +using Base, ACEbase +export NewParams, GlobalParams, AtomicParams, AzimuthalParams, ShellParams, ParaSet, OnSiteParaSet, OffSiteParaSet, ison + +# The `Params` structure has been temporarily renamed to `NewParams` to avoid conflicts +# with the old code. However, this will be rectified when the old code is overridden. + +# ╔════════════╗ +# ║ Parameters ║ +# ╚════════════╝ +# Parameter related code. + + +# +# Todo: +# - Parameters: +# - Need to enforce limits on key values, shells must be larger than zero and +# azimuthal numbers must be non-negative. +# - All Params should be combinable, with compound classes generated when combining +# different Params types. Compound types should always check the more refined +# struct first (i.e. ShellParamsLabel conversion + t2l(val::Pair{NTuple{N, I}, V}) where {N, I<:Integer, V} = convert(Pair{Label{N, I}, V}, val) + + if with_basis + return quote + function $(esc(name))(b_def, arg::$T1) where {K<:Label{$N, <:Integer}, V, N} + $(Expr(:call, Expr(:curly, esc(:new), :K, :V), Expr(:call, Dict, :arg), :b_def)) + end + + function $(esc(name))(b_def, arg::$T1) where {K<:NTuple{$N, <:Integer}, V, N} + $(Expr(:call, esc(name), :b_def, Expr(:(...), Expr(:call, :map, esc(t2l), :arg)))) + end + end + else + return quote + function $(esc(name))(arg::$T1) where {K<:Label{$N, <:Integer}, V, N} + $(Expr(:call, Expr(:curly, esc(:new), :K, :V), Expr(:call, Dict, :arg))) + end + + function $(esc(name))(arg::$T1) where {K<:NTuple{$N, <:Integer}, V, N} + $(Expr(:call, esc(name), Expr(:(...), Expr(:call, :map, esc(t2l), :arg)))) + end + end + end +end + + +# ╭────────┬────────────╮ +# │ Params │ Definition │ +# ╰────────┴────────────╯ +""" +Dictionary-like structures for specifying model parameters. + +These are used to provide the parameters needed when constructing models within the +`ACEhamiltonians` framework. There are currently four `Params` type structures, namely +`GlobalParams`, `AtomicParams`, `AzimuthalParams`, and `ShellParams`, each offering +varying levels of specificity. + +Each parameter, correlation order, maximum polynomial degree, environmental cutoff +distance, etc. may be specified using any of the available `Params` based structures. +However, i) each `Params` instance may represent one, and only one, parameter, and ii) +on/off-site parameters must not be mixed. +""" +abstract type NewParams{K, V} end + +""" + GlobalParams(val) + +A `GlobalParams` instance indicates that a single value should be used for all relevant +interactions. Querying such instances will always return the value `val`; so long as the +query is valid. For example: +``` +julia> p = GlobalParams(10.) +GlobalParams{Float64} with 1 entries: + () => 10.0 + +julia> p[1] # <- query parameter associated with H +10. +julia> p[(1, 6)] # <- query parameter associated with H-C interaction +10. +julia> p[(1, 6, 1, 2)] # <- interaction between 1ˢᵗ shell on H and 2ⁿᵈ shell on C +10. +``` +As can be seen the specified value `10.` will always be returned so long as the query is +valid. These instances are useful when specifying parameters that are constant across all +bases, such as the internal cutoff distance, as it avoids having to repeatedly specify it +for each and every interaction. + +# Arguments + - `val::Any`: value of the parameter + +""" +struct GlobalParams{K, V} <: NewParams{K, V} + _vals::Dict{K, V} + + @build GlobalParams 0 false + # Catch for special case where a single value passed + GlobalParams(arg) = GlobalParams(Label()=>arg) +end + + +""" + AtomicParams(k₁=>v₁, k₂=>v₂, ..., kₙ=>vₙ) + +These instances allow for parameters to be specified on a species by species basis. This +equates to one parameter per species for on-site interactions and one parameter per species +pair for off-site interactions. This will then result in all associated bases associated +with a specific species/species-pair all using a common value, like so: +``` +julia> p_on = AtomicParams(1=>9., 6=>11.) +AtomicParams{Float64} with 2 entries: + 6 => 11.0 + 1 => 9.0 + +julia> p_off = AtomicParams((1, 1)=>9., (1, 6)=>10., (6, 6)=>11.) +AtomicParams{Float64} with 3 entries: + (6, 6) => 11.0 + (1, 6) => 10.0 + (1, 1) => 9.0 + +# The value 11. is returned for all on-site C interaction queries +julia> p_on[(6, 1, 1)] == p_on[(6, 1, 2)] == p_on[(6, 2, 2)] == 11. +true +# The value 10. is returned for all off-site H-C interaction queries +julia> p_off[(1, 6, 1, 1)] == p_off[(6, 1, 2, 1)] == p_off[(6, 1, 2, 2)] == 10. +true +``` +These instances are instantiated in a similar manner to dictionaries and offer a finer +degree of control over the parameters than `GlobalParams` structures but are not as +granular as `AzimuthalParams` structures. + +# Arguments +- `pairs::Pair`: a sequence of pair arguments specifying the parameters for each species + or species-pair. Valid parameter forms are: + + - on-site: `z₁=>v` or `(z,)=>v` for on-sites + - off-site: `(z₁, z₂)=>v` + + where `zᵢ` represents the atomic number of species `i` and `v` the parameter valued + associated with this species or specie pair. + + +# Notes +It is important to note that atom pair keys are permutationally invariant, i.e. the keys +`(1, 6)` and `(6, 1)` are redundant and will overwrite one another like so: +``` +julia> test = AtomicParams((1, 6)=>10., (6, 1)=>1000.) +AtomicParams{Float64} with 1 entries: + (1, 6) => 1000.0 + +julia> test[(1, 6)] == test[(6, 1)] == 1000.0 +true +``` +Finally atomic numbers will be sorted so that the lowest atomic number comes first. However, +this is only a superficial visual change and queries will still be invariant to permutation. +""" +struct AtomicParams{K, V} <: NewParams{K, V} + _vals::Dict{K, V} + + @build AtomicParams 1 false + @build AtomicParams 2 false + # Catch for special case where keys are integers rather than tuples + AtomicParams(arg::Vararg{Pair{I, V}, N}) where {I<:Integer, V, N} = AtomicParams( + map(i->((first(i),)=>last(i)), arg)...) + +end + + +""" + AzimuthalParams(basis_definition, k₁=>v₁, k₂=>v₂, ..., kₙ=>vₙ) + +Parameters specified for each azimuthal quantum number of each species. This allows for a +finer degree of control and is a logical extension of the `AtomicParams` structure. It is +important to note that `AzimuthalParams` instances must be supplied with a basis definition. +This allows it to work out the azimuthal quantum number associated with each shell during +lookup. + +``` +# Basis definition describing a H_1s C_2s1p basis set +julia> basis_def = Dict(1=>[0], 6=>[0, 0, 1]) +julia> p_on = AzimuthalParams( + basis_def, (1, 0, 0)=>1, (6, 0, 0)=>2, (6, 0, 1)=>3, (6, 1, 1)=>4) +AzimuthalParams{Int64} with 4 entries: + (6, 0, 0) => 2 + (1, 0, 0) => 1 + (6, 1, 1) => 4 + (6, 0, 1) => 3 + +julia> p_off = AzimuthalParams( + basis_def, (1, 1, 0, 0)=>1, (6, 6, 0, 0)=>2, (6, 6, 0, 1)=>3, (6, 6, 1, 1)=>4, + (1, 6, 0, 0)=>6, (1, 6, 0, 1)=>6) + +AzimuthalParams{Int64} with 6 entries: + (1, 6, 0, 1) => 6 + (6, 6, 0, 1) => 3 + (1, 6, 0, 0) => 6 + (1, 1, 0, 0) => 1 + (6, 6, 1, 1) => 4 + (6, 6, 0, 0) => 2 + +# on-site interactions involving shells 1 % 2 will return 2 as they're both s-shells. +julia> p_on[(6, 1, 1)] == p_on[(6, 1, 2)] == p_on[(6, 2, 2)] == 2 +true + +``` + +# Arguments +- `basis_definition::BasisDef`: basis definition specifying the bases present on each + species. This is used to work out the azimuthal quantum number associated with each + shell when queried. +- `pairs::Pair`: a sequence of pair arguments specifying the parameters for each unique + atomic-number/azimuthal-number pair. Valid forms are: + + - on-site: `(z, ℓ₁, ℓ₂)=>v` + - off-site: `(z₁, z₂, ℓ₁, ℓ₂)=>v` + + where `zᵢ` and `ℓᵢ` represents the atomic and azimuthal numbers of species `i` to which + the parameter `v` is associated. + +# Notes +While keys are agnostic to the ordering of the azimuthal numbers; the first atomic number +`z₁` will always correspond to the first azimuthal number `ℓ₁`, i.e.: + - `(z₁, ℓ₁, ℓ₂) == (z₁, ℓ₂, ℓ₁)` + - `(z₁, z₂, ℓ₁, ℓ₂) == (z₂, z₁, ℓ₂, ℓ₁)` + - `(z₁, z₂, ℓ₁, ℓ₂) ≠ (z₁, z₂ ℓ₂, ℓ₁)` + - `(z₁, z₂, ℓ₁, ℓ₂) ≠ (z₂, z₁ ℓ₁, ℓ₂)` + +""" +struct AzimuthalParams{K, V} <: NewParams{K, V} + _vals::Dict{K, V} + _basis_def + + @build AzimuthalParams 3 true + @build AzimuthalParams 4 true +end + +""" + ShellParams(k₁=>v₁, k₂=>v₂, ..., kₙ=>vₙ) + +`ShellParams` structures allow for individual values to be provided for each and every +unique interaction. While this proved the finest degree of control it can quickly become +untenable for systems with large basis sets or multiple species due the shear number of +variable required. +``` +# For H1s C2s1p basis set. +julia> p_on = ShellParams( + (1, 1, 1)=>1, (6, 1, 1)=>2, (6, 1, 2)=>3, (6, 1, 3)=>4, + (6, 2, 2)=>5, (6, 2, 3)=>6, (6, 3, 3)=>7) + +ShellParams{Int64} with 7 entries: + (6, 3, 3) => 7 + (1, 1, 1) => 1 + (6, 1, 3) => 4 + (6, 2, 2) => 5 + (6, 1, 1) => 2 + (6, 1, 2) => 3 + (6, 2, 3) => 6 + +julia> p_off = ShellParams( + (1, 1, 1, 1)=>1, (1, 6, 1, 1)=>2, (1, 6, 1, 2)=>3, (1, 6, 1, 3)=>4, + (6, 6, 1, 1)=>5, (6, 6, 1, 2)=>6, (6, 6, 1, 3)=>74, (6, 6, 2, 2)=>8, + (6, 6, 2, 3)=>9, (6, 6, 3, 3)=>10) + +ShellParams{Int64} with 10 entries: + (6, 6, 2, 2) => 8 + (6, 6, 3, 3) => 10 + (6, 6, 1, 3) => 74 + (1, 1, 1, 1) => 1 + (1, 6, 1, 2) => 3 + (1, 6, 1, 1) => 2 + (1, 6, 1, 3) => 4 + (6, 6, 1, 1) => 5 + (6, 6, 1, 2) => 6 + (6, 6, 2, 3) => 9 + +``` + +# Arguments +- `pairs::Pair`: a sequence of pair arguments specifying the parameters for each unique + shell pair: + - on-site: `(z, s₁, s₂)=>v`, interaction between shell numbers `s₁` & `s₂` on species `z` + - off-site: `(z₁, z₂, s₁, s₂)=>v`, interaction between shell number `s₁` on species + `zᵢ` and shell number `s₂` on species `z₂`. + +""" +struct ShellParams{K, V} <: NewParams{K, V} + _vals::Dict{K, V} + + @build ShellParams 3 false + @build ShellParams 4 false +end + +# ╭────────┬───────────────────────╮ +# │ Params │ General Functionality │ +# ╰────────┴───────────────────────╯ +# Return the key and value types of the internal dictionary. +Base.valtype(::NewParams{K, V}) where {K, V} = V +Base.keytype(::NewParams{K, V}) where {K, V} = K +Base.valtype(::Type{NewParams{K, V}}) where {K, V} = V +Base.keytype(::Type{NewParams{K, V}}) where {K, V} = K + +# Extract keys and values from the internal dictionary (and number of elements) +Base.keys(x::NewParams) = keys(x._vals) +Base.values(x::NewParams) = values(x._vals) +Base.length(x::T) where T<:NewParams = length(x._vals) + +# Equality check, mostly use during testing +function Base.:(==)(x::T₁, y::T₂) where {T₁<:NewParams, T₂<:NewParams} + dx, dy = x._vals, y._vals + # Different type Params are not comparable + if T₁ ≠ T₂ + return false + # Different key sets means x & y are different + elseif keys(dx) ≠ keys(dy) + return false + # If any key yields a different value in x from x then x & y are different + else + for key in keys(dx) + if dx[key] ≠ dy[key] + return false + end + end + # Otherwise there is no difference between x and y, thus return true + return true + end +end + + + +# ╭────────┬────────────────────╮ +# │ Params │ Indexing Functions │ +# ╰────────┴────────────────────╯ +""" + params_object[key] + +This function makes `Params` structures indexable in the same way that dictionaries are. +This will not only check the `Params` object `params` for the specified key `key` but will +also check for i) permutationally equivalent matches, i.e. (1, 6)≡(6, 1), and ii) keys +that `key` is a subtype of i.e. (1, 6, 1, 1) ⊆ (1, 6). + +Valid key types are: + - z/(z,): single atomic number + - (z₁, z₂): pair of atomic numbers + - (z, s₁, s₂): single atomic number with pair of shell numbers + - (z₁, z₂, s₁, s₂): pair of atomic numbers with pair of shell numbers + +This is primarily intended to be used by the code internally, but is left accessible to the +user. +""" +function Base.getindex(x::T, key::K) where {T<:NewParams, K} + # This will not only match the specified key but also any superset it is a part of; + # i.e. the key (z₁, z₂, s₁, s₂) will match (z₁, z₂). + + # Block 1: convert shell numbers to azimuthal numbers for the AzimuthalParams case. + if T<:AzimuthalParams && !(K<:Integer) + if length(key) == 3 + key = (key[1], [x._basis_def[key[1]][i] for i in key[2:3]]...) + else + key = (key[1:2]..., x._basis_def[key[1]][key[3]], x._basis_def[key[2]][key[4]]) + end + end + + # Block 2: identify closest viable key. + super_key = filter(k->(key ⊆ k), keys(x)) + + # Block 3: retrieve the selected key. + if length(super_key) ≡ 0 + throw(KeyError(key)) + else + return x._vals[first(super_key)] + end +end + + +# ╭────────┬──────────────────╮ +# │ Params │ IO Functionality │ +# ╰────────┴──────────────────╯ +"""Full, multi-line string representation of a `Param` type objected""" +function _multi_line(io, x::T) where T<:NewParams + i = length(keytype(x._vals).types[1].types) ≡ 1 ? 1 : Base.:(:) + indent = repeat(" ", get(io, :indent, 0)+2) + v_string = join(["$(indent)$(k[i]) => $v" for (k, v) in x._vals], "\n") + # Make convert "()" to "(All)" for to make GlobalParams more readable + v_string = replace(v_string, "()" => "All") + return "$(nameof(T)){$(valtype(x))} with $(length(x._vals)) entries:\n$(v_string)" +end + + +function Base.show(io::O, x::T) where {T<:NewParams, O<:IO} + # If printing an isolated Params instance, just use the standard multi-line format + # if !haskey(io.dict, :SHOWN_SET) + # print(io, _multi_line(x)) + if !get(io, :compact, false) && !haskey(io.dict, :SHOWN_SET) + print(io, _multi_line(io, x)) + # If the Params is being printed as part of a group then a more compact + # representation is needed. + else + # Create a slicer remove braces from tuples of length 1 if needed + s = length(keytype(x)) == 1 ? 1 : Base.:(:) + # Sort the keys to ensure consistency + keys_s = sort([j.id for j in keys(x._vals)]) + # Only show first and last keys (or just the first if there is only one) + targets = length(x) != 1 ? [[1, lastindex(keys_s)]] : [1:1] + # Build the key list and print the message out + k_string = join([k[s] for k in keys_s[targets...]], " … ") + # Make convert "()" to "(All)" for to make GlobalParams more readable + k_string = replace(k_string, "()" => "All") + indent = repeat(" ", get(io, :indent, 0)) + print(io, "$(indent)$(nameof(T))($(k_string))") + end +end + +# Special show case: Needed as Base.TTY has no information dictionary +Base.show(io::Base.TTY, x::T) where T<:NewParams = print(io, _multi_line(x)) + + +function ACEbase.write_dict(p::T) where T<:NewParams{K, V} where {K, V} + # Recursive and arbitrary value type storage to be implemented later + # value_parsable = hasmethod(ACEbase.write_dict, (V)) + + dict = Dict( + "__id__"=>"NewParams", + "vals"=>Dict(string(k)=>v for (k, v) in p._vals)) + + if T<:AzimuthalParams + dict["basis_def"] = p._basis_def + end + + return dict +end + +function ACEbase.read_dict(::Val{:NewParams}, dict::Dict) + vals = Dict(Label(k)=>v for (k,v) in dict["vals"]) + n = length(keytype(vals)) + + if n ≡ 0 + return GlobalParams(vals...) + elseif n ≤ 2 + return AtomicParams(vals...) + elseif haskey(dict, "basis_def") + return AzimuthalParams(dict["basis_def"], vals...) + else + return ShellParams(vals...) + end + +end + + +# ╔═════════╗ +# ║ ParaSet ║ +# ╚═════════╝ +# Containers for collections of `Params` instances. These exist mostly to ensure that +# all the required parameters are specified and provide a single location where user +# specified parameters can be collected and checked. + +# ╭─────────┬────────────╮ +# │ ParaSet │ Definition │ +# ╰─────────┴────────────╯ +""" +`ParaSet` instances are structures which collect all the required parameter definitions +for a given interaction type in once place. Once instantiated, the `OnSiteParaSet` and +`OffSiteParaSet` structures should contain all parameters required to construct all of +the desired on/off-site bases. +""" +abstract type ParaSet end + + +""" + OnSiteParaSet(ν, deg, e_cut_out, r0) + +This structure holds all the `Params` instances required to construct the on-site +bases. + + +# Arguments +- `ν::Params{K, Int}`: correlation order, for on-site interactions the body order is one + more than the correlation order. +- `deg::Params{K, Int}`: maximum polynomial degree. +- `e_cut_out::Parameters{K, Float}`: environment's external cutoff distance. +- `r0::Parameters{K, Float}`: scaling parameter (typically set to the nearest neighbour distances). + + +# Todo + - check that r0 is still relevant +""" +struct OnSiteParaSet <: ParaSet + ν + deg + e_cut_out + r0 + + function OnSiteParaSet(ν::T₁, deg::T₂, e_cut_out::T₃, r0::T₄ + ) where {T₁<:NewParams, T₂<:NewParams, T₃<:NewParams, T₄<:NewParams} + ν::NewParams{<:Label, <:Integer} + deg::NewParams{<:Label, <:Integer} + e_cut_out::NewParams{<:Label, <:AbstractFloat} + r0::NewParams{<:Label, <:AbstractFloat} + new(ν, deg, e_cut_out, r0) + end + +end + +""" + OffSiteParaSet(ν, deg, b_cut, e_cut_out, r0) + +This structure holds all the `Params` instances required to construct the off-site +bases. + +# Arguments +- `ν::Params{K, Int}`: correlation order, for off-site interactions the body order is two + more than the correlation order. +- `deg::Params{K, Int}`: maximum polynomial degree. +- `b_cut::Params{K, Float}`: cutoff distance for off-site interactions. +- `e_cut_out::Params{K, Float}`: environment's external cutoff distance. + +# Todo: +- add λₙ & λₗ as parameters +- generate constructor to allow for arbitrary fields +""" +struct OffSiteParaSet <: ParaSet + ν + deg + b_cut + e_cut_out + + function OffSiteParaSet(ν::T₁, deg::T₂, b_cut::T₃, e_cut_out::T₄ + ) where {T₁<:NewParams, T₂<:NewParams, T₃<:NewParams, T₄<:NewParams} + ν::NewParams{<:Label, <:Integer} + deg::NewParams{<:Label, <:Integer} + b_cut::NewParams{<:Label, <:AbstractFloat} + e_cut_out::NewParams{<:Label, <:AbstractFloat} + new(ν, deg, b_cut, e_cut_out) + end + +end + +# ╭─────────┬───────────────────────╮ +# │ ParaSet │ General Functionality │ +# ╰─────────┴───────────────────────╯ +function Base.:(==)(x::T, y::T) where T<:ParaSet + # Check that all fields are equal to one another + for field in fieldnames(T) + # If any do not match then return false + if getfield(x, field) ≠ getfield(y, field) + return false + end + end + + # If all files match then return true + return true +end + + +# ╭─────────┬────────────────────────────────╮ +# │ ParaSet │ Miscellaneous Helper Functions │ +# ╰─────────┴────────────────────────────────╯ +# Returns true if a `ParaSet` corresponds to an on-site interaction. +ison(::OnSiteParaSet) = true +ison(::OffSiteParaSet) = false + + +# ╭─────────┬──────────────────╮ +# │ ParaSet │ IO Functionality │ +# ╰─────────┴──────────────────╯ +function ACEbase.write_dict(p::T) where T<:ParaSet + dict = Dict( + "__id__"=>"ParaSet", + (string(fn)=>write_dict(getfield(p, fn)) for fn in fieldnames(T))...) + return dict +end + + +function ACEbase.read_dict(::Val{:ParaSet}, dict::Dict) + if haskey(dict, "b_cut") + return OffSiteParaSet(( + ACEbase.read_dict(dict[i]) for i in + ["ν", "deg", "b_cut", "e_cut_out"])...) + else + return OnSiteParaSet(( + ACEbase.read_dict(dict[i]) for i in + ["ν", "deg", "e_cut_out", "r0"])...) + end +end + + +function Base.show(io::O, x::T) where {T<:ParaSet, O<:IO} + print(io, "$(nameof(T))") + if !get(io, :compact, false) && !haskey(io.dict, :SHOWN_SET) + new_io = IOContext(io.io, :indent=>get(io, :indent, 0)+4, :compact=>get(io, :compact, false)) + for f in fieldnames(T) + print(new_io, join(["\n", repeat(" ", get(io, :indent, 0)+2)]), "$f: ", getfield(x, f)) + end + + else + for (i, f) in enumerate(fieldnames(T)) + print(io, "$f: ") + show(io, getfield(x, f)) + if i != length(fieldnames(T)) + print(io, ", ") + end + end + print(io, ")") + + end + nothing +end + + + +# ╭─────────┬────────────────────╮ +# │ ParaSet │ Indexing Functions │ +# ╰─────────┴────────────────────╯ +""" + on_site_para_set[key] + +Indexing an `OnSiteParaSet` instance will index each of the internal fields and return +their results in a tuple, i.e. calling `res = on_site_para_set[key]` equates to calling +``` +res = ( + on_site_para_set.ν[key], on_site_para_set.deg[key], + on_site_para_set.e_cut_out[key], on_site_para_set.e_cut_in[key]) +``` + +This is mostly intended as a convenience function. +""" +function Base.getindex(para::OnSiteParaSet, key) + return ( + para.ν[key], para.deg[key], + para.e_cut_out[key], para.r0[key]) +end + + +""" + off_site_para_set[key] + +Indexing an `OffSiteParaSet` instance will index each of the internal fields and return +their results in a tuple, i.e. calling `res = off_site_para_set[key]` equates to calling +``` +res = ( + off_site_para_set.ν[key], off_site_para_set.deg[key], off_site_para_set.b_cut[key], + off_site_para_set.e_cut_out[key], off_site_para_set.e_cut_in[key]) +``` + +This is mostly intended as a convenience function. +""" +function Base.getindex(para::OffSiteParaSet, key) + return ( + para.ν[key], para.deg[key], para.b_cut[key], + para.e_cut_out[key]) +end + + + +# ╔═══════════════════════════╗ +# ║ Internal Helper Functions ║ +# ╚═══════════════════════════╝ + +""" +Sort `Label` tuples so that the lowest atomic-number/shell-number comes first for the +two/one atom interaction labels. If more than four integers are specified then an error +is raised. +""" + +""" + _process_ctuple(tuple) + +Preprocess tuples prior to their conversion into `Label` instances. This ensures that +tuples are ordered so that: + 1. the lowest atomic number comes first, but only if multiple atomic numbers are specified. + 2. the lowest shell number comes first, but only where this does not conflict with point 1. + +An error is then raised if the tuple is of an unexpected length. permitted lengths are: + - 1/(z) single atomic number. + - 2/(z₁, z₂) pair of atomic numbers + - 3/(z, s₁, s₂) single atomic number and pair of shell numbers + - 4/(z₁, z₂, s₁, s₂) pair of atomic numbers and a pair of shell numbers. + +Note that in the latter case s₁ & s₂ correspond to shells on z₁ & z₂ respectively thus +if z₁ and z₂ are flipped due to z₁>z₂ then s₁ & s₂ must also be shuffled. + +This os intended only to be used internally and only during the construction of `Label` +instances. +""" +function _process_tuple(x::NTuple{N, I}) where {N, I<:Integer} + if N <= 1; x + elseif N ≡ 2; x[1] ≤ x[2] ? x : reverse(x) + elseif N ≡ 3; x[2] ≤ x[3] ? x : x[[1, 3, 2]] + elseif N ≡ 4 + if x[1] > x[2] || ((x[1] ≡ x[2]) && (x[3] > x[4])); x[[2, 1, 4, 3]] + else; x + end + else + error( + "Label may contain no more than four integers, valid formats are:\n"* + " ()\n (z₁,)\n (z₁, s₁, s₂)\n (z₁, z₂)\n (z₁, z₂, s₁, s₂)") + end +end + + +# # Guards type conversion of dictionaries keyed with `Label` entities. This is done to +# # ensure that a meaningful message is given to the user when a key-collision occurs. +# function _guarded_convert(t::Type{Dict{Label{N, I}, V}}, x::Dict{NTuple{N, I}, V}) where {N, I<:Integer, V} +# try +# return convert(t, x) +# catch e +# if e.msg == "key collision during dictionary conversion" +# r_keys = _redundant_keys([k for k in keys(x)]) +# error("Redundant keys found:\n$(join([" - $(join(i, ", "))" for i in r_keys], "\n"))") +# else +# rethrow(e) +# end +# end +# end + +# # Collisions cannot occur when input dictionary is keyed by integers not tuples +# _guarded_convert(t::Type{Dict{Label{1, I}, V}}, x::Dict{I, V}) where {N, I<:Integer, V} = convert(t, x) + + +# function _redundant_keys(keys_in::Vector{NTuple{N, I}}) where {I<:Integer, N} +# duplicates = [] +# while length(keys_in) ≥ 1 +# key = Label(pop!(keys_in)) +# matches = [popat!(keys_in, i) for i in findall(i -> i == key, keys_in)] +# if length(matches) ≠ 0 +# append!(duplicates, Ref((key, matches...))) +# end +# end +# return duplicates +# end + +end \ No newline at end of file diff --git a/examples/H2O/python_interface/small/src/predicting.jl b/examples/H2O/python_interface/small/src/predicting.jl new file mode 100644 index 0000000..726d7e0 --- /dev/null +++ b/examples/H2O/python_interface/small/src/predicting.jl @@ -0,0 +1,492 @@ +module Predicting + +using ACE, ACEbase, ACEhamiltonians, LinearAlgebra +using JuLIP: Atoms, neighbourlist +using ACE: ACEConfig, AbstractState, evaluate + +using ACEhamiltonians.States: _get_states +using ACEhamiltonians.Fitting: _evaluate_real + +using ACEhamiltonians: DUAL_BASIS_MODEL + +export predict, predict!, cell_translations + + +""" + cell_translations(atoms, cutoff) + +Translation indices of all cells in within range of the origin. Multiplying any translation +index by the lattice vector will return the cell translation vector associated with said +cell. The results of this function are most commonly used in constructing the real space +matrix. + +# Arguments +- `atoms::Atoms`: system for which the cell translation index vectors are to be constructed. +- `cutoff::AbstractFloat`: cutoff distance for diatomic interactions. + +# Returns +- `cell_indices::Matrix{Int}`: a 3×N matrix specifying all cell images that are within + the cutoff distance of the origin cell. + +# Notes +The first index provided in `cell_indices` is always that of the origin cell; i.e. [0 0 0]. +A cell is included if, and only if, at least one atom within it is within range of at least +one atom in the origin cell. Mirror image cell are always included, that is to say if the +cell [i, j, k] is present then the cell [-i, -j, -k] will also be present. +""" +function cell_translations(atoms::Atoms{T}, cutoff) where T<:AbstractFloat + + l⃗, x⃗ = atoms.cell, atoms.X + # n_atoms::Int = size(x⃗, 2) + n_atoms::Int = size(x⃗, 1) + + # Identify how many cell images can fit within the cutoff distance. + aₙ, bₙ, cₙ = convert.(Int, cld.(cutoff, norm.(eachrow(l⃗)))) + + # Matrix in which the resulting translation indices are to be stored + cellᵢ = Matrix{Int}(undef, 3, (2aₙ + 1) * (2bₙ + 1) * (2cₙ + 1)) + + # The first cell is always the origin cell + cellᵢ[:, 1] .= 0 + i = 1 + + # Loop over all possible cells within the cutoff range. + for n₁=-aₙ:aₙ, n₂=-bₙ:bₙ, n₃=-cₙ:cₙ + + # Origin cell is skipped over when encountered as it is already defined. + if n₁ ≠ 0 || n₂ ≠ 0 || n₃ ≠ 0 + + # Construct the translation vector + t⃗ = l⃗[1, :]n₁ + l⃗[2, :]n₂ + l⃗[3, :]n₃ + + # Check if any atom in the shifted cell, n⃗, is within the cutoff distance of + # any other atom in the origin cell, [0,0,0]. + min_distance = 2cutoff + for atomᵢ=1:n_atoms, atomⱼ=1:n_atoms + min_distance = min(min_distance, norm(x⃗[atomᵢ] - x⃗[atomⱼ] + t⃗)) + end + + # If an atom in the shifted cell is within the cutoff distance of another in + # the origin cell then the cell should be included. + if min_distance ≤ cutoff + i += 1 + cellᵢ[:, i] .= n₁, n₂, n₃ + end + end + end + + # Return a truncated view of the cell translation index matrix. + return cellᵢ[:, 1:i] + +end + +""" +cell_translations(atoms, model) + +Translation indices of all cells in within range of the origin. Note, this is a wrapper +for the base `cell_translations(atoms, cutoff)` method which automatically selects an +appropriate cutoff distance. See the base method for more info. + + +# Argument +- `atoms::Atoms`: system for which the cell translation index vectors are to be constructed. +- `model::Model`: model instance from which an appropriate cutoff distance is to be derived. + + +# Returns +- `cell_indices::Matrix{Int}`: a 3×N matrix specifying all cell images that are within + the cutoff distance of the origin cell. + +""" +function cell_translations(atoms::Atoms, model::Model) + # Loop over the interaction cutoff distances and identify the maximum recognise + # interaction distance and use that as the cutoff. + return cell_translations( + atoms, maximum(values(model.off_site_parameters.b_cut))) +end + + +""" + predict!(values, basis, state) + +Predict the values for a given sub-block by evaluating the provided basis on the specified +state; or more accurately the descriptor that is to be constructed from said state. Results +are placed directly into the supplied matrix `values.` + +# Arguments + - `values::AbstractMatrix`: matrix into which the results should be placed. + - `basis::AHSubModel`: basis to be evaluated. + - `state::Vector{States}`: state upon which the `basis` should be evaluated. +""" +function predict!(values::AbstractMatrix, submodel::T, state::Vector{S}) where {T<:AHSubModel, S<:AbstractState} + # If the model has been fitted then use it to predict the results; otherwise just + # assume the results are zero. + if is_fitted(submodel) + # Construct a descriptor representing the supplied state and evaluate the + # basis on it to predict the associated sub-block. + A = evaluate(submodel.basis, ACEConfig(state)) + B = _evaluate_real(A) + values .= (submodel.coefficients' * B) + submodel.mean + + @static if DUAL_BASIS_MODEL + if T<: AnisoSubModel + A = evaluate(submodel.basis_i, ACEConfig(reflect.(state))) + B = _evaluate_real(A) + values .= (values + ((submodel.coefficients_i' * B) + submodel.mean_i)') / 2.0 + elseif !ison(submodel) && (submodel.id[1] == submodel.id[2]) && (submodel.id[3] == submodel.id[4]) + # If the dual basis model is being used then it is assumed that the symmetry + # issue has not been resolved thus an additional symmetrisation operation is + # required. + A = evaluate(submodel.basis, ACEConfig(reflect.(state))) + B = _evaluate_real(A) + values .= (values + ((submodel.coefficients' * B) + submodel.mean)') / 2.0 + end + end + + else + fill!(values, 0.0) + end +end + + +# Construct and fill a matrix with the results of a single state + +""" +""" +function predict(submodel::AHSubModel, states::Vector{<:AbstractState}) + # Create a results matrix to hold the predicted values. The shape & type information + # is extracted from the basis. However, complex types will be converted to their real + # equivalents as results in ACEhamiltonians are always real. With the current version + # of ACE this is the the easiest way to reliably identify the shape and float type of + # the sub-blocks; at least that Julia is happy with. + n, m, type = ACE.valtype(submodel.basis).parameters[3:5] + values = Matrix{real(type)}(undef, n, m) + predict!(values, submodel, states) + return values +end + + +""" +Predict the values for a collection of sub-blocks by evaluating the provided basis on the +specified states. This is a the batch operable variant of the primary `predict!` method. + +""" +function predict!(values::AbstractArray{<:Any, 3}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) + for i=1:length(states) + @views predict!(values[:, :, i], submodel, states[i]) + end +end + + +""" +""" +function predict(submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) + # Construct and fill a matrix with the results from multiple states + n, m, type = ACE.valtype(submodel.basis).parameters[3:5] + values = Array{real(type), 3}(undef, n, m, length(states)) + predict!(values, submodel, states) + return values +end + + +# Special version of the batch operable `predict!` method that is used when scattering data +# into a Vector of AbstractMatrix types rather than into a three dimensional tensor. This +# is implemented to facilitate the scattering of data into collection of sub-view arrays. +function predict!(values::Vector{<:Any}, submodel::AHSubModel, states::Vector{<:Vector{<:AbstractState}}) + for i=1:length(states) + @views predict!(values[i], submodel, states[i]) + end +end + + + + +""" +""" +function predict(model::Model, atoms::Atoms, cell_indices::Union{Nothing, AbstractMatrix}=nothing; kwargs...) + # Pre-build neighbour list to avoid edge case which can degrade performance + _preinitialise_neighbour_list(atoms, model) + + if isnothing(cell_indices) + return _predict(model, atoms; kwargs...) + else + return _predict(model, atoms, cell_indices; kwargs...) + end +end + +function _predict(model, atoms, cell_indices) + + # Todo:- + # - use symmetry to prevent having to compute data for cells reflected + # cell pairs; i.e. [ 0, 0, 1] & [ 0, 0, -1] + # - Setting the on-sites to an identity should be determined by the model + # rather than just assuming that the user always wants on-site overlap + # blocks to be identity matrices. + + basis_def = model.basis_definition + n_orbs = number_of_orbitals(atoms, basis_def) + + # Matrix into which the final results will be placed + matrix = zeros(n_orbs, n_orbs, size(cell_indices, 2)) + + # Mirror index map array required by `_reflect_block_idxs!` + mirror_idxs = _mirror_idxs(cell_indices) + + # The on-site blocks of overlap matrices are approximated as identity matrix. + if model.label ≡ "S" + matrix[1:n_orbs+1:n_orbs^2] .= 1.0 + end + + for (species₁, species₂) in species_pairs(atoms::Atoms) + + # Matrix containing the block indices of all species₁-species₂ atom-blocks + blockᵢ = repeat_atomic_block_idxs( + atomic_block_idxs(species₁, species₂, atoms), size(cell_indices, 2)) + + # Identify on-site sub-blocks now as they as static over the shell pair loop. + # Note that when `species₁≠species₂` `length(on_blockᵢ)≡0`. + on_blockᵢ = filter_on_site_idxs(blockᵢ) + + for (shellᵢ, shellⱼ) in shell_pairs(species₁, species₂, basis_def) + + + # Get the off-site basis associated with this interaction + basis_off = model.off_site_submodels[(species₁, species₂, shellᵢ, shellⱼ)] + + # Identify off-site sub-blocks with bond-distances less than the specified cutoff + off_blockᵢ = filter_idxs_by_bond_distance( + filter_off_site_idxs(blockᵢ), + envelope(basis_off).r0cut, atoms, cell_indices) + + # Blocks in the lower triangle are redundant in the homo-orbital interactions + if species₁ ≡ species₂ && shellᵢ ≡ shellⱼ + off_blockᵢ = filter_upper_idxs(off_blockᵢ) + end + + # off_site_states = _get_states( # Build states for the off-site atom-blocks + # off_blockᵢ, atoms, envelope(basis_off), cell_indices) + + # # Don't try to compute off-site interactions if none exist + # if length(off_site_states) > 0 + # let values = predict(basis_off, off_site_states) # Predict off-site sub-blocks + # set_sub_blocks!( # Assign off-site sub-blocks to the matrix + # matrix, values, off_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + + # _reflect_block_idxs!(off_blockᵢ, mirror_idxs) + # values = permutedims(values, (2, 1, 3)) + # set_sub_blocks!( # Assign data to symmetrically equivalent blocks + # matrix, values, off_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + # end + # end + + ########################### Chen ############################ + if size(off_blockᵢ, 2) != 0 + + off_site_states = _get_states( # Build states for the off-site atom-blocks + off_blockᵢ, atoms, envelope(basis_off), cell_indices) + + # Don't try to compute off-site interactions if none exist + if length(off_site_states) > 0 + let values = predict(basis_off, off_site_states) # Predict off-site sub-blocks + set_sub_blocks!( # Assign off-site sub-blocks to the matrix + matrix, values, off_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + + _reflect_block_idxs!(off_blockᵢ, mirror_idxs) + values = permutedims(values, (2, 1, 3)) + set_sub_blocks!( # Assign data to symmetrically equivalent blocks + matrix, values, off_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + end + end + + end + ############################################################# + + # Evaluate on-site terms for homo-atomic interactions; but only if not instructed + # to approximate the on-site sub-blocks as identify matrices. + if species₁ ≡ species₂ && model.label ≠ "S" + # Get the on-site basis and construct the on-site states + basis_on = model.on_site_submodels[(species₁, shellᵢ, shellⱼ)] + on_site_states = _get_states(on_blockᵢ, atoms; r=radial(basis_on).R.ru) + + # Don't try to compute on-site interactions if none exist + if length(on_site_states) > 0 + let values = predict(basis_on, on_site_states) # Predict on-site sub-blocks + set_sub_blocks!( # Assign on-site sub-blocks to the matrix + matrix, values, on_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + values = permutedims(values, (2, 1, 3)) + set_sub_blocks!( # Assign data to the symmetrically equivalent blocks + matrix, values, on_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + end + end + end + + end + end + + return matrix +end + + +function _predict(model, atoms) + # Currently this method has the tendency to produce non-positive definite overlap + # matrices when working with the aluminum systems, however this is not observed in + # the silicon systems. As such this function should not be used for periodic systems + # until the cause of this issue can be identified. + @warn "This function is not to be trusted" + # TODO: + # - It seems like the `filter_idxs_by_bond_distance` method is not working as intended + # as results change based on whether this is enabled or disabled. + + # See comments in the real space matrix version of `predict` more information. + basis_def = model.basis_definition + n_orbs = number_of_orbitals(atoms, basis_def) + + matrix = zeros(n_orbs, n_orbs) + + # If constructing an overlap matrix then the on-site blocks can just be set to + # an identify matrix. + if model.label ≡ "S" + matrix[1:n_orbs+1:end] .= 1.0 + end + + for (species₁, species₂) in species_pairs(atoms::Atoms) + + blockᵢ = atomic_block_idxs(species₁, species₂, atoms) + + on_blockᵢ = filter_on_site_idxs(blockᵢ) + + for (shellᵢ, shellⱼ) in shell_pairs(species₁, species₂, basis_def) + + basis_off = model.off_site_submodels[(species₁, species₂, shellᵢ, shellⱼ)] + + off_blockᵢ = filter_idxs_by_bond_distance( + filter_off_site_idxs(blockᵢ), + envelope(basis_off).r0cut, atoms) + + if species₁ ≡ species₂ && shellᵢ ≡ shellⱼ + off_blockᵢ = filter_upper_idxs(off_blockᵢ) + end + + off_site_states = _get_states( + off_blockᵢ, atoms, envelope(basis_off)) + + if length(off_site_states) > 0 + let values = predict(basis_off, off_site_states) + set_sub_blocks!( + matrix, values, off_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + + _reflect_block_idxs!(off_blockᵢ) + values = permutedims(values, (2, 1, 3)) + set_sub_blocks!( + matrix, values, off_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + end + end + + + if species₁ ≡ species₂ && model.label ≠ "S" + basis_on = model.on_site_submodels[(species₁, shellᵢ, shellⱼ)] + on_site_states = _get_states(on_blockᵢ, atoms; r=radial(basis_on).R.ru) + + + if length(on_site_states) > 0 + let values = predict(basis_on, on_site_states) + set_sub_blocks!( + matrix, values, on_blockᵢ, shellᵢ, shellⱼ, atoms, basis_def) + + values = permutedims(values, (2, 1, 3)) + set_sub_blocks!( + matrix, values, on_blockᵢ, shellⱼ, shellᵢ, atoms, basis_def) + end + end + end + + end + end + + return matrix +end + + +# ╭───────────────────────────╮ +# │ Internal Helper Functions │ +# ╰───────────────────────────╯ + +""" +Construct the mirror index map required by `_reflect_block_idxs!`. +""" +function _mirror_idxs(cell_indices) + mirror_idxs = Vector{Int}(undef, size(cell_indices, 2)) + let cell_to_index = Dict(cell=>idx for (idx, cell) in enumerate(eachcol(cell_indices))) + for i=1:length(mirror_idxs) + mirror_idxs[i] = cell_to_index[cell_indices[:, i] * -1] + end + end + return mirror_idxs +end + + +""" +This function takes in a `BlkIdx` entity as an argument & swaps the atomic block indices; +i.e. [1, 2] → [2, 1]. +""" +function _reflect_block_idxs!(block_idxs::BlkIdx) + @inbounds for i=1:size(block_idxs, 2) + block_idxs[1, i], block_idxs[2, i] = block_idxs[2, i], block_idxs[1, i] + end + nothing +end + +""" +Inverts a `BlkIdx` instance by swapping the atomic-indices and substitutes the cell index +for its reflected counterpart; i.e. [i, j, k] → [j, i, idx_mirror[k]]. +""" +function _reflect_block_idxs!(block_idxs::BlkIdx, idx_mirror::AbstractVector) + @inbounds for i=1:size(block_idxs, 2) + block_idxs[1, i], block_idxs[2, i] = block_idxs[2, i], block_idxs[1, i] + block_idxs[3, i] = idx_mirror[block_idxs[3, i]] + end + nothing +end + + +function _maximum_distance_estimation(model::Model) + # Maximum radial distance (on-site) + max₁ = maximum(values(model.on_site_parameters.e_cut_out)) + # Maximum radial distance (off-site) + max₂ = maximum(values(model.off_site_parameters.e_cut_out)) + # Maximum effective envelope distance + max₃ = maximum( + [sqrt((env.r0cut + env.zcut)^2 + (env.rcut/2)^2) + for env in envelope.(values(model.off_site_submodels))]) + + return max(max₁, max₂, max₃) + +end + +""" +The construction of neighbour lists can be computationally intensive. As such lists are +used frequently by the lower levels of the code, they are cached & only every recomputed +recomputed if the requested cutoff distance exceeds that used when building the cached +version. It has been found that because each basis can have a different cutoff distance +it is possible that, due to the effects of evaluation order, that the neighbour list can +end up being reconstructed many times. This can be mitigated by anticipating what the +largest cutoff distance is likely to be and pre-building the neighbour list ahead of time. +Hence this function. +""" +function _preinitialise_neighbour_list(atoms::Atoms, model::Model) + # Get a very rough approximation for what the largest cutoff distance might be when + # constructing the neighbour list. + r = _maximum_distance_estimation(model) * 1.1 + + # Construct construct and cache the maximum likely neighbour list + neighbourlist(atoms, r; fixcell=false); + + nothing +end + + +end diff --git a/examples/H2O/python_interface/small/src/properties.jl b/examples/H2O/python_interface/small/src/properties.jl new file mode 100644 index 0000000..e16670a --- /dev/null +++ b/examples/H2O/python_interface/small/src/properties.jl @@ -0,0 +1,216 @@ + + +# Warning this module is not release ready +module Properties + +using ACEhamiltonians, LinearAlgebra + +export real_to_complex!, real_to_complex, band_structure, density_of_states, eigenvalue_confidence_interval + + +const _π2im = -2.0π * im + +function eigvals_at_k(H::A, S::A, T, k_point; kws...) where A<:AbstractArray{<:AbstractFloat, 3} + return real(eigvals(real_to_complex(H, T, k_point), real_to_complex(S, T, k_point); kws...)) +end + + +phase(k::AbstractVector, T::AbstractVector) = exp(_π2im * (k ⋅ T)) +phase(k::AbstractVector, T::AbstractMatrix) = exp.(_π2im * (k' * T)) + +function real_to_complex!(A_complex::AbstractMatrix{C}, A_real::AbstractArray{F, 3}, T::AbstractMatrix, k_point; sym=false) where {C<:Complex, F<:AbstractFloat} + for i=1:size(T, 2) + @views A_complex .+= A_real[:, :, i] .* phase(k_point, T[:, i]) + end + if sym + A_complex .= (A_complex + A_complex') * 0.5 + end + nothing +end + + +""" + real_to_complex(A_real, T, k_point[; sym=false]) + +Compute the complex matrix at a given k-point for a given real-space matrix. + +# Arguments + - `A_real`: real-space matrix of size N×N×T, where N is the number of atomic + orbitals and T the number of cell translation vectors. + - `T`: cell translation vector matrix of size 3×T. + - `k_point`: the k-points for which the complex matrix should be returned. + - `sym`: if true the resulting matrix will be symmetrised prior to its return. + +# Returns + - `A_complex`: the real space matrix evaluated at the requested k-point. + +""" +function real_to_complex(A_real::AbstractArray{F, 3}, T, k_point::Vector; sym=false) where F<:AbstractFloat + A_complex = zeros(Complex{F}, size(A_real, 2), size(A_real, 2)) + real_to_complex!(A_complex, A_real, T, k_point; sym=sym) + return A_complex +end + + +function eigenvalue_confidence_interval(H, H̃, S, S̃, T, k_points, posterior=false) + n = size(H, 1) + C = complex(valtype(H)) + + H_k = Matrix{C}(undef, n, n) + S_k = Matrix{C}(undef, n, n) + H̃_k = Matrix{C}(undef, n, n) + S̃_k = Matrix{C}(undef, n, n) + ΔH = Matrix{C}(undef, n, n) + ΔS = Matrix{C}(undef, n, n) + + results = Matrix{valtype(H)}(undef, n, size(k_points, 2)) + + for (i, k_point) in enumerate(eachcol(k_points)) + fill!(H_k, zero(C)) + fill!(S_k, zero(C)) + fill!(H̃_k, zero(C)) + fill!(S̃_k, zero(C)) + + real_to_complex!(H_k, H, T, k_point) + real_to_complex!(S_k, S, T, k_point) + real_to_complex!(H̃_k, H̃, T, k_point) + real_to_complex!(S̃_k, S̃, T, k_point) + + ΔH[:, :] = H̃_k - H_k + ΔS[:, :] = S̃_k - S_k + + ϵ, φ = eigen!(H_k, S_k); + + ϵₜ = let + if !posterior + ϵ + else + eigen!(H̃_k, S̃_k).values() + end + end + + for (j, (ϵᵢ, φᵢ)) in enumerate(zip(ϵₜ, eachcol(φ))) + + results[j, i] = real.(φᵢ' * ((ΔH - ϵᵢ * ΔS) * φᵢ)) + end + + end + + return results +end + + +function gaussian_broadening(E, ϵ, σ) + return exp(-((E - ϵ) / σ)^2) / (sqrt(π) * σ) +end + +# function gaussian_broadening(E, dE, ϵ, σ) +# # Broadens in an identical manner to FHI-aims; not that this wil require +# # SpecialFunctions.erf to work. While the results returned by this method +# # match with FHI-aims the final DoS is off by a factor of 0.5 for some +# # reason; so double counting is happening somewhere. +# ga = erf((E - ϵ + (dE/2)) / (sqrt(2.0)σ)) +# gb = erf((E - ϵ - (dE/2)) / (sqrt(2.0)σ)) +# return (ga - gb) / 2dE +# end + + +""" +Density of states (k-point independent) +""" +function density_of_states(E::T, ϵ::T, σ) where T<:Vector{<:AbstractFloat} + dos = T(undef, length(E)) + for i=1:length(E) + dos[i] = sum(gaussian_broadening.(E[i], ϵ, σ)) + end + return dos +end + + + +""" +Density of states (k-point dependant) +""" +function density_of_states(E::V, ϵ::Matrix{F}, k_weights::V, σ; fermi=0.0) where V<:Vector{F} where F<:AbstractFloat + # A non-zero fermi value indicates that the energies `E` are relative to the fermi + # level. The most efficient and less user intrusive way to deal with this is create + # and operate on an offset copy of `E`. + if fermi ≠ 0.0 + E = E .+ fermi + end + + dos = zeros(F, length(E)) + let temp_array = zeros(F, size(ϵ)...) + for i=1:length(E) + temp_array .= gaussian_broadening.(E[i], ϵ, σ) + temp_array .*= k_weights' + dos[i] = sum(temp_array) + end + end + + return dos +end + + +function density_of_states(E::Vector, H::M, S::M, σ) where {M<:AbstractMatrix} + return density_of_states(E, eigvals(H, S), σ) +end + +function density_of_states(E::V, H::A, S::A, k_point::V, T, σ) where {V<:Vector{<:AbstractFloat}, A<:AbstractArray{<:AbstractFloat,3}} + return density_of_states(E, eigvals_at_k(H, S, T, k_point), σ) +end + +function density_of_states(E::Vector, H::A, S::A, k_points::AbstractMatrix, T, k_weights, σ; fermi=0.0) where {A<:AbstractArray{F, 3}} where F<:AbstractFloat + return density_of_states(E, band_structure(H, S, T, k_points), k_weights, σ; fermi) +end + + + + +""" + band_structure(H_real, S_real, T, k_points) + +# Arguments + - `H_real`: real space Hamiltonian matrix of size N×N×T, where N is the number of atomic + orbitals and T the number of cell translation vectors. + - `S_real`: real space overlap matrix of size N×N×T. + - `T`: cell translation vector matrix of size 3×T. + - `k_points`: a matrix specifying the k-points at which the eigenvalues should be evaluated. + +# Returns + - `eigenvalues`: eigenvalues evaluated along the specified k-point path. The columns of + this matrix loop over k-points and rows over states. + +""" +function band_structure(H_real::A, S_real::A, T, k_points) where A<:AbstractArray{F, 3} where F<:AbstractFloat + + C = Complex{F} + nₒ, nₖ = size(H_real, 2), size(k_points, 2) + + # Final results array. + ϵ = Matrix{F}(undef, nₒ, nₖ) + + # Construct the transient storage arrays + let H_complex = Matrix{C}(undef, nₒ, nₒ), S_complex = Matrix{C}(undef, nₒ, nₒ) + + # Loop over each k-point + for i=1:nₖ + # Clear the transient storage arrays + fill!(H_complex, zero(C)) + fill!(S_complex, zero(C)) + + # Evaluate the Hamiltonian and overlap matrices at the iᵗʰ k-point + real_to_complex!(H_complex, H_real, T, k_points[:, i]) + real_to_complex!(S_complex, S_real, T, k_points[:, i]) + + # Calculate the eigenvalues + ϵ[:, i] .= real(eigvals(H_complex, S_complex)) + end + + end + + return ϵ +end + + +end \ No newline at end of file diff --git a/examples/H2O/python_interface/small/src/states.jl b/examples/H2O/python_interface/small/src/states.jl new file mode 100644 index 0000000..e42ccd0 --- /dev/null +++ b/examples/H2O/python_interface/small/src/states.jl @@ -0,0 +1,498 @@ +module States +using ACEhamiltonians, NeighbourLists, JuLIP +using ACEhamiltonians.MatrixManipulation: BlkIdx +using StaticArrays: SVector +using LinearAlgebra: norm, normalize +using ACE: AbstractState, CylindricalBondEnvelope, BondEnvelope, _evaluate_bond, _evaluate_env +using ACEhamiltonians: BOND_ORIGIN_AT_MIDPOINT + +import ACEhamiltonians.Parameters: ison +import ACE: _inner_evaluate + +export BondState, AtomState, reflect, get_state + +# ╔════════╗ +# ║ States ║ +# ╚════════╝ +""" + BondState(mu, mu_i, mu_j, rr, rr0, bond) + +State entities used when representing the environment about a bond. + +# Fields +- `mu`: AtomicNumber of the current atom +- `mu_i`: AtomicNumber of the first bonding atom +- `mu_j`: AtomicNumber of the second bonding atom +- `rr`: environmental atom's position relative to the midpoint of the bond. +- `rr0`: vector between the two "bonding" atoms, i.e. the bond vector. +- `bond`: a boolean, which if true indicates the associated `BondState` entity represents + the bond itself. If false, then the state is taken to represent an environmental atom + about the bond rather than the bond itself. + +# Notes +If `bond == true` then `rr` should be set to `rr0/2`. If an environmental atom lies too +close to bond's midpoint then ACE may crash. Thus a small offset may be required in some +cases. + +# Developers Notes +An additional field will be added at a later data to facilitate multi-species support. It +is possible that the `BondState` structure will have to be split into two sub-structures. + +# Todo + - Documentation should be updated to account for the fact that the bond origin has been + moved back to the first atoms position. + - Call for better names for the first three fields +""" +const AN_default = AtomicNumber(:X) + +struct BondState{T<:SVector{3, <:AbstractFloat}, B<:Bool} <: AbstractState + mu::AtomicNumber + mu_i::AtomicNumber + mu_j::AtomicNumber + rr::T + rr0::T + bond::B +end + +BondState{T, Bool}(rr,rr0,bond::Bool) where T<:SVector{3, <:AbstractFloat} = BondState(AN_default,AN_default,AN_default,T(rr),T(rr0),bond) +BondState(rr,rr0,bond::Bool) = BondState{SVector{3, Float64}, Bool}(rr,rr0,bond) +""" + AtomState(mu,mu_i,rr) + +State entity representing the environment about an atom. + +# Fields +- `mu`: AtomicNumber of the current atom +- `mu_i`: AtomicNumber of the centre atom +- `rr`: environmental atom's position relative to the host atom. + +""" +struct AtomState{T<:SVector{3, <:AbstractFloat}} <: AbstractState + mu::AtomicNumber + mu_i::AtomicNumber + rr::T +end + +AtomState{T}(rr) where T<:SVector{3, <:AbstractFloat} = AtomState(AN_default,AN_default,T(rr)) +AtomState(rr) = AtomState{SVector{3, Float64}}(rr) + +# ╭────────┬───────────────────────╮ +# │ States │ General Functionality │ +# ╰────────┴───────────────────────╯ + +# Display methods to help alleviate endless terminal spam. +function Base.show(io::IO, state::BondState) + mu = state.mu + mu_i = state.mu_i + mu_j = state.mu_j + rr = string([round.(state.rr, digits=5)...]) + rr0 = string([round.(state.rr0, digits=5)...]) + print(io, "BondState(mu:$mu, mu_i:$mu_i, mu_j:$mu_j, rr:$rr, rr0:$rr0, bond:$(state.bond))") +end + +function Base.show(io::IO, state::AtomState) + mu = state.mu + mu_i = state.mu_i + rr = string([round.(state.rr, digits=5)...]) + print(io, "AtomState(mu:$mu, mu_i:$mu_i, rr:$rr)") +end + +# Allow for equality checks (will otherwise default to equivalency) +Base.:(==)(x::T, y::T) where T<:BondState = x.mu == y.mu && x.mu_i == y.mu_i && x.mu_j == y.mu_j && x.rr == y.rr && y.rr0 == y.rr0 && x.bond == y.bond +Base.:(==)(x::T, y::T) where T<:AtomState = x.mu == y.mu && x.mu_i == y.mu_i && x.rr == y.rr + +# The ≈ operator is commonly of more use for State entities than the equality +Base.isapprox(x::T, y::T; kwargs...) where T<:BondState = x.mu == y.mu && x.mu_i == y.mu_i && x.mu_j == y.mu_j && isapprox(x.rr, y.rr; kwargs...) && isapprox(x.rr0, y.rr0; kwargs...) && x.bond == y.bond +Base.isapprox(x::T, y::T; kwargs...) where T<:AtomState = x.mu == y.mu && x.mu_i == y.mu_i && isapprox(x.rr, y.rr; kwargs...) +Base.isapprox(x::T, y::T; kwargs...) where T<:AbstractVector{<:BondState} = all(x .≈ y) +Base.isapprox(x::T, y::T; kwargs...) where T<:AbstractVector{<:AtomState} = all(x .≈ y) + +# ACE requires the `zero` method to be defined for states. +Base.zero(::Type{BondState{T, S}}) where {T, S} = BondState{T, S}(zero(T), zero(T), true) +Base.zero(::Type{AtomState{T}}) where T = AtomState{T}(zero(T)) +Base.zero(::B) where B<:BondState = zero(B) +Base.zero(::B) where B<:AtomState = zero(B) +# Todo: +# - Identify why `zero(BondState)` is being called and if the location where it is used +# is effected by always choosing `bond` to be false. [PRIORITY:LOW] + +""" + ison(state) + +Return a boolean indicating whether the state entity is associated with either an on-site +or off-site interaction. +""" +ison(::T) where T<:AtomState = true +ison(::T) where T<:BondState = false + + +""" + reflect(state) + +Reflect `BondState` across the bond's midpoint. Calling on a state representing the bond A→B +will return the symmetrically B→A state. For states where `bond=true` this will flip the +sign on `rr` & `rr0`; whereas only `rr0` is flipped for the `bond=false` case. + +# Arguments +- `state::BondState`: the state to be reflected. + +# Returns +- `reflected_state::BondState`: a view of `state` reflected across the midpoint. + +# Warnings +This is only valid for bond states whose atomic positions are given relative to the midpoint +of the bond; i.e. `envelope.λ≡0`. +""" +function reflect(state::T) where T<:BondState + @static if BOND_ORIGIN_AT_MIDPOINT + if state.bond + return T(state.mu, state.mu_j, state.mu_i, -state.rr, -state.rr0, true) + else + return T(state.mu, state.mu_j, state.mu_i, state.rr, -state.rr0, false) + end + else + if state.bond + return T(state.mu, state.mu_j, state.mu_i, -state.rr, -state.rr0, true) + else + return T(state.mu, state.mu_j, state.mu_i, state.rr - state.rr0, -state.rr0, false) + end + end +end + +# `reflect` is just an identify function for `AtomState` instances. This is included to +# alleviate the need for branching elsewhere in the code. +reflect(state::AtomState) = state + +# ╭───────┬───────────╮ +# │ State │ Factories │ +# ╰───────┴───────────╯ +""" + get_state(i, atoms[; r=16.0]) + +Construct a state representing the environment about atom `i`. + +# Arguments +- `i::Integer`: index of the atom whose state is to be constructed. +- `atoms::Atoms`: the `Atoms` object in which atom `i` resides. +- `r::AbstractFloat`: maximum distance up to which neighbouring atoms + should be considered. + +# Returns +- `state::Vector{::AtomState}`: state objects representing the environment about atom `i`. + +""" +function get_state(i::Integer, atoms::Atoms; r::AbstractFloat=16.0) + # Construct the neighbour list (this is cached so speed is not an issue) + pair_list = JuLIP.neighbourlist(atoms, r; fixcell=false) + + # Extract environment about each relevant atom from the pair list. These will be tuples + # of the form: (atomic-index, relative-position) + idxs, vecs, species = JuLIP.Potentials.neigsz(pair_list, atoms, i) + + # Construct the `AtomState`` vector + st = [ AtomState(species[j], atoms.Z[i], vecs[j]) for j = 1:length(species) ] + + # Return an AtomState vector without those outside the cutoff sphere. + return filter(k -> norm(k.rr) <= r, st) +end + + +""" + get_state(i, j, atoms, envelope[, image]) + +Construct a state representing the environment about the "bond" between atoms `i` & `j`. + +# Arguments +- `i::Int`: atomic index of the first bonding atom. +- `j::Int`: atomic index of the second bonding atom. +- `atoms::Atoms`: the `Atoms` object in which atoms `i` and `j` reside. +- `envelope::CylindricalBondEnvelope:` an envelope specifying the volume to consider when + constructing the state. This must be centred at the bond's midpoint; i.e. `envelope.λ≡0`. +- `image::Optional{Vector}`: a vector specifying the image in which atom `j` + should reside; i.e. the cell translation vector. This defaults to `nothing` which will + result in the closets periodic image of `j` being used. +- `r::AbstractFloat`: this can be used to manually specify the cutoff distance used when + building the neighbour list. This will override the locally computed value for `r` and + is primarily used to aid debugging. + +# Returns +- `state::Vector{::BondState}`: state objects representing the environment about the bond + between atoms `i` and `j`. + +# Notes +It is worth noting that a state will be constructed for the ij bond even when the distance +between them exceeds the bond-cutoff specified by the `envelope`. The maximum cutoff +distance for neighbour list construction is handled automatically. + +# Warnings +- The neighbour list for the bond is constructed by applying an offset to the first atom's + neighbour list. As such spurious states will be encountered when the ij distance exceeds + the bond cutoff value `envelope.r0cut`. Do not ignore this warning! +- It is vital to ensure that when an `image` is supplied that all atomic coordinates are + correctly wrapped into the unit cell. If fractional coordinates lie outside of the range + [0, 1] then the results of this function will not be correct. + +""" +function get_state( + i::I, j::I, atoms::Atoms, envelope::CylindricalBondEnvelope, + image::Union{AbstractVector{I}, Nothing}=nothing; r::Union{Nothing, <:AbstractFloat}=nothing) where {I<:Integer} + + # Todo: + # - Combine the neighbour lists of atom i and j rather than just the former. This + # will reduce the probably of spurious state construction. But will increase run + # time as culling of duplicate states and bond states will be required. + # - rr for the bond really should be halved and inverted to match up with the + # environmental coordinate system. + + # Neighbour list cutoff distance; accounting for distances being relative to atom `i` + # rather than the bond's mid-point + if isnothing(r) + r = sqrt((envelope.r0cut + envelope.zcut)^2 + envelope.rcut^2) + end + + # Neighbours list construction (about atom `i`) + idxs, vecs, cells, species = _neighbours(i, atoms, r) + + # Get the bond vector between atoms i & j; where i is in the origin cell & j resides + # in either i) closest periodic image, or ii) that specified by `image` if provided. + if isnothing(image) + # Identify the shortest i→j vector account for PBC. + idx = _locate_minimum_image(j, idxs, vecs) + rr0 = vecs[idx] + else + @assert length(image) == 3 "image must be a vector of length three" + # Find the vector between atom i in the origin cell and atom j in cell `image`. + idx = _locate_target_image(j, idxs, cells, image) + if idx != 0 + rr0 = vecs[idx] + else # Special case where the cutoff was too short to catch the desired i→j bond. + # In this case we must calculate rr0 manually. + rr0 = atoms.X[j] - atoms.X[i] + (adjoint(image .* atoms.pbc) * atoms.cell).parent + end + end + + # The i→j bond vector must be removed from `vecs` so that it does not get treated as + # an environmental atom in the for loop later on. This operation is done even if the + # `idx==0` to maintain type stability. + @views vecs_no_bond = vecs[1:end .!= idx] + @views species_no_bond = species[1:end .!= idx] + + # `BondState` entity vector + states = Vector{BondState{typeof(rr0), Bool}}(undef, length(vecs_no_bond) + 1) + + # Construct the bond vector state; i.e where `bond=true` + states[1] = BondState(atoms.Z[j],atoms.Z[i],atoms.Z[j],rr0, rr0, true) + + @static if BOND_ORIGIN_AT_MIDPOINT + # As the mid-point of the bond is used as the origin an offset is needed to shift + # vectors so they're relative to the bond's midpoint and not atom `i`. + offset = rr0 * 0.5 + end + + # Construct the environmental atom states; i.e. where `bond=false`. + for (k, v⃗) in enumerate(vecs_no_bond) + @static if BOND_ORIGIN_AT_MIDPOINT + # Offset the positions so that they are relative to the bond's midpoint. + states[k+1] = BondState{typeof(rr0), Bool}(species_no_bond[k], atoms.Z[i], atoms.Z[j], v⃗ - offset, rr0, false) + else + states[k+1] = BondState{typeof(rr0), Bool}(species_no_bond[k], atoms.Z[i], atoms.Z[j], v⃗, rr0, false) + end + + end + + # Cull states outside of the bond envelope using the envelope's filter operator. This + # task is performed manually here in an effort to reduce run time and memory usage. + @views mask = _inner_evaluate.(Ref(envelope), states[2:end]) .!= 0.0 + @views n = sum(mask) + 1 + @views states[2:n] = states[2:end][mask] + + return states[1:n] + +end + +# Commonly one will need to collect multiple states rather than single states on their +# own. Hence the `get_state[s]` functions. These functions have been tagged for internal +# use only until they can be polished up. + +""" + _get_states(block_idxs, atoms, envelope[, images]) + +Get the states describing the environments about a collection of bonds as defined by the +block index list `block_idxs`. This is effectively just a fancy wrapper for `get_state'. + +# Arguments +- `block_idxs`: atomic index matrix in which the first & second rows specify the indices of the + two "bonding" atoms. The third row, if present, is used to index `images` to collect + cell in which the second atom lies. +- `atoms`: the `Atoms` object in which that atom pair resides. +- `envelope`: an envelope specifying the volume to consider when constructing the states. +- `images`: Cell translation index lookup list, this is only relevant when `block_idxs` + supplies and cell index value. The cell translation index for the iᵗʰ state will be + taken to be `images[block_indxs[i, 3]]`. + +# Returns +- `bond_states::Vector{::Vector{::BondState}}`: a vector providing the requested bond states. + +# Developers Notes +This is currently set to private until it is cleaned up. + +""" +function _get_states(block_idxs::BlkIdx, atoms::Atoms{T}, envelope::CylindricalBondEnvelope, + images::Union{AbstractMatrix{I}, Nothing}=nothing) where {I, T} + if isnothing(images) + if size(block_idxs, 1) == 3 && any block_idxs[3, :] != 1 + throw(ArgumentError("`idxs` provides non-origin cell indices but no + `images` argument was given!")) + end + return get_state.(block_idxs[1, :], block_idxs[2, :], Ref(atoms), Ref(envelope))::Vector{Vector{BondState{SVector{3, T}, Bool}}} + else + # If size(block_idxs,1) == 2, i.e. no cell index is supplied then this will error out. + # Thus not manual error handling is required. If images are supplied then block_idxs + # must contain the image index. + return get_state.( + block_idxs[1, :], block_idxs[2, :], Ref(atoms), + Ref(envelope), eachcol(images[:, block_idxs[3, :]]))::Vector{Vector{BondState{SVector{3, T}, Bool}}} + end +end + + +""" + _get_states(block_idxs, atoms[; r=16.0]) + +Get states describing the environments around each atom block specified in `block_idxs`. +Note that `block_idxs` is assumed to contain only on-site blocks. This is just a wrapper +for `get_state'. + +# Developers Notes +This is currently set to private until it is cleaned up. + +""" +function _get_states(block_idxs::BlkIdx, atoms::Atoms{T}; r=16.0) where T + if @views block_idxs[1, :] != block_idxs[2, :] + throw(ArgumentError( + "The supplied `block_idxs` represent a hetroatomic interaction. But the function + called is for retrieving homoatomic states.")) + end + # Type ambiguities in the JuLIP.Atoms structure means that Julia cannot determine the + # function's return type; specifically the value type of the static vector. Thus some + # pseudo type hard coding must be done here. + return get_state.(block_idxs[1, :], (atoms,); r=r)::Vector{Vector{AtomState{SVector{3, T}}}} +end + + + +# ╭────────┬──────────────────────────╮ +# │ States │ Factory Helper Functions │ +# ╰────────┴──────────────────────────╯ + + +""" + _neighbours(i, atoms, r) + +Identify and return information about atoms neighbouring atom `i` in system `atoms`. + +# Arguments +- `i::Int`: index of the atom for which the neighbour list is to be constructed. +- `atoms::Atoms`: system in which atom `i` is located. +- `r::AbstractFloat`: cutoff distance to for the neighbour list. Due to the effects off + cacheing this should be treated as if it were a lower bounds for the cutoff rather than + the cutoff itself. + +# Returns +- `idxs`: atomic index of each neighbour. +- `vecs`: distance vector to each neighbour. +- `cells`: index specifying the cell in which the neighbouring atom resides. + +# Warnings +Due to the effects of caching there is a high probably that the returned neighbour list +will contain neighbours at distances greater than `r`. + +""" +function _neighbours(i::Integer, atoms::Atoms, r::AbstractFloat) + pair_list = JuLIP.neighbourlist(atoms, r; fixcell=false) + return [ NeighbourLists.neigss(pair_list, i)..., JuLIP.Potentials.neigsz(pair_list,atoms,i)[3] ] +end + + +""" + _locate_minimum_image(j, idxs, vecs) + +Index of the closest `j` neighbour accounting for periodic boundary conditions. + +# Arguments +- `j::Integer`: Index of the atom for for whom the minimum image is to be identified. +- `idxs::Vector{::Integer}`: Integers specifying the indices of the atoms to two which + the distances in `vecs` correspond. +- `vecs::Vector{SVector{3, ::AbstractFloat}}`: Vectors between the the source atom and + the target atom. + +# Returns +- `index::Integer`: an index `k` for which `vecs[k]` will yield the vector between the + source atom and the closest periodic image of atom `j`. + +# Notes +If multiple minimal vectors are found, then the first one will be returned. + +# Todo +- This will error out when the cutoff distance is lower than the bond distance. While such + an occurrence is unlikely in smaller cells it will no doubt occur in larger ones. + +""" +function _locate_minimum_image(j::Integer, idxs::AbstractVector{<:Integer}, vecs::AbstractVector{<:AbstractVector{<:AbstractFloat}}) + # Locate all entries in the neighbour list that correspond to atom `j` + js = findall(==(j), idxs) + if length(js) == 0 + # See the "Todo" section in the docstring. + error("Neighbour not in range") + end + + # Identify which image of atom `j` is closest + return js[findmin(norm, vecs[js])[2]] +end + +""" + _locate_target_image(j, idxs, images, image) + +Search through the neighbour list for atoms with the atomic index `j` that reside in +the specified `image` and return its index. If no such match is found, as can happen +when the cutoff distance is too short, then an index of zero is returned. + +# Arguments +- `j`: index of the desired neighbour. +- `idxs`: atomic indices of atoms in the neighbour list. +- `images`: image in which the neighbour list atoms reside. +- `image`: image in which the target neighbour should reside. + +# Returns +- `idx::Int`: index of the first entry in the neighbour list representing an atom with the + atom index `j` residing in the image `image`. Zero if no matches are found. + +# Notes +The `images` argument is set vector of vectors here as this is represents the type returned +by the neighbour list constructor. Blocking other types prevents any misunderstandings. + +# Todo: +- Test for type instabilities + +""" +function _locate_target_image(j::I, idxs::AbstractVector{I}, images::AbstractVector{<:AbstractVector{I}}, image::AbstractVector{I})::I where I<:Integer + js = findall(==(j), idxs) + idx = findfirst(i -> all(i .== image), images[js]) + return isnothing(idx) ? zero(I) : js[idx] +end + + +# ╭────────┬───────────╮ +# │ States │ Overrides │ +# ╰────────┴───────────╯ +# Local override to account for the `BondState` field `be::Symbol` being replaced with the +# field `bond::Bool`. +function _inner_evaluate(env::BondEnvelope, state::BondState) + if state.bond + return _evaluate_bond(env, state) + else + return _evaluate_env(env, state) + end +end + +end diff --git a/examples/H2O/test_300K_4_500K.jl b/examples/H2O/test_300K_4_500K.jl new file mode 100644 index 0000000..64455c3 --- /dev/null +++ b/examples/H2O/test_300K_4_500K.jl @@ -0,0 +1,66 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-500K_3.h5" +output_path = "./Result/H_H2O_2_300K_4_500K" +model_path = "./Result/H_H2O_2_300K_rcut_10/dyn-wd-300K_3.bin" +nsamples = 512 #5000 +mkpath(output_path) + +@info "Constructing Model" + +model = deserialize(model_path) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] # train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] # train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/test_500K_4_300K.jl b/examples/H2O/test_500K_4_300K.jl new file mode 100644 index 0000000..fb14c4f --- /dev/null +++ b/examples/H2O/test_500K_4_300K.jl @@ -0,0 +1,66 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-300K_3.h5" +output_path = "./Result/H_H2O_2_500K_4_300K" +model_path = "./Result/H_H2O_2_500K_rcut_10/dyn-wd-500K_3.bin" +nsamples = 512 #5000 +mkpath(output_path) + +@info "Constructing Model" + +model = deserialize(model_path) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] # train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] # train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_H_H2O_1_rcut_10.jl b/examples/H2O/train_H_H2O_1_rcut_10.jl new file mode 100644 index 0000000..f9913ae --- /dev/null +++ b/examples/H2O/train_H_H2O_1_rcut_10.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/H2O_H_aims.h5" +output_path = "./Result/H_H2O_1_rcut_10" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "H" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(10.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(10.), #(10.), + # Environmental cutoff radius + GlobalParams(5.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_H_H2O_1_rcut_6.jl b/examples/H2O/train_H_H2O_1_rcut_6.jl new file mode 100644 index 0000000..4439302 --- /dev/null +++ b/examples/H2O/train_H_H2O_1_rcut_6.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/H2O_H_aims.h5" +output_path = "./Result/H_H2O_1_rcut_6" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "H" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(6.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(6.), #(10.), + # Environmental cutoff radius + GlobalParams(3.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_H_H2O_2_300K_rcut_10.jl b/examples/H2O/train_H_H2O_2_300K_rcut_10.jl new file mode 100644 index 0000000..206f47b --- /dev/null +++ b/examples/H2O/train_H_H2O_2_300K_rcut_10.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-300K_3.h5" +output_path = "./Result/H_H2O_2_300K_rcut_10" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "H" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(10.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(10.), #(10.), + # Environmental cutoff radius + GlobalParams(5.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_H_H2O_2_300K_rcut_6.jl b/examples/H2O/train_H_H2O_2_300K_rcut_6.jl new file mode 100644 index 0000000..975f24c --- /dev/null +++ b/examples/H2O/train_H_H2O_2_300K_rcut_6.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-300K_3.h5" +output_path = "./Result/H_H2O_2_300K_rcut_6" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "H" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(6.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(6.), #(10.), + # Environmental cutoff radius + GlobalParams(3.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_H_H2O_2_500K_rcut_10.jl b/examples/H2O/train_H_H2O_2_500K_rcut_10.jl new file mode 100644 index 0000000..486c578 --- /dev/null +++ b/examples/H2O/train_H_H2O_2_500K_rcut_10.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-500K_3.h5" +output_path = "./Result/H_H2O_2_500K_rcut_10" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "H" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(10.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(10.), #(10.), + # Environmental cutoff radius + GlobalParams(5.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_H_H2O_2_500K_rcut_6.jl b/examples/H2O/train_H_H2O_2_500K_rcut_6.jl new file mode 100644 index 0000000..50ba15e --- /dev/null +++ b/examples/H2O/train_H_H2O_2_500K_rcut_6.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-500K_3.h5" +output_path = "./Result/H_H2O_2_500K_rcut_6" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "H" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(6.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(6.), #(10.), + # Environmental cutoff radius + GlobalParams(3.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_S_H2O_1_rcut_6.jl b/examples/H2O/train_S_H2O_1_rcut_6.jl new file mode 100644 index 0000000..6f85cf4 --- /dev/null +++ b/examples/H2O/train_S_H2O_1_rcut_6.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/H2O_H_aims.h5" +output_path = "./Result/S_H2O_1_rcut_6" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "S" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(1), #(2) + # Maximum polynomial degree + GlobalParams(12), #(6), + # Environmental cutoff radius + GlobalParams(6.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(12), #(6), + # Bond cutoff radius + GlobalParams(6.), #(10.), + # Environmental cutoff radius + GlobalParams(3.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_S_H2O_2_300K_rcut_10.jl b/examples/H2O/train_S_H2O_2_300K_rcut_10.jl new file mode 100644 index 0000000..4f64804 --- /dev/null +++ b/examples/H2O/train_S_H2O_2_300K_rcut_10.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-300K_3.h5" +output_path = "./Result/S_H2O_2_300K_rcut_10" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "S" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(1), #(2), + # Maximum polynomial degree + GlobalParams(12), #(6), + # Environmental cutoff radius + GlobalParams(10.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(12), #(6), + # Bond cutoff radius + GlobalParams(10.), #(10.), + # Environmental cutoff radius + GlobalParams(5.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_S_H2O_2_500K_rcut_10.jl b/examples/H2O/train_S_H2O_2_500K_rcut_10.jl new file mode 100644 index 0000000..17b1c3e --- /dev/null +++ b/examples/H2O/train_S_H2O_2_500K_rcut_10.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-500K_3.h5" +output_path = "./Result/S_H2O_2_500K_rcut_10" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "S" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(1), #(2), + # Maximum polynomial degree + GlobalParams(12), #(6), + # Environmental cutoff radius + GlobalParams(10.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(12), #(6), + # Bond cutoff radius + GlobalParams(10.), #(10.), + # Environmental cutoff radius + GlobalParams(5.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_dm_H2O_2_300K_rcut_10.jl b/examples/H2O/train_dm_H2O_2_300K_rcut_10.jl new file mode 100644 index 0000000..cb1dc42 --- /dev/null +++ b/examples/H2O/train_dm_H2O_2_300K_rcut_10.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-300K_3.h5" +output_path = "./Result/dm_H2O_2_300K_rcut_10" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "dm" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(10.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(10.), #(10.), + # Environmental cutoff radius + GlobalParams(5.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_dm_H2O_2_300K_rcut_14_n_512.jl b/examples/H2O/train_dm_H2O_2_300K_rcut_14_n_512.jl new file mode 100644 index 0000000..55de234 --- /dev/null +++ b/examples/H2O/train_dm_H2O_2_300K_rcut_14_n_512.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-300K_3.h5" +output_path = "./Result/dm_H2O_2_300K_rcut_14_n_512" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "dm" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(14.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(14.), #(10.), + # Environmental cutoff radius + GlobalParams(7.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_dm_H2O_2_300K_rcut_14_n_768.jl b/examples/H2O/train_dm_H2O_2_300K_rcut_14_n_768.jl new file mode 100644 index 0000000..b5dd0e4 --- /dev/null +++ b/examples/H2O/train_dm_H2O_2_300K_rcut_14_n_768.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-300K_3.h5" +output_path = "./Result/dm_H2O_2_300K_rcut_14_n_768" +nsamples = 768 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "dm" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(14.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(14.), #(10.), + # Environmental cutoff radius + GlobalParams(7.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_dm_H2O_2_500K_rcut_10.jl b/examples/H2O/train_dm_H2O_2_500K_rcut_10.jl new file mode 100644 index 0000000..c626a44 --- /dev/null +++ b/examples/H2O/train_dm_H2O_2_500K_rcut_10.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-500K_3.h5" +output_path = "./Result/dm_H2O_2_500K_rcut_10" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "dm" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(10.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(10.), #(10.), + # Environmental cutoff radius + GlobalParams(5.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_dm_H2O_2_500K_rcut_14_n_512.jl b/examples/H2O/train_dm_H2O_2_500K_rcut_14_n_512.jl new file mode 100644 index 0000000..358f929 --- /dev/null +++ b/examples/H2O/train_dm_H2O_2_500K_rcut_14_n_512.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-500K_3.h5" +output_path = "./Result/dm_H2O_2_500K_rcut_14_n_512" +nsamples = 512 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "dm" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(14.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(14.), #(10.), + # Environmental cutoff radius + GlobalParams(7.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/examples/H2O/train_dm_H2O_2_500K_rcut_14_n_768.jl b/examples/H2O/train_dm_H2O_2_500K_rcut_14_n_768.jl new file mode 100644 index 0000000..ca4eecb --- /dev/null +++ b/examples/H2O/train_dm_H2O_2_500K_rcut_14_n_768.jl @@ -0,0 +1,139 @@ +using BenchmarkTools, Serialization, Random +using Distributed, SlurmClusterManager +addprocs(28) +@everywhere begin + using ACEhamiltonians, HDF5, Serialization + using Statistics + using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma + using JuLIP: Atoms + using PeriodicTable + using Statistics + include("./utils.jl") +end + +# ----------------------- +# |***general setting***| +# ----------------------- + +database_path = "./Data/Datasets/dyn-wd-500K_3.h5" +output_path = "./Result/dm_H2O_2_500K_rcut_14_n_768" +nsamples = 768 #5200 +mkpath(output_path) + +@info "Constructing Model" + +# Provide a label for the model (this should be H, S or dm) +model_type = "dm" + +# Define the basis definition +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) + +# On site parameter deceleration +on_site_parameters = OnSiteParaSet( + # Maximum correlation order + GlobalParams(2), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Environmental cutoff radius + GlobalParams(14.), #(10.), + # Scaling factor "r₀" + GlobalParams(0.9) #(2.5) +) + +# Off-site parameter deceleration +off_site_parameters = OffSiteParaSet( + # Maximum correlation order + GlobalParams(1), + # Maximum polynomial degree + GlobalParams(14), #(6), + # Bond cutoff radius + GlobalParams(14.), #(10.), + # Environmental cutoff radius + GlobalParams(7.), #(5.), +) + +# initialisation +model = Model(basis_definition, on_site_parameters, off_site_parameters, model_type) + +# ----------------------- +# |***fitting model***| +# ----------------------- +@info "Fitting Model" + +# Names of the systems to which the model should be fitted +target_systems = h5open(database_path) do database keys(database) end +rng = MersenneTwister(1234) +@assert nsamples <= length(target_systems) "nsample should be smaller or equal to nsample" +target_systems = shuffle(rng, target_systems)[begin:nsamples] +target_systems = [target_systems[i:5:end] for i in 1:5] +train_systems = vcat(target_systems[1:end-1]...) +test_systems = target_systems[end] + +# Open up the HDF5 database within which the target data is stored +h5open(database_path) do database + # Load the target system(s) for fitting + systems = [database[system] for system in train_systems] + + # Perform the fitting operation + fit!(model, systems; recentre=true) +end + +filename = split(basename(database_path), ".")[begin] +serialize(joinpath(output_path, filename*".bin"), model) + +# model = deserialize(joinpath(output_path, filename*".bin")) + +################################# test set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in test_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in test_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_test.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_test.jls"), "w") do file + serialize(file, data_dict) +end + + +################################# training set ######################################### +#prediction +atoms = h5open(database_path) do database + [load_atoms(database[system]) for system in train_systems] + end +images = cell_translations.(atoms, Ref(model)) +predicted = predict.(Ref(model), atoms, images) + +#groud truth data +get_matrix = Dict( # Select an appropriate function to load the target matrix +"H"=>load_hamiltonian, "S"=>load_overlap, "dm"=>load_density_matrix, +"Hg"=>load_hamiltonian_gamma, "Sg"=>load_overlap_gamma, "dmg"=>load_density_matrix_gamma)[model.label] +gt = h5open(database_path) do database + [get_matrix(database[system]) for system in train_systems] +end + +matrix_dict = Dict("predicted"=>predicted, "gt"=>gt, "atoms"=>atoms) +open(joinpath(output_path, "matrix_dict_train.jls"), "w") do file + serialize(file, matrix_dict) +end + +error = predicted-gt +data_dict = get_error_dict(error, atoms, model) +open(joinpath(output_path, "data_dict_train.jls"), "w") do file + serialize(file, data_dict) +end \ No newline at end of file diff --git a/tools/utils.jl b/examples/H2O/utils.jl similarity index 100% rename from tools/utils.jl rename to examples/H2O/utils.jl diff --git a/utils/intra_vs_inter.jl b/utils/intra_vs_inter.jl new file mode 100644 index 0000000..bfc9515 --- /dev/null +++ b/utils/intra_vs_inter.jl @@ -0,0 +1,150 @@ +using Serialization +using JuLIP +using LinearAlgebra: norm, pinv +using StatsBase +using Statistics +using Plots + +# model_path = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_2_500K_rcut_10/dyn-wd-500K_3.bin" + +# bond_cutoff = 1. +num_bond = 2 + +Hartree2meV=27211.4 + +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) +n_1 = sum([2*i+1 for i in basis_definition[1]]) +n_8 = sum([2*i+1 for i in basis_definition[8]]) +basis_num = Dict(1=>n_1, 8=>n_8) +mol_basis_num = n_1*2 + n_8 + +data_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_2_500K_4_300K/matrix_dict.jls" +data_dict = open(data_file, "r") do file + deserialize(file) +end + +predicts = data_dict["predicted"].*Hartree2meV +gts = data_dict["gt"].*Hartree2meV +atoms_list = data_dict["atoms"] + +# model = deserialize(model_path) +# images_list = cell_translations.(atoms_list, Ref(model)) +# pairs = [ for image in images for atom in atoms ] + +errors_list = [abs.(predict-gt) for (predict, gt) in zip(predicts, gts)] +error_intra = cat([cat(errors[1:mol_basis_num, 1:mol_basis_num], errors[1+mol_basis_num:end, 1+mol_basis_num:end], dims=3) for errors in errors_list]..., dims=3) +error_inter = cat([cat(errors[1+mol_basis_num:end, 1:mol_basis_num], errors[1:mol_basis_num, 1+mol_basis_num:end], dims=3) for errors in errors_list]..., dims=3) +mae_intra = mean(error_intra) +mae_inter = mean(error_inter) + +gts_intra = cat([cat(gt[1:mol_basis_num, 1:mol_basis_num], gt[1+mol_basis_num:end, 1+mol_basis_num:end], dims=3) for gt in gts]..., dims=3) +gts_inter = cat([cat(gt[1+mol_basis_num:end, 1:mol_basis_num], gt[1:mol_basis_num, 1+mol_basis_num:end], dims=3) for gt in gts]..., dims=3) +gt_intra = std(gts_intra) +gt_inter = std(gts_inter) + +mae_norm_intra = mae_intra/gt_intra +mae_norm_inter = mae_inter/gt_inter + +mae = mean(cat(errors_list..., dims=3)) +println("mae, mae_intra, mae_inter, mae_norm_intra, mae_norm_inter: $mae, $mae_intra, $mae_inter, $mae_norm_intra, $mae_norm_inter") + +mae_intra_plot = dropdims(mean(error_intra, dims=3), dims=3) +mae_inter_plot = dropdims(mean(error_inter, dims=3), dims=3) + +p=heatmap(mae_intra_plot, size=(800, 750), color=:jet, title="Intra molecule MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +savefig(joinpath(dirname(data_file), "mae_intra_plot.png")) +display(p) + +p=heatmap(mae_inter_plot, size=(800, 750), color=:jet, title="Inter molecule MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +savefig(joinpath(dirname(data_file), "mae_inter_plot.png")) +display(p) + + + +mae_norm_intra = dropdims(mean(error_intra, dims=3)./std(gts_intra, dims=3), dims=3) +mae_norm_inter = dropdims(mean(error_inter, dims=3)./std(gts_inter, dims=3), dims=3) + +p=heatmap(mae_norm_intra, size=(800, 750), color=:jet, title="Normalized intra molecule MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold)) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +p1=plot(p, right_margin=13Plots.mm) +savefig(p1, joinpath(dirname(data_file), "mae_norm_intra_plot.png")) +display(p) + +p=heatmap(mae_norm_inter, size=(800, 750), color=:jet, title="Normalized inter molecule MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold), clims=(0, 1)) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +p1=plot(p, right_margin=10Plots.mm) +savefig(p1, joinpath(dirname(data_file), "mae_norm_inter_plot.png")) +display(p) + + + + +# for (atoms, images, errors) in zip(atoms_list, images_list, errors_list) +# shift_vectors = collect(eachrow(images' * atoms.cell)) +# pos = hcat(collect.(atoms.X)...) +# dis = reshape(pos, size(pos, 1), size(pos, 2), 1) .- reshape(pos, size(pos, 1), 1, size(pos, 2)) +# dis = reshape(dis, (1, size(dis)...)) +# dis = dis .- reshape(hcat(collect.(shift_vectors)...)', length(shift_vectors), 3, 1, 1) +# dis = dropdims(mapslices(norm, dis; dims=2); dims=2) +# dis = permutedims(dis, (3,2,1)) +# for i_idx, (X₁, Z₁) in enumerate(atoms.X, atoms.Z) +# if Z₁ == 8 +# H_OO = extract_block(errors, i_idx, i_idx, findall([images[:,i]==[0, 0, 0] for i in range(1,size(images,2))])[1], basis_num, atoms) +# idx_line = partialsortperm(dis[i_idx, :, :][:], 1:3; rev=false)[2:end] +# j_idxes, image_idxes = rem.(idx_line, Ref(length(atoms.Z))), Int.(floor.(idx_line./Ref(length(atoms.Z)))) +# O_indices = i_idx +# H_indices = j_idxes +# i_idxes = [i_idx for i in 1: num_bond] +# for j_idx in j_idxes +# push!(partialsortperm(dis[j_idx, i_idx, :], 1, rev=false), image_idxes) +# push!(j_idx, i_idxes) +# push!(i_idx, j_idxes) +# push!(partialsortperm(dis[j_idx, i_idx, :], 1, rev=false), image_idxes) + + + + +# for (X₂, Z₂) in zip(atoms.X, atoms.Z) +# if Z₂==1 and + +# norm.( X₁ - X₂ + shift_vectors) + + + + + +# mask = norm.(atoms.X - atoms.X + shift_vectors[block_idxs[3, :]]) .<= distance + + + + + + + +# function extract_block(matrix::Array{Float64, 3}, i_idx::Int, j_idx::Int, image_idx::Int, basis_num::Dict, atoms::Atoms) +# idx_begin = vcat([1],cumsum([basis_num[i] for i in atoms.Z])[1:end-1].+1) +# idx_end = cumsum([basis_num[i] for i in atoms.Z]) +# return matrix[idx_begin[i_idx]: idx_end[i_idx], idx_begin[j_idx]: idx_end[j_idx], image_idx] +# end + + + + diff --git a/utils/monomer.jl b/utils/monomer.jl new file mode 100644 index 0000000..5e86cb3 --- /dev/null +++ b/utils/monomer.jl @@ -0,0 +1,180 @@ +using Serialization +using JuLIP +using LinearAlgebra: norm, pinv +using StatsBase +using Statistics +using Plots + +# model_path = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_2_500K_rcut_10/dyn-wd-500K_3.bin" + +# bond_cutoff = 1. +num_bond = 2 + +Hartree2meV=27211.4 + +basis_definition = Dict(1=>[0, 0, 0, 0, 1, 1, 2], 8=>[0, 0, 0, 1, 1, 2]) +n_1 = sum([2*i+1 for i in basis_definition[1]]) +n_8 = sum([2*i+1 for i in basis_definition[8]]) +basis_num = Dict(1=>n_1, 8=>n_8) +mol_basis_num = n_1*2 + n_8 + +data_file = "/home/c/chenqian3/ACEhamiltonians/H2O_PASHA/H2O_Pasha/Result/H_H2O_1_rcut_6/matrix_dict_test.jls" +data_dict = open(data_file, "r") do file + deserialize(file) +end + +predicts = data_dict["predicted"].*Hartree2meV +gts = data_dict["gt"].*Hartree2meV +atoms_list = data_dict["atoms"] + +# model = deserialize(model_path) +# images_list = cell_translations.(atoms_list, Ref(model)) +# pairs = [ for image in images for atom in atoms ] + +# errors_list = [abs.(predict-gt) for (predict, gt) in zip(predicts, gts)] +# error_intra = cat([cat(errors[1:mol_basis_num, 1:mol_basis_num], errors[1+mol_basis_num:end, 1+mol_basis_num:end], dims=3) for errors in errors_list]..., dims=3) +# error_inter = cat([cat(errors[1+mol_basis_num:end, 1:mol_basis_num], errors[1:mol_basis_num, 1+mol_basis_num:end], dims=3) for errors in errors_list]..., dims=3) +# mae_intra = mean(error_intra) +# mae_inter = mean(error_inter) + +# gts_intra = cat([cat(gt[1:mol_basis_num, 1:mol_basis_num], gt[1+mol_basis_num:end, 1+mol_basis_num:end], dims=3) for gt in gts]..., dims=3) +# gts_inter = cat([cat(gt[1+mol_basis_num:end, 1:mol_basis_num], gt[1:mol_basis_num, 1+mol_basis_num:end], dims=3) for gt in gts]..., dims=3) +# gt_intra = std(gts_intra) +# gt_inter = std(gts_inter) + +errors = abs.(cat((gts.-predicts)..., dims=3)) # cat([abs.(predict-gt) for (predict, gt) in zip(predicts, gts)]..., dims=3) +gts = cat(gts..., dims=3) +mae = mean(errors) +gt = std(gts) +mae_norm = mae/gt + +mae = mean(errors) +println("mae, mae_norm: $mae, $mae_norm") + +# mae_norm_intra = mae_intra/gt_intra +# mae_norm_inter = mae_inter/gt_inter + +mae_plot = dropdims(mean(errors, dims=3), dims=3) +mae_norm_plot = dropdims(mean(errors, dims=3)./std(gts, dims=3), dims=3) + +p=heatmap(mae_plot, size=(800, 750), color=:jet, title="Monomer MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +savefig(joinpath(dirname(data_file), "mae.png")) +display(p) + +p=heatmap(mae_norm_plot, size=(800, 750), color=:jet, title="Normalized monomer MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], + ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold)) +vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +hline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +p1=plot(p, right_margin=13Plots.mm) +savefig(p1, joinpath(dirname(data_file), "mae_norm.png")) +display(p) + + + +# mae_intra_plot = dropdims(mean(error_intra, dims=3), dims=3) +# mae_inter_plot = dropdims(mean(error_inter, dims=3), dims=3) + +# p=heatmap(mae_intra_plot, size=(800, 750), color=:jet, title="Intra molecule MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], +# ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +# vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# hline!(p, [14.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# hline!(p, [29.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# savefig(joinpath(dirname(data_file), "mae_intra_plot.png")) +# display(p) + +# p=heatmap(mae_inter_plot, size=(800, 750), color=:jet, title="Inter molecule MAE (meV)", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], +# ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold),) +# vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# hline!(p, [14.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# hline!(p, [29.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# savefig(joinpath(dirname(data_file), "mae_inter_plot.png")) +# display(p) + + + +# mae_norm_intra = dropdims(mean(error_intra, dims=3)./std(gts_intra, dims=3), dims=3) +# mae_norm_inter = dropdims(mean(error_inter, dims=3)./std(gts_inter, dims=3), dims=3) + +# p=heatmap(mae_norm_intra, size=(800, 750), color=:jet, title="Normalized intra molecule MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], +# ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold)) +# vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# hline!(p, [14.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# hline!(p, [29.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# p1=plot(p, right_margin=13Plots.mm) +# savefig(p1, joinpath(dirname(data_file), "mae_norm_intra_plot.png")) +# display(p) + +# p=heatmap(mae_norm_inter, size=(800, 750), color=:jet, title="Normalized intra molecule MAE", titlefont=font(20, "times", :bold), xticks=([7, 21.5, 36.5], +# ["O", "H", "H"]), yticks=([7, 21.5, 36.5], ["O", "H", "H"]), tickfont=font(18, "Courier", :bold), clims=(0, 1)) +# vline!(p, [14.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# vline!(p, [29.5], color=:grey, linestyle=:dash, linewidth=2, label=false) +# hline!(p, [14.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# hline!(p, [29.5], color=:grey, linestyle=:dot, linewidth=2, label=false) +# p1=plot(p, right_margin=10Plots.mm) +# savefig(p1, joinpath(dirname(data_file), "mae_norm_inter_plot.png")) +# display(p) + + + + +# for (atoms, images, errors) in zip(atoms_list, images_list, errors_list) +# shift_vectors = collect(eachrow(images' * atoms.cell)) +# pos = hcat(collect.(atoms.X)...) +# dis = reshape(pos, size(pos, 1), size(pos, 2), 1) .- reshape(pos, size(pos, 1), 1, size(pos, 2)) +# dis = reshape(dis, (1, size(dis)...)) +# dis = dis .- reshape(hcat(collect.(shift_vectors)...)', length(shift_vectors), 3, 1, 1) +# dis = dropdims(mapslices(norm, dis; dims=2); dims=2) +# dis = permutedims(dis, (3,2,1)) +# for i_idx, (X₁, Z₁) in enumerate(atoms.X, atoms.Z) +# if Z₁ == 8 +# H_OO = extract_block(errors, i_idx, i_idx, findall([images[:,i]==[0, 0, 0] for i in range(1,size(images,2))])[1], basis_num, atoms) +# idx_line = partialsortperm(dis[i_idx, :, :][:], 1:3; rev=false)[2:end] +# j_idxes, image_idxes = rem.(idx_line, Ref(length(atoms.Z))), Int.(floor.(idx_line./Ref(length(atoms.Z)))) +# O_indices = i_idx +# H_indices = j_idxes +# i_idxes = [i_idx for i in 1: num_bond] +# for j_idx in j_idxes +# push!(partialsortperm(dis[j_idx, i_idx, :], 1, rev=false), image_idxes) +# push!(j_idx, i_idxes) +# push!(i_idx, j_idxes) +# push!(partialsortperm(dis[j_idx, i_idx, :], 1, rev=false), image_idxes) + + + + +# for (X₂, Z₂) in zip(atoms.X, atoms.Z) +# if Z₂==1 and + +# norm.( X₁ - X₂ + shift_vectors) + + + + + +# mask = norm.(atoms.X - atoms.X + shift_vectors[block_idxs[3, :]]) .<= distance + + + + + + + +# function extract_block(matrix::Array{Float64, 3}, i_idx::Int, j_idx::Int, image_idx::Int, basis_num::Dict, atoms::Atoms) +# idx_begin = vcat([1],cumsum([basis_num[i] for i in atoms.Z])[1:end-1].+1) +# idx_end = cumsum([basis_num[i] for i in atoms.Z]) +# return matrix[idx_begin[i_idx]: idx_end[i_idx], idx_begin[j_idx]: idx_end[j_idx], image_idx] +# end + + + + diff --git a/utils/utils.jl b/utils/utils.jl new file mode 100644 index 0000000..814dcc2 --- /dev/null +++ b/utils/utils.jl @@ -0,0 +1,315 @@ +using BenchmarkTools, Serialization +using ACEhamiltonians, HDF5, Serialization +using Statistics +using ACEhamiltonians.DatabaseIO: load_hamiltonian_gamma, load_overlap_gamma, load_density_matrix_gamma +using JuLIP: Atoms +using PeriodicTable +using Statistics +using Plots + +# ------------------------------------ +# |***functions slicing the matrix***| +# ------------------------------------ +function slice_matrix(matrix::AbstractArray{<:Any, 3}, id::Union{NTuple{3, Int}, NTuple{4, Int}}, + atoms::Atoms, basis_def::BasisDef) + + blocks, block_idxs = locate_and_get_sub_blocks(matrix, id..., atoms, basis_def) + + return blocks + +end + + +# get the species and Azimuthal number +function get_az_species(basis_definition::BasisDef, id::Union{NTuple{3, Int}, NTuple{4, Int}}) + if length(id) == 3 + return (id[begin], basis_definition[id[begin]][id[end-1]], basis_definition[id[end-2]][id[end]]) + elseif length(id) == 4 + return (id[begin], id[2], basis_definition[id[begin]][id[end-1]], basis_definition[id[end-2]][id[end]]) + end +end + + +# return the metrics +function metrics(errors::Union{Vector{Array{Float64,3}}, Vector{<:Any}}, type::String) + if isempty(errors) + return nothing + end + errors_tensor = cat(errors..., dims=3) + if type == "MAE" + return mean(abs.(errors_tensor)) + elseif type == "RMSE" + return sqrt(mean(errors_tensor.^2)) + else + throw(ArgumentError("invalid metrics type, the matrics should be either MAE or RMSE")) + end +end + + +# return the error data dict +ob_idx = Dict(0=>"s", 1=>"p", 2=>"d", 3=>"f") + +function get_error_dict(errors::Vector{Array{Float64, 3}}, atomsv::Vector{Atoms{Float64}}, model::Model) + + basis_definition = model.basis_definition + + data_dict = Dict("MAE"=>Dict(), "RMSE"=>Dict()) + data_dict["MAE"]["all"] = mean(abs.(vcat([error[:] for error in errors]...))) + data_dict["RMSE"]["all"] = sqrt(mean(vcat([error[:] for error in errors]...).^2)) + + for site in ["on", "off"] + + submodels = site == "on" ? model.on_site_submodels : model.off_site_submodels + data_dict["MAE"][site] = Dict() + data_dict["RMSE"][site] = Dict() + + # all the sumodels + error_vector = [] + for id in keys(submodels) + error_block = slice_matrix.(errors, Ref(id), atomsv, Ref(basis_definition)) + push!(error_vector, [error[:] for error in error_block]) + end + error_vector = [error_vector[i][j] for i in 1:length(error_vector) for j in 1:length(error_vector[i])] + error_vector = vcat(error_vector...) + data_dict["MAE"][site]["all"] = mean(abs.(error_vector)) + data_dict["RMSE"][site]["all"] = sqrt(mean(error_vector.^2)) + + #submodels + data_dict["MAE"][site]["submodels"] = Dict() + data_dict["RMSE"][site]["submodels"] = Dict() + for id in keys(submodels) + error_block = slice_matrix.(errors, Ref(id), atomsv, Ref(basis_definition)) + data_dict["MAE"][site]["submodels"][id] = metrics(error_block, "MAE") + data_dict["RMSE"][site]["submodels"][id] = metrics(error_block, "RMSE") + end + + ids = collect(keys(submodels)) + + #Azimuthal number + data_dict["MAE"][site]["Azimuthal"] = Dict() + data_dict["RMSE"][site]["Azimuthal"] = Dict() + az_pairs = unique([(basis_definition[i[begin]][i[end-1]], basis_definition[i[end-2]][i[end]]) for i in ids]) + for (l₁, l₂) in az_pairs + error_block_collection = [] + for id in keys(submodels) + if (basis_definition[id[begin]][id[end-1]], basis_definition[id[end-2]][id[end]]) == (l₁, l₂) + error_block = slice_matrix.(errors, Ref(id), atomsv, Ref(basis_definition)) + push!(error_block_collection, error_block) + end + end + error_block_collection = [error_block_collection[i][j] for i in 1:length(error_block_collection) + for j in 1:length(error_block_collection[i])] + data_dict["MAE"][site]["Azimuthal"][join([ob_idx[l₁],ob_idx[l₂]])] = metrics(error_block_collection, "MAE") + data_dict["RMSE"][site]["Azimuthal"][join([ob_idx[l₁],ob_idx[l₂]])] = metrics(error_block_collection, "RMSE") + end + + #species Azimuthal + data_dict["MAE"][site]["Azimuthal_Species"] = Dict() + data_dict["RMSE"][site]["Azimuthal_Species"] = Dict() + az_species_pairs = unique(get_az_species.(Ref(basis_definition), ids)) + for pair in az_species_pairs + species = join([elements[i].symbol for i in pair[begin:end-2]]) + azimuthal = join([ob_idx[pair[end-1]],ob_idx[pair[end]]]) + error_block_collection = [] + for id in keys(submodels) + if get_az_species(basis_definition, id) == pair + error_block = slice_matrix.(errors, Ref(id), atomsv, Ref(basis_definition)) + push!(error_block_collection, error_block) + end + end + error_block_collection = [error_block_collection[i][j] for i in 1:length(error_block_collection) + for j in 1:length(error_block_collection[i])] + data_dict["MAE"][site]["Azimuthal_Species"][join([species,azimuthal])] = metrics(error_block_collection, "MAE") + data_dict["RMSE"][site]["Azimuthal_Species"][join([species,azimuthal])] = metrics(error_block_collection, "RMSE") + end + + end + + return data_dict + +end + + +# ------------------------------------ +# |*************plotting*************| +# ------------------------------------ + +function plot_save_single(x::T₁, y::T₂, x_label::String, y_label::String, label::Union{String, Matrix, Nothing}, + output_path::String, basen::String) where {T₁, T₂} + title = join([x_label, y_label, basen], "_") + i = sortperm(x) + x = x[i] + if typeof(y)<:Vector{Float64} + y = y[i] + elseif typeof(y)<:Vector{Vector{Float64}} + y = [y_sub[i] for y_sub in y] + else + error("the y is not in the right form") + end + plt = plot(x, y, title=title, xlabel=x_label, ylabel=y_label, label=label, legend_background_color=RGBA(1, 1, 1, 0)) + filename = joinpath(output_path, join([x_label, y_label, basen], "_")*".png") + savefig(plt, filename) +end + + +function plot_hyperparams(data_dict::Dict, x_label::String, output_path_figs::String) + + mkpath(output_path_figs) + + x_label in ["d_max", "r_cut"] ? nothing : throw(AssertionError("the x_label should be either d_max or r_cut")) + assess_type = collect(keys(data_dict)) + x = x_label == "d_max" ? [i[1] for i in assess_type] : [i[2] for i in assess_type] + + for y_label in ["MAE", "RMSE"] + + y_all = [i[y_label]["all"] for i in values(data_dict)] + label = nothing + basen = "all" + plot_save_single(x, y_all, x_label, y_label, label, output_path_figs, basen) + + for site in ["on", "off"] + y_all = [i[y_label][site]["all"] for i in values(data_dict)] + label = nothing + basen = join([site, "all"], "_") + plot_save_single(x, y_all, x_label, y_label, label, output_path_figs, basen) + + for type in ["Azimuthal", "Azimuthal_Species"] + # This step is merely to fix the typo in former get_error_dict function + if type == "Azimuthal_Species" + type = haskey(collect(values(data_dict))[1][y_label][site], "Azimuthal_Species") ? "Azimuthal_Species" : "Azimuthal_SPecies" + end + y_type = [i[y_label][site][type] for i in values(data_dict)] + label = reshape(collect(keys(y_type[1])), (1,:)) + y_type = [[y_type_sub[label_sub] for y_type_sub in y_type] for label_sub in label][:] + basen = join([site, type], "_") + plot_save_single(x, y_type, x_label, y_label, label, output_path_figs, basen) + end + + end + + end + +end + + +function plot_save_single_cross(x::T₁, y::T₂, x_label::String, y_label::String, label::Union{String, Matrix, Nothing}, + output_path::String, basen::String) where {T₁, T₂} + title = join([y_label, basen], "_") + i = sortperm(x) + x = x[i] + if typeof(y)<:Vector{Float64} + y = y[i] + elseif typeof(y)<:Vector{Vector{Float64}} + y = [y_sub[i] for y_sub in y] + else + error("the y is not in the right form") + end + plt = plot(x, y, title=title, xlabel=x_label, ylabel=y_label, label=label, legend_background_color=RGBA(1, 1, 1, 0)) + filename = joinpath(output_path, join([y_label, basen], "_")*".png") + savefig(plt, filename) +end + + +function plot_cross(data_dict::Dict, output_path_figs::String) + + mkpath(output_path_figs) + + x_label = "number_of_water_molecules_for_prediction" + assess_type = collect(keys(data_dict)) + model_sys_sizes = unique([i[1] for i in assess_type]) #datasize used for model training + pred_sys_sizes = unique([i[2] for i in assess_type]) #system size for prediction + + for y_label in ["MAE", "RMSE"] + + y_all = [[data_dict[(model_sys_size, pred_sys_size)][y_label]["all"] for pred_sys_size in pred_sys_sizes] + for model_sys_size in model_sys_sizes] + + label = reshape(model_sys_sizes, (1,:)) + basen = "all" + plot_save_single_cross(pred_sys_sizes, y_all, x_label, y_label, label, output_path_figs, basen) + + for site in ["on", "off"] + y_all = [[data_dict[(model_sys_size, pred_sys_size)][y_label][site]["all"] + for pred_sys_size in pred_sys_sizes] for model_sys_size in model_sys_sizes] + label = reshape(model_sys_sizes, (1,:)) + basen = join([site, "all"], "_") + plot_save_single_cross(pred_sys_sizes, y_all, x_label, y_label, label, output_path_figs, basen) + + for type in ["Azimuthal", "Azimuthal_Species"] + if type == "Azimuthal_Species" + type = haskey(collect(values(data_dict))[1][y_label][site], "Azimuthal_Species") ? "Azimuthal_Species" : "Azimuthal_SPecies" + end + for model_sys_size in model_sys_sizes + y_type = [data_dict[(model_sys_size, pred_sys_size)][y_label][site][type] + for pred_sys_size in pred_sys_sizes] + label = reshape(collect(keys(y_type[1])), (1,:)) + y_type = [[y_type_sub[label_sub] for y_type_sub in y_type] for label_sub in label][:] + basen = join([site, type, model_sys_size], "_") + plot_save_single_cross(pred_sys_sizes, y_type, x_label, y_label, label, output_path_figs, basen) + end + end + + end + + end + +end + + + + +function plot_save_single_size(x::T₁, y::T₂, x_label::String, y_label::String, label::Union{String, Matrix, Nothing}, + output_path::String, basen::String) where {T₁, T₂} + title = join([y_label, basen], "_") + i = sortperm(x) + x = x[i] + if typeof(y)<:Vector{Float64} + y = y[i] + elseif typeof(y)<:Vector{Vector{Float64}} + y = [y_sub[i] for y_sub in y] + else + error("the y is not in the right form") + end + plt = plot(x, y, title=title, xlabel=x_label, ylabel=y_label, label=label, legend_background_color=RGBA(1, 1, 1, 0)) + filename = joinpath(output_path, join([y_label, basen], "_")*".png") + savefig(plt, filename) +end + + +function plot_size(data_dict::Dict, output_path_figs::String) + + mkpath(output_path_figs) + + x = collect(keys(data_dict)) + x_label = "number of samples" + + for y_label in ["MAE", "RMSE"] + + y_all = [i[y_label]["all"] for i in values(data_dict)] + label = nothing + basen = "all" + plot_save_single_size(x, y_all, x_label, y_label, label, output_path_figs, basen) + + for site in ["on", "off"] + y_all = [i[y_label][site]["all"] for i in values(data_dict)] + label = nothing + basen = join([site, "all"], "_") + plot_save_single_size(x, y_all, x_label, y_label, label, output_path_figs, basen) + + for type in ["Azimuthal", "Azimuthal_Species"] + # This step is merely to fix the typo in former get_error_dict function + if type == "Azimuthal_Species" + type = haskey(collect(values(data_dict))[1][y_label][site], "Azimuthal_Species") ? "Azimuthal_Species" : "Azimuthal_SPecies" + end + y_type = [i[y_label][site][type] for i in values(data_dict)] + label = reshape(collect(keys(y_type[1])), (1,:)) + y_type = [[y_type_sub[label_sub] for y_type_sub in y_type] for label_sub in label][:] + basen = join([site, type], "_") + plot_save_single_size(x, y_type, x_label, y_label, label, output_path_figs, basen) + end + + end + + end + +end \ No newline at end of file