Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
cortner committed Jun 7, 2024
1 parent 5fdba55 commit 7979e0c
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 153 deletions.
20 changes: 13 additions & 7 deletions src/Polynomials4ML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module Polynomials4ML

# -------------- Import Bumper and related things ---------------

using Bumper, WithAlloc
using Bumper, WithAlloc, StrideArrays
import WithAlloc: whatalloc

# -------------- import ACEbase stuff
Expand All @@ -13,11 +13,19 @@ import ACEbase: evaluate, evaluate_d, evaluate_ed, evaluate_dd, evaluate_ed2,
evaluate!, evaluate_d!, evaluate_ed!, evaluate_ed2!
import ACEbase.FIO: read_dict, write_dict

using LuxCore, Random
import ChainRulesCore: rrule, frule


function natural_indices end # could rename this get_spec or similar ...
function index end
function orthpolybasis end
function degree end

function pullback_evaluate end
function pullback_evaluate! end
function pushforward_evaluate end
function pushforward_evaluate! end

# some stuff to allow bases to overload some lux functionality ...
# how much of this should go into ACEbase?
Expand Down Expand Up @@ -59,24 +67,22 @@ include("chebbasis.jl")
include("trig.jl")
include("rtrig.jl")

#=
# 3d harmonics
include("sphericalharmonics/sphericalharmonics.jl")
# include("sphericalharmonics/sphericalharmonics.jl")

# quantum chemistry
include("atomicorbitalsradials/atomicorbitalsradials.jl")
# include("atomicorbitalsradials/atomicorbitalsradials.jl")

# generating product bases (generalisation of tensor products)
include("sparseproduct.jl")
# include("sparseproduct.jl")

# LinearLayer implementation
# this is needed to better play with cached arrays + to give the correct
# behaviour when the feature dimension is different from expected.
include("linear.jl")
# include("linear.jl")

# generic machinery for wrapping poly4ml bases into lux layers
include("lux.jl")
=#

# basis components to implement cluster expansion methods
include("ace/ace.jl")
Expand Down
12 changes: 5 additions & 7 deletions src/ace/sparseprodpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,6 @@ _valtype(basis::PooledSparseProduct, BB::Tuple) =
_gradtype(basis::PooledSparseProduct, BB::Tuple) =
mapreduce(eltype, promote_type, BB)

function whatalloc(evaluate!, basis::PooledSparseProduct{NB}, BB::TupVecMat) where {NB}
TV = _valtype(basis, BB)
nA = length(basis)
return (TV, nA)
end




# ----------------------- evaluation kernels
Expand Down Expand Up @@ -91,6 +84,11 @@ import Base.Cartesian: @nexprs
# return nothing
# end

function whatalloc(evaluate!, basis::PooledSparseProduct{NB}, BB::TupVecMat) where {NB}
TV = _valtype(basis, BB)
nA = length(basis)
return (TV, nA)
end

function evaluate!(A, basis::PooledSparseProduct{NB}, BB::TupVec) where {NB}
BB_batch = ntuple(i -> reshape(BB[i], (1, length(BB[i]))), NB)
Expand Down
21 changes: 0 additions & 21 deletions src/ace/sparsesymmprod.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@

using ChainRulesCore
using ChainRulesCore: NoTangent

export SparseSymmProd

@doc raw"""
Expand Down Expand Up @@ -309,24 +306,6 @@ function _pfwd_AA_N!(AA, ∂AA, A, ΔA,
end


# -------------- ChainRules integration


function rrule(::typeof(evaluate), basis::SparseSymmProd, A)
AA = evaluate(basis, A)
return AA, Δ -> (NoTangent(), NoTangent(), pullback_evaluate(Δ, basis, A))
end

function rrule(::typeof(pullback_evaluate), ∂AA, basis::SparseSymmProd, A)
∂A = pullback_evaluate(∂AA, basis, A)
function _pb(∂²)
g∂AA, gA = pb_pb_evaluate(∂², ∂AA, basis, A)
return NoTangent(), g∂AA, NoTangent(), gA
end
return ∂A, _pb
end



# -------------- Lux integration
# it needs an extra lux interface reason as in the case of the `basis`
Expand Down
146 changes: 93 additions & 53 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using StaticArrays: StaticArray, SVector, StaticVector, similar_type
using ChainRulesCore
import ChainRulesCore: rrule, frule


abstract type AbstractP4MLLayer end

Expand All @@ -20,7 +22,7 @@ allowed dimensionality of inputs to allow tensorial shapes.
"""
abstract type AbstractP4MLTensor <: AbstractP4MLLayer end

# ---------------------------------------
# ---------------------------------------------------------------------------
# some helpers to deal with the required fields:
# TODO: now that there is only meta, should this be removed?

Expand All @@ -36,7 +38,7 @@ end
_make_reqfields() = (_makemeta(), )


# ---------------------------------------
# -------------------------------------------------------------------

# SphericalCoords is defined here so it can be part of SINGLE
# TODO: retire this as soon as we fully switched to SpheriCart
Expand Down Expand Up @@ -158,86 +160,124 @@ end
# ---------------------------------------
# allocating evaluation interface

(basis::AbstractP4MLLayer)(args...) = evaluate(basis, args...)
(l::AbstractP4MLLayer)(args...) =
evaluate(l, args...)

evaluate(args...) = _with_safe_alloc(evaluate!, args...)
evaluate(l::AbstractP4MLLayer, args...) =
_with_safe_alloc(evaluate!, l, args...)

evaluate_ed(args...) = _with_safe_alloc(evaluate_ed!, args...)
evaluate_ed(l::AbstractP4MLLayer, args...) =
_with_safe_alloc(evaluate_ed!, l, args...)

evaluate_ed2(args...) = _with_safe_alloc(evaluate_ed2!, args...)
evaluate_ed2(l::AbstractP4MLLayer, args...) =
_with_safe_alloc(evaluate_ed2!, l, args...)

evaluate_d(basis, args...) = evaluate_ed(basis, args...)[2]
evaluate_d(l::AbstractP4MLLayer, args...) =
evaluate_ed(l, args...)[2]

evaluate_dd(basis, args...) = evaluate_ed2(basis, args...)[3]
evaluate_dd(l::AbstractP4MLLayer, args...) =
evaluate_ed2(l, args...)[3]

pullback_evaluate(args...) = _with_safe_alloc(pullback_evaluate!, args...)
pullback_evaluate(∂X, l::AbstractP4MLLayer, args...) =
_with_safe_alloc(pullback_evaluate!, ∂X, l, args...)

pushforward_evaluate(args...) = _with_safe_alloc(pushforward_evaluate!, args...)
pushforward_evaluate(l::AbstractP4MLLayer, args...) =
_with_safe_alloc(pushforward_evaluate!, l, args...)

pb_pb_evaluate(args...) = _with_safe_alloc(pb_pb_evaluate!, args...)
pb_pb_evaluate(∂P, ∂X, l::AbstractP4MLLayer, args...) =
_with_safe_alloc(pb_pb_evaluate!, ∂P, ∂X, l, args...)


# ---------------------------------------
# general rrules and frules interface for ChainRulesCore
# ---------------------------------------------------------------
# general rrules and frules interface for AbstractP4MLBasis

import ChainRulesCore: rrule

# ∂_xa ( ∂P : P ) = ∑_ij ∂_xa ( ∂P_ij * P_ij )
# = ∑_ij ∂P_ij * ∂_xa ( P_ij )
# = ∑_ij ∂P_ij * dP_ij δ_ia
function rrule(::typeof(evaluate),
basis::AbstractP4MLBasis,
R::AbstractVector{<: Real})
P = evaluate(basis, R)
return P, ∂P -> (NoTangent(), NoTangent(), pullback_evaluate(∂P, basis, R))
function whatalloc(::typeof(pullback_evaluate!),
∂P, basis::AbstractP4MLBasis, X::AbstractVector)
T∂X = promote_type(_gradtype(basis, X), eltype(∂P))
return (T∂X, length(X))
end

function pullback_evaluate(∂P, basis::AbstractP4MLBasis, X::AbstractVector{<: Real})
P, dP = evaluate_ed(basis, X)
@assert size(∂P) == (length(X), length(basis))
T∂R = promote_type(eltype(∂P), eltype(dP))
∂X = zeros(T∂R, length(X))
function pullback_evaluate!(∂X,
∂P, basis::AbstractP4MLBasis, X::AbstractVector;
dP = evaluate_ed(basis, X)[2] )
@assert size(∂P) == size(dP) == (length(X), length(basis))
@assert length(∂X) == length(X)
# manual loops to avoid any broadcasting of StrideArrays
# ∂_xa ( ∂P : P ) = ∑_ij ∂_xa ( ∂P_ij * P_ij )
# = ∑_ij ∂P_ij * ∂_xa ( P_ij )
# = ∑_ij ∂P_ij * dP_ij δ_ia
for n = 1:size(dP, 2)
@simd ivdep for a = 1:length(X)
∂X[a] += dP[a, n] * ∂P[a, n]
end
@simd ivdep for a = 1:length(X)
∂X[a] += dP[a, n] * ∂P[a, n]
end
end
return ∂X
end


function rrule(::typeof(pullback_evaluate),
∂P, basis::AbstractP4MLBasis, X::AbstractVector{<: Real})
∂X = pullback_evaluate(∂P, basis, X)
function _pb(∂2)
∂∂P, ∂X = pb_pb_evaluate(∂2, ∂P, basis, X)
return NoTangent(), ∂∂P, NoTangent(), ∂X
end
return ∂X, _pb
function rrule(::typeof(evaluate),
basis::AbstractP4MLBasis,
X::AbstractVector)
P = evaluate(basis, X)
# TODO: here we could do evaluate_ed, but need to think about how this
# works with the kwarg trick above...
return P, ∂P -> (NoTangent(), NoTangent(), pullback_evaluate(∂P, basis, X))
end


function pb_pb_evaluate(∂2, ∂P, basis::AbstractP4MLBasis,
X::AbstractVector{<: Real})
# @info("_evaluate_pb2")
#=
function whatalloc(::typeof(pb_pb_evaluate!),
∂∂X, ∂P, basis::AbstractP4MLBasis, X::AbstractVector)
Nbasis = length(basis)
Nx = length(X)
P, dP, ddP = evaluate_ed2(basis, X)
# ∂2 is the dual to ∂X ∈ ℝ^N, N = length(X)
@assert ∂2 isa AbstractVector
@assert length(∂2) == Nx
@assert ∂∂X isa AbstractVector
@assert length(∂∂X) == Nx
@assert size(∂P) == (Nx, Nbasis)
T∂²P = promote_type(_valtype(basis, X), eltype(∂P), eltype(∂∂X))
T∂²X = promote_type(_gradtype(basis, X), eltype(∂P), eltype(∂∂X))
return (T∂²P, Nx, Nbasis), (T∂²X, Nx)
end
∂2_∂ = zeros(size(∂P))
∂2_X = zeros(length(X))
for n = 1:Nbasis
@simd ivdep for a = 1:Nx
∂2_∂[a, n] = ∂2[a] * dP[a, n]
∂2_X[a] += ∂2[a] * ddP[a, n] * ∂P[a, n]
function pb_pb_evaluate!(∂²P, ∂²X, # output
∂∂X, # input / perturbation of ∂X
∂P, basis::AbstractP4MLBasis, # inputs
X::AbstractVector{<: Real})
@no_escape begin
P, dP, ddP = @withalloc evaluate_ed2!(basis, X)
for n = 1:Nbasis
@simd ivdep for a = 1:Nx
∂²P[a, n] = ∂∂X[a] * dP[a, n]
∂²X[a] += ∂∂X[a] * ddP[a, n] * ∂P[a, n]
end
end
end
return ∂2_∂, ∂2_X
return ∂²P, ∂²X
end
function rrule(::typeof(pullback_evaluate),
∂P, basis::AbstractP4MLBasis, X::AbstractVector{<: Real})
∂X = pullback_evaluate(∂P, basis, X)
function _pb(∂2)
∂∂P, ∂X = pb_pb_evaluate(∂2, ∂P, basis, X)
return NoTangent(), ∂∂P, NoTangent(), ∂X
end
return ∂X, _pb
end
=#


# -------------------------------------------------------------
# general rrules and frules for AbstractP4MLTensor


function rrule(::typeof(evaluate),
basis::AbstractP4MLTensor,
X)
P = evaluate(basis, X)
return P, ∂P -> (NoTangent(), NoTangent(), pullback_evaluate(∂P, basis, X))
end

5 changes: 0 additions & 5 deletions src/linear.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import ChainRulesCore: rrule
using LuxCore
using Random
using LinearAlgebra: mul!
using StrideArrays

export LinearLayer

Expand Down Expand Up @@ -36,7 +32,6 @@ x = randn(N, in_d) # batch-first
out, st = l(x, ps, st)
println(out == x * transpose(W))) # true
```
"""
struct LinearLayer{FEATFIRST} <: AbstractExplicitLayer
in_dim::Integer
Expand Down
Loading

0 comments on commit 7979e0c

Please sign in to comment.