Skip to content

Commit

Permalink
small bugfix in lux
Browse files Browse the repository at this point in the history
  • Loading branch information
cortner committed Jun 22, 2024
1 parent 813d8cb commit 901632b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Polynomials4ML"
uuid = "03c4bcba-a943-47e9-bfa1-b1661fc2974f"
authors = ["Christoph Ortner <[email protected]> and contributors"]
version = "0.3.0"
version = "0.3.1-dev"

[deps]
ACEbase = "14bae519-eb20-449c-a949-9c58ed33163e"
Expand Down
9 changes: 2 additions & 7 deletions src/ace/sparseprodpool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -514,15 +514,13 @@ import LuxCore: AbstractExplicitLayer, initialparameters, initialstates
struct PooledSparseProductLayer{NB} <: AbstractExplicitLayer
basis::PooledSparseProduct{NB}
meta::Dict{String, Any}
release_input::Bool
end

function lux(basis::PooledSparseProduct;
name = String(nameof(typeof(basis))),
meta = Dict{String, Any}("name" => name),
release_input = true)
meta = Dict{String, Any}("name" => name))
@assert haskey(meta, "name")
return PooledSparseProductLayer(basis, meta, release_input)
return PooledSparseProductLayer(basis, meta)
end

initialparameters(rng::AbstractRNG, layer::PooledSparseProductLayer) =
Expand All @@ -533,8 +531,5 @@ initialstates(rng::AbstractRNG, layer::PooledSparseProductLayer) =

(l::PooledSparseProductLayer)(BB, ps, st) = begin
out = evaluate(l.basis, BB)
if l.release_input
release!.(BB)
end
return out, st
end

0 comments on commit 901632b

Please sign in to comment.