diff --git a/src/sphericalharmonics/scylm.jl b/src/sphericalharmonics/scylm.jl index 7e4d5cc..3b8f61a 100644 --- a/src/sphericalharmonics/scylm.jl +++ b/src/sphericalharmonics/scylm.jl @@ -30,11 +30,8 @@ Base.show(io::IO, basis::SCYlmBasis{L, normalisation, static, T}) where {L, norm # ---------------------- Interfaces -evaluate!(Y::AbstractArray, basis::SCYlmBasis, X::SVector{3}) = scYlm!(Y,compute(basis.basis,X)) - -evaluate_ed!(Y::AbstractArray, dY::AbstractArray, basis::SCYlmBasis, X::SVector{3}) = scYlm_ed!(Y,dY,compute_with_gradients(basis.basis,X)[1],compute_with_gradients(basis.basis,X)[2]) - -scYlm!(Y,val) = Y .= val +evaluate!(Y::AbstractArray, basis::SCYlmBasis, X::SVector{3}) = [Y[i] = compute(basis.basis,X)[i] for i = 1:length(Y)] # scYlm!(Y,compute(basis.basis,X)) +evaluate_ed!(Y::AbstractArray, dY::AbstractArray, basis::SCYlmBasis, X::SVector{3}) = scYlm_ed!(Y,dY,compute_with_gradients(basis.basis,X)...) function scYlm_ed!(Y,dY,val,dval) Y .= val @@ -45,5 +42,19 @@ function scYlm_ed!(Y,dY,val,dval) end evaluate!(Y::AbstractArray, basis::SCYlmBasis, X::AbstractVector{<: SVector{3}}) = compute!(Y,basis.basis,X) - -evaluate_ed!(Y::AbstractArray, dY::AbstractArray, basis::SCYlmBasis, X::AbstractVector{<: SVector{3}}) = compute_with_gradients!(Y,dY,basis.basis,X) \ No newline at end of file +evaluate_ed!(Y::AbstractArray, dY::AbstractArray, basis::SCYlmBasis, X::AbstractVector{<: SVector{3}}) = compute_with_gradients!(Y,dY,basis.basis,X) + +# rrule +function ChainRulesCore.rrule(::typeof(evaluate), basis::SCYlmBasis, X) + A, dX = evaluate_ed(basis, X) + function pb(∂A) + @assert size(∂A) == (length(X), length(basis)) + T∂X = promote_type(eltype(∂A), eltype(dX)) + ∂X = similar(X, SVector{3, T∂X}) + for i = 1:length(X) + ∂X[i] = sum([∂A[i,j] * dX[i,j] for j = 1:length(dX[i,:])]) + end + return NoTangent(), NoTangent(), ∂X + end + return A, pb +end \ No newline at end of file diff --git a/test/sphericalharmonics/test_scylm.jl b/test/sphericalharmonics/test_scylm.jl index 9939aa2..98a2b6d 100644 --- a/test/sphericalharmonics/test_scylm.jl +++ b/test/sphericalharmonics/test_scylm.jl @@ -3,7 +3,7 @@ using Polynomials4ML, Polynomials4ML.Testing using Polynomials4ML: index_y, rand_sphere using Polynomials4ML: evaluate, evaluate_d, evaluate_ed using Polynomials4ML.Testing: print_tf, println_slim -# using ACEbase.Testing: fdtest +using ACEbase.Testing: fdtest using HyperDualNumbers: Hyper @@ -160,49 +160,28 @@ println_slim(@test dY1 ≈ dY2) # println_slim(@test dY1 ≈ dY2) # println_slim(@test ΔY1 ≈ ΔY2) -# using Zygote -# @info("Test rrule") -# using LinearAlgebra: dot -# rSH = SCYlmBasis(10) +using Zygote +@info("Test rrule") +using LinearAlgebra: dot +rSH = SCYlmBasis(10) -# for ntest = 1:30 -# local X -# local Y -# local Rnl -# local u +for ntest = 1:30 + local X + local Y + local Rnl + local u -# X = [ rand_sphere() for i = 1:21 ] -# Y = [ rand_sphere() for i = 1:21 ] -# _x(t) = X + t * Y -# A = evaluate(rSH, X) -# u = randn(size(A)) -# F(t) = dot(u, evaluate(rSH, _x(t))) -# dF(t) = begin -# val, pb = Zygote.pullback(rSH, _x(t)) # TODO: write a pullback?? -# ∂BB = pb(u)[1] # pb(u)[1] returns NoTangent() for basis argument -# return sum( dot(∂BB[i], Y[i]) for i = 1:length(Y) ) -# end -# print_tf(@test fdtest(F, dF, 0.0; verbose = false)) -# end -# println() - -# ## Debugging code -# X = [ rand_sphere() for i = 1:21 ] -# Y = [ rand_sphere() for i = 1:21 ] -# _x(t) = X + t * Y -# A = evaluate(rSH, X) -# u = randn(size(A)) -# F(t) = dot(u, evaluate(rSH, _x(t))) -# t = 1 -# val, pb = Zygote.pullback(rSH, _x(t)) -# pb -# ∂BB = pb(A)[1] - -# dF(t) = begin -# val, pb = Zygote.pullback(rSH, _x(t)) -# ∂BB = pb(u)[1] # pb(u)[1] returns NoTangent() for basis argument -# return sum( dot(∂BB[i], Y[i]) for i = 1:length(Y) ) -# end -# fdtest(F, dF, 0.0; verbose = false) - -# print_tf(@test fdtest(F, dF, 0.0; verbose = false)) + X = [ rand_sphere() for i = 1:21 ] + Y = [ rand_sphere() for i = 1:21 ] + _x(t) = X + t * Y + A = evaluate(rSH, X) + u = randn(size(A)) + F(t) = dot(u, evaluate(rSH, _x(t))) + dF(t) = begin + val, pb = Zygote.pullback(rSH, _x(t)) # TODO: write a pullback?? + ∂BB = pb(u)[1] # pb(u)[1] returns NoTangent() for basis argument + return sum( dot(∂BB[i], Y[i]) for i = 1:length(Y) ) + end + print_tf(@test fdtest(F, dF, 0.0; verbose = false)) +end +println() \ No newline at end of file