Skip to content

Commit

Permalink
debug pullback, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglw0521 committed Nov 10, 2023
1 parent a0f937d commit 46dce96
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 52 deletions.
25 changes: 18 additions & 7 deletions src/sphericalharmonics/scylm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,8 @@ Base.show(io::IO, basis::SCYlmBasis{L, normalisation, static, T}) where {L, norm
# ---------------------- Interfaces


evaluate!(Y::AbstractArray, basis::SCYlmBasis, X::SVector{3}) = scYlm!(Y,compute(basis.basis,X))

evaluate_ed!(Y::AbstractArray, dY::AbstractArray, basis::SCYlmBasis, X::SVector{3}) = scYlm_ed!(Y,dY,compute_with_gradients(basis.basis,X)[1],compute_with_gradients(basis.basis,X)[2])

scYlm!(Y,val) = Y .= val
evaluate!(Y::AbstractArray, basis::SCYlmBasis, X::SVector{3}) = [Y[i] = compute(basis.basis,X)[i] for i = 1:length(Y)] # scYlm!(Y,compute(basis.basis,X))
evaluate_ed!(Y::AbstractArray, dY::AbstractArray, basis::SCYlmBasis, X::SVector{3}) = scYlm_ed!(Y,dY,compute_with_gradients(basis.basis,X)...)

function scYlm_ed!(Y,dY,val,dval)
Y .= val
Expand All @@ -45,5 +42,19 @@ function scYlm_ed!(Y,dY,val,dval)
end

evaluate!(Y::AbstractArray, basis::SCYlmBasis, X::AbstractVector{<: SVector{3}}) = compute!(Y,basis.basis,X)

evaluate_ed!(Y::AbstractArray, dY::AbstractArray, basis::SCYlmBasis, X::AbstractVector{<: SVector{3}}) = compute_with_gradients!(Y,dY,basis.basis,X)
evaluate_ed!(Y::AbstractArray, dY::AbstractArray, basis::SCYlmBasis, X::AbstractVector{<: SVector{3}}) = compute_with_gradients!(Y,dY,basis.basis,X)

# rrule
function ChainRulesCore.rrule(::typeof(evaluate), basis::SCYlmBasis, X)
A, dX = evaluate_ed(basis, X)
function pb(∂A)
@assert size(∂A) == (length(X), length(basis))
T∂X = promote_type(eltype(∂A), eltype(dX))
∂X = similar(X, SVector{3, T∂X})
for i = 1:length(X)
∂X[i] = sum([∂A[i,j] * dX[i,j] for j = 1:length(dX[i,:])])
end
return NoTangent(), NoTangent(), ∂X
end
return A, pb
end
69 changes: 24 additions & 45 deletions test/sphericalharmonics/test_scylm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Polynomials4ML, Polynomials4ML.Testing
using Polynomials4ML: index_y, rand_sphere
using Polynomials4ML: evaluate, evaluate_d, evaluate_ed
using Polynomials4ML.Testing: print_tf, println_slim
# using ACEbase.Testing: fdtest
using ACEbase.Testing: fdtest
using HyperDualNumbers: Hyper


Expand Down Expand Up @@ -160,49 +160,28 @@ println_slim(@test dY1 ≈ dY2)
# println_slim(@test dY1 ≈ dY2)
# println_slim(@test ΔY1 ≈ ΔY2)

# using Zygote
# @info("Test rrule")
# using LinearAlgebra: dot
# rSH = SCYlmBasis(10)
using Zygote
@info("Test rrule")
using LinearAlgebra: dot
rSH = SCYlmBasis(10)

# for ntest = 1:30
# local X
# local Y
# local Rnl
# local u
for ntest = 1:30
local X
local Y
local Rnl
local u

# X = [ rand_sphere() for i = 1:21 ]
# Y = [ rand_sphere() for i = 1:21 ]
# _x(t) = X + t * Y
# A = evaluate(rSH, X)
# u = randn(size(A))
# F(t) = dot(u, evaluate(rSH, _x(t)))
# dF(t) = begin
# val, pb = Zygote.pullback(rSH, _x(t)) # TODO: write a pullback??
# ∂BB = pb(u)[1] # pb(u)[1] returns NoTangent() for basis argument
# return sum( dot(∂BB[i], Y[i]) for i = 1:length(Y) )
# end
# print_tf(@test fdtest(F, dF, 0.0; verbose = false))
# end
# println()

# ## Debugging code
# X = [ rand_sphere() for i = 1:21 ]
# Y = [ rand_sphere() for i = 1:21 ]
# _x(t) = X + t * Y
# A = evaluate(rSH, X)
# u = randn(size(A))
# F(t) = dot(u, evaluate(rSH, _x(t)))
# t = 1
# val, pb = Zygote.pullback(rSH, _x(t))
# pb
# ∂BB = pb(A)[1]

# dF(t) = begin
# val, pb = Zygote.pullback(rSH, _x(t))
# ∂BB = pb(u)[1] # pb(u)[1] returns NoTangent() for basis argument
# return sum( dot(∂BB[i], Y[i]) for i = 1:length(Y) )
# end
# fdtest(F, dF, 0.0; verbose = false)

# print_tf(@test fdtest(F, dF, 0.0; verbose = false))
X = [ rand_sphere() for i = 1:21 ]
Y = [ rand_sphere() for i = 1:21 ]
_x(t) = X + t * Y
A = evaluate(rSH, X)
u = randn(size(A))
F(t) = dot(u, evaluate(rSH, _x(t)))
dF(t) = begin
val, pb = Zygote.pullback(rSH, _x(t)) # TODO: write a pullback??
∂BB = pb(u)[1] # pb(u)[1] returns NoTangent() for basis argument
return sum( dot(∂BB[i], Y[i]) for i = 1:length(Y) )
end
print_tf(@test fdtest(F, dF, 0.0; verbose = false))
end
println()

0 comments on commit 46dce96

Please sign in to comment.