Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP : Sphericart Integration #81

Merged
merged 12 commits into from
Jan 16, 2024
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.2.10"
ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -19,7 +20,9 @@ ObjectPools = "658cac36-ff0f-48ad-967c-110375d98c9d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
SpheriCart = "5caf2b29-02d9-47a3-9434-5931c85ba645"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StrideArrays = "d1fa6d79-ef01-42a6-86c9-f7c551f8593b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -41,6 +44,9 @@ SpecialFunctions = "2.2"
StaticArrays = "1.5"
StrideArrays = "0.1.25"
julia = "1.8.0, 1.9.0, 1.10.0"
SpheriCart = "0.0.3"
BlockDiagonals = "0.1.42"
SparseArrays = "1.8"
LinearAlgebra = "1.8"
Printf = "1.8"
Random = "1.8"
Expand Down
74 changes: 74 additions & 0 deletions src/sphericalharmonics/scylm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using SpheriCart: compute, compute_with_gradients,
compute!, compute_with_gradients!,
SphericalHarmonics

export SCYlmBasis

struct SCYlmBasis{L, normalisation, static, T} <: SVecPoly4MLBasis
basis::SphericalHarmonics{L, normalisation, static, T}
@reqfields
end

maxL(basis::SCYlmBasis{L}) where {L} = L

Base.length(basis::SCYlmBasis) = sizeY(maxL(basis))

SCYlmBasis(maxL::Integer, T::Type=Float64) =
SCYlmBasis( SphericalHarmonics(maxL; normalisation = :L2,
static = maxL <= 15,
T = T))

SCYlmBasis(scsh::SphericalHarmonics) =
SCYlmBasis(scsh, _make_reqfields()...)

natural_indices(basis::SCYlmBasis) =
[ NamedTuple{(:l, :m)}(idx2lm(i)) for i = 1:length(basis) ]

_valtype(sh::SCYlmBasis{L, NRM, STATIC, T},
::Type{<: StaticVector{3, S}}) where {L, NRM, STATIC, T <: Real, S <: Real} =
promote_type(T, S)

_valtype(sh::SCYlmBasis{L, NRM, STATIC, T},
::Type{<: StaticVector{3, Hyper{S}}}) where {L, NRM, STATIC, T <: Real, S <: Real} =
promote_type(T, Hyper{S})

Base.show(io::IO, basis::SCYlmBasis{L, NRM, STATIC, T}) where {L, NRM, STATIC, T} =
print(io, "SCYlmBasis(L=$L)")

# ---------------------- Interfaces

function evaluate!(Y::AbstractArray, basis::SCYlmBasis, X::SVector{3})
Y_temp = reshape(Y, 1, :)
compute!(Y_temp, basis.basis, SA[X,])
return Y
end

function evaluate_ed!(Y::AbstractArray, dY::AbstractArray, basis::SCYlmBasis, X::SVector{3})
Y_temp = reshape(Y, 1, :)
dY_temp = reshape(dY, 1, :)
compute_with_gradients!(Y_temp, dY_temp, basis.basis, SA[X,])
return Y, dY
end

evaluate!(Y::AbstractArray,
basis::SCYlmBasis, X::AbstractVector{<: SVector{3}}) =
compute!(Y, basis.basis, X)

evaluate_ed!(Y::AbstractArray, dY::AbstractArray,
basis::SCYlmBasis, X::AbstractVector{<: SVector{3}}) =
compute_with_gradients!(Y,dY,basis.basis,X)

# rrule
function ChainRulesCore.rrule(::typeof(evaluate), basis::SCYlmBasis, X)
A, dX = evaluate_ed(basis, X)
function pb(∂A)
@assert size(∂A) == (length(X), length(basis))
T∂X = promote_type(eltype(∂A), eltype(dX))
∂X = similar(X, SVector{3, T∂X})
for i = 1:length(X)
∂X[i] = sum([∂A[i,j] * dX[i,j] for j = 1:length(dX[i,:])])
end
return NoTangent(), NoTangent(), ∂X
end
return A, pb
end
2 changes: 2 additions & 0 deletions src/sphericalharmonics/sphericalharmonics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ include("rylm.jl")
include("crlm.jl")
include("rrlm.jl")

include("scylm.jl")

const XlmBasis = Union{RYlmBasis, CYlmBasis, CRlmBasis, RRlmBasis}

"""
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Test
@testset "Real Spherical Harmonics" begin include("sphericalharmonics/test_rylm.jl"); end
@testset "Complex Solid Harmonics" begin include("sphericalharmonics/test_crlm.jl"); end
@testset "Real Solid Harmonics" begin include("sphericalharmonics/test_rrlm.jl"); end
@testset "Real Spherical Harmonics via SpheriCart" begin include("sphericalharmonics/test_scylm.jl"); end

# Quantum Chemistry
@testset "Atomic Orbitals Radials" begin include("test_atorbrad.jl"); end
Expand Down
187 changes: 187 additions & 0 deletions test/sphericalharmonics/test_scylm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
using LinearAlgebra, StaticArrays, Test, Printf, SparseArrays, BlockDiagonals
using Polynomials4ML, Polynomials4ML.Testing
using Polynomials4ML: index_y, rand_sphere
using Polynomials4ML: evaluate, evaluate_d, evaluate_ed
using Polynomials4ML.Testing: print_tf, println_slim
using ACEbase.Testing: fdtest
using HyperDualNumbers: Hyper



verbose = false

@info("Testing consistency of Real and Complex SH; SpheriCart convention")

# SpheriCart R2C transformation
function ctran3(L)
AA = zeros(ComplexF64, 2L+1, 2L+1)
for i = 1:2L+1
for j in [i, 2L+2-i]
AA[i,j] = begin
if i == j == L+1
1
elseif i > L+1 && j > L+1
(-1)^(i-L-1)/sqrt(2)
elseif i < L+1 && j < L+1
im/sqrt(2)
elseif i < L+1 && j > L+1
(-1)^(i-L)/sqrt(2)*im
elseif i > L+1 && j < L+1
1/sqrt(2)
end
end
end
end
return sparse(AA)
end

function test_r2c_y(L, cY, rY)
Ts = BlockDiagonal([ ctran3(l)' for l = 0:L ]) |> sparse
return cY ≈ Ts * rY
end

##

maxL = 20
cSH = CYlmBasis(maxL)
rSH = SCYlmBasis(maxL)

for nsamples = 1:30
local R
R = rand_sphere()
cY = evaluate(cSH, R)
rY = evaluate(rSH, R)
print_tf(@test test_r2c_y(maxL, cY, rY))
end
println()


##

@info("Check consistency of serial and batched evaluation")

X = [ rand_sphere() for i = 1:23 ]
Y1 = evaluate(rSH, X)
Y2 = similar(Y1)
for i = 1:length(X)
Y2[i, :] = evaluate(rSH, X[i])
end
println_slim(@test Y1 ≈ Y2)


##

@info("Test: check derivatives of real spherical harmonics")
for nsamples = 1:30
local R, rSH, h
R = @SVector rand(3)
rSH = RYlmBasis(5)
Y, dY = evaluate_ed(rSH, R)
DY = Matrix(transpose(hcat(dY...)))
errs = []
verbose && @printf(" h | error \n")
for p = 2:10
h = 0.1^p
DYh = similar(DY)
Rh = Vector(R)
for i = 1:3
Rh[i] += h
DYh[:, i] = (evaluate(rSH, SVector(Rh...)) - Y) / h
Rh[i] -= h
end
push!(errs, norm(DY - DYh, Inf))
verbose && @printf(" %.2e | %.2e \n", h, errs[end])
end
success = (minimum(errs[2:end]) < 1e-3 * maximum(errs[1:3])) || (minimum(errs) < 1e-10)
print_tf(@test success)
end
println()


##

@info("Check consistency of serial and batched gradients")

rSH = SCYlmBasis(10)
X = [ rand_sphere() for i = 1:21 ]

x2dualwrtj(x, j) = SVector{3}([Hyper(x[i], i == j, i == j, 0) for i = 1:3])

hX = [x2dualwrtj(x, 1) for x in X]


Y0 = evaluate(rSH, X)
Y1, dY1 = evaluate_ed(rSH, X)
Y2 = similar(Y1); dY2 = similar(dY1)
for i = 1:length(X)
Y2[i, :] = evaluate(rSH, X[i])
dY2[i, :] = evaluate_ed(rSH, X[i])[2]
end
println_slim(@test Y0 ≈ Y1 ≈ Y2)
println_slim(@test dY1 ≈ dY2)


# ## -- check the laplacian implementation

# using LinearAlgebra: tr
# using ForwardDiff
# P4 = Polynomials4ML

# function fwdΔ1(rYlm, x)
# Y = evaluate(rYlm, x)
# nY = length(Y)
# _j(x) = ForwardDiff.jacobian(x -> evaluate(rYlm, x), x)[:]
# _h(x) = reshape(ForwardDiff.jacobian(_j, x), (nY, 3, 3))
# H = _h(x)
# return [ tr(H[i, :, :]) for i = 1:nY ]
# end

# for x in X
# ΔY = P4.laplacian(rSH, x)
# ΔYfwd = fwdΔ1(rSH, x)
# print_tf(@test ΔYfwd ≈ ΔY)
# end
# println()

# @info("check batched laplacian")
# ΔY1 = P4.laplacian(rSH, X)
# ΔY2 = similar(ΔY1)
# for (i, x) in enumerate(X)
# ΔY2[i, :] = P4.laplacian(rSH, x)
# end
# println_slim(@test ΔY1 ≈ ΔY2)


# @info("check eval_grad_laplace")
# Y1, dY1, ΔY1 = P4.eval_grad_laplace(rSH, X)
# Y2, dY2 = evaluate_ed(rSH, X)
# ΔY2 = P4.laplacian(rSH, X)
# println_slim(@test Y1 ≈ Y2)
# println_slim(@test dY1 ≈ dY2)
# println_slim(@test ΔY1 ≈ ΔY2)

using Zygote
@info("Test rrule")
using LinearAlgebra: dot
rSH = SCYlmBasis(10)

for ntest = 1:30
local X
local Y
local Rnl
local u

X = [ rand_sphere() for i = 1:21 ]
Y = [ rand_sphere() for i = 1:21 ]
_x(t) = X + t * Y
A = evaluate(rSH, X)
u = randn(size(A))
F(t) = dot(u, evaluate(rSH, _x(t)))
dF(t) = begin
val, pb = Zygote.pullback(rSH, _x(t)) # TODO: write a pullback??
∂BB = pb(u)[1] # pb(u)[1] returns NoTangent() for basis argument
return sum( dot(∂BB[i], Y[i]) for i = 1:length(Y) )
end
print_tf(@test fdtest(F, dF, 0.0; verbose = false))
end
println()
3 changes: 2 additions & 1 deletion test/test_lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ test_bases = [ (chebyshev_basis(10), () -> rand()),
(MonoBasis(10), ()-> rand()),
(legendre_basis(10), () -> rand()),
(CYlmBasis(5), () -> randn(SVector{3, Float64})),
(RYlmBasis(5), () -> randn(SVector{3, Float64})) ]
(RYlmBasis(5), () -> randn(SVector{3, Float64})),
(SCYlmBasis(5), () -> randn(SVector{3, Float64})), ]

for (basis, rnd) in test_bases
local B1, B2, x
Expand Down
Loading