Skip to content

Commit

Permalink
extenging withalloc to sparseprodpool
Browse files Browse the repository at this point in the history
  • Loading branch information
cortner committed Jun 3, 2024
1 parent e85fcac commit 902120e
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ ChainRulesCore = "1"
Combinatorics = "1"
ForwardDiff = "0.10"
HyperDualNumbers = "4.0.10"
LinearAlgebra = "1.8"
LinearAlgebra = "1.9"
LoopVectorization = "0.12"
LuxCore = "0.1.3"
NamedTupleTools = "0.14.3"
ObjectPools = "0.3.1"
Printf = "1.8"
Printf = "1.9"
QuadGK = "2"
Random = "1.8"
SparseArrays = "1.8"
Random = "1.9"
SparseArrays = "1.9"
SpecialFunctions = "2.2"
SpheriCart = "0.0.3"
StaticArrays = "1.5"
StrideArrays = "0.1.25"
Test = "1.8"
Test = "1.9"
WithAlloc = "0.0.4"
julia = "1.8.0, 1.9.0, 1.10.0"
julia = "1.9.0, 1.10.0"

[extras]
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand Down
6 changes: 4 additions & 2 deletions src/ace/sparseprodpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ _gradtype(basis::AbstractPoly4MLBasis, BB::Tuple) =
_alloc(basis::PooledSparseProduct, BB::TupVecMat) =
acquire!(basis.pool, :A, (length(basis), ), _valtype(basis, BB) )

_out_size(basis::PooledSparseProduct, BB::TupVecMat) =
length(basis)

# _alloc_d(basis::AbstractPoly4MLBasis, BB::TupVecMat) =
# acquire!(basis.pool, _outsym(BB), (length(basis), ), _gradtype(basis, BB) )
Expand Down Expand Up @@ -109,7 +111,7 @@ function evaluate!(A, basis::PooledSparseProduct{NB}, BB::TupVec) where {NB}
b = ntuple(t -> BB[t][ϕ[t]], NB)
@inbounds A[iA] += @fastmath prod(b)
end
return nothing
return A
end


Expand Down Expand Up @@ -147,7 +149,7 @@ function evaluate!(A, basis::PooledSparseProduct{NB}, BB::TupMat,
A[iA] = a
end

return nothing
return A
end


Expand Down
8 changes: 4 additions & 4 deletions src/testing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,19 +105,19 @@ function _allocations_inner(basis, x; ed = true, ed2 = true)
s = sum(P)
if ed
P1, dP1 = @withalloc evaluate_ed!(basis, x)
s1 = s + sum(P1) + sum(dP1)
# s += s + sum(P1) + sum(dP1)
end
if ed2
P2, dP2, ddP2 = @withalloc evaluate_ed2!(basis, x)
s2 = s1 + sum(P2) + sum(dP2) + sum(ddP2)
# s2 = s1 + sum(P2) + sum(dP2) + sum(ddP2)
end
nothing
end
return s2
return s
end

function test_withalloc(basis, x; allowed_allocs = 0, kwargs...)
nalloc = @allocated ( _allocations_inner(basis, x) )
nalloc = @allocated ( _allocations_inner(basis, x; kwargs...) )
P1 = basis(x)
@no_escape begin
P2 = @withalloc evaluate!(basis, x)
Expand Down
13 changes: 12 additions & 1 deletion test/ace/test_sparseprodpool.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

using BenchmarkTools, Test, Polynomials4ML, ChainRulesCore
using Polynomials4ML: PooledSparseProduct, evaluate, evaluate!
using ACEbase.Testing: fdtest, println_slim, print_tf
using ACEbase.Testing: fdtest, println_slim, print_tf,
test_withalloc

test_evaluate(basis::PooledSparseProduct, BB::Tuple{Vararg{AbstractVector}}) =
[prod(BB[j][basis.spec[i][j]] for j = 1:length(BB))
Expand Down Expand Up @@ -72,6 +73,16 @@ println()

##

@info(" testing withalloc")
basis = _generate_basis(; order=3)
BB = _rand_input1(basis)
bBB = _rand_input(basis)
println_slim(@test test_withalloc(basis, BB; ed = false, ed2 = false) )
println_slim(@test test_withalloc(basis, bBB; ed = false, ed2 = false) )

##


@info("Testing rrule")
using LinearAlgebra: dot

Expand Down

0 comments on commit 902120e

Please sign in to comment.