Skip to content

Commit

Permalink
Merge pull request #106 from avik-pal/ap/wn_fix
Browse files Browse the repository at this point in the history
Fixes WeightNorm with zero Parameter bug
  • Loading branch information
avik-pal authored Jul 26, 2022
2 parents a5e3a70 + 972c503 commit 6babea5
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 23 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## v0.4.11

- Introduces `Lux.Training` API for less clunky training loops.
- WeightNorm cannot be invoked on a parameter with all elements equal to 0.

## v0.4.10

Expand Down
2 changes: 1 addition & 1 deletion docs/src/api/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ CurrentModule = Lux

All features listed on this page are **experimental** which means:

1. No SemVar Guarantees. We use code here to iterate fast and most users should wait for
1. No SemVer Guarantees. We use code here to iterate fast and most users should wait for
these features to be marked non-experimental.
2. The code will probably be moved into a separate repository in the future.
3. Expect edge-cases and report them. It will help us move these features out of
Expand Down
2 changes: 1 addition & 1 deletion docs/src/devdocs/style_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ away those optimizations (these can be tested via `Zygote.@code_ir`).

## Deprecation

Deprecations should be handled according to SemVar recommendations, i.e. there should be
Deprecations should be handled according to SemVer recommendations, i.e. there should be
atleast one version where we throw a deprecation warning. This ensures users know how to
modify their code for upcoming releases.

Expand Down
10 changes: 9 additions & 1 deletion src/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,15 @@ function initialparameters(rng::AbstractRNG,
i = 1
for k in propertynames(ps_layer)
v = ps_layer[k]
if k which_params
if k in which_params
if all(iszero, v)
msg = ("Parameter $(k) is completely zero. This will result in NaN " *
"gradients. Either remove this parameter from `which_params` or " *
"modify the initialization in the actual layer. Typically this is " *
"controlled using the `init_$(k)` keyword argument.")
# FIXME(@avik-pal): This is not really an ArgumentError
throw(ArgumentError(msg))
end
dim = wn.dims === nothing ? ndims(v) : wn.dims[i]
push!(ps_normalized, Symbol(string(k) * "_g") => _norm_except(v; dims=dim))
push!(ps_normalized, Symbol(string(k) * "_v") => v)
Expand Down
55 changes: 37 additions & 18 deletions test/layers/normalize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,82 +231,101 @@ end
end

@testset "Conv" begin
c = Conv((3, 3), 3 => 3)
c = Conv((3, 3), 3 => 3; init_bias=Lux.ones32)

wn = WeightNorm(c, (:weight, :bias))
println(wn)
ps, st = Lux.setup(rng, wn)
x = randn(rng, Float32, 3, 3, 3, 1)

run_JET_tests(wn, x, ps, st)
test_gradient_correctness_fdm(x -> sum(first(wn(x, ps, st))), x; atol=1.0f-3,
rtol=1.0f-3)
test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps;
atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(c, (:weight,))
println(wn)
ps, st = Lux.setup(rng, wn)
x = randn(rng, Float32, 3, 3, 3, 1)

run_JET_tests(wn, x, ps, st)
test_gradient_correctness_fdm(x -> sum(first(wn(x, ps, st))), x; atol=1.0f-3,
rtol=1.0f-3)
test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps;
atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(c, (:weight, :bias), (2, 2))
println(wn)
ps, st = Lux.setup(rng, wn)
x = randn(rng, Float32, 3, 3, 3, 1)

run_JET_tests(wn, x, ps, st)
test_gradient_correctness_fdm(x -> sum(first(wn(x, ps, st))), x; atol=1.0f-3,
rtol=1.0f-3)
test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps;
atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(c, (:weight,), (2,))
println(wn)
ps, st = Lux.setup(rng, wn)
x = randn(rng, Float32, 3, 3, 3, 1)

run_JET_tests(wn, x, ps, st)
test_gradient_correctness_fdm(x -> sum(first(wn(x, ps, st))), x; atol=1.0f-3,
rtol=1.0f-3)
test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps;
atol=1.0f-3, rtol=1.0f-3)
end

@testset "Dense" begin
d = Dense(3 => 3)
d = Dense(3 => 3; init_bias=Lux.ones32)

wn = WeightNorm(d, (:weight, :bias))
println(wn)
ps, st = Lux.setup(rng, wn)
x = randn(rng, Float32, 3, 1)

run_JET_tests(wn, x, ps, st)
test_gradient_correctness_fdm(x -> sum(first(wn(x, ps, st))), x; atol=1.0f-3,
rtol=1.0f-3)
test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps;
atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(d, (:weight,))
println(wn)
ps, st = Lux.setup(rng, wn)
x = randn(rng, Float32, 3, 1)

run_JET_tests(wn, x, ps, st)
test_gradient_correctness_fdm(x -> sum(first(wn(x, ps, st))), x; atol=1.0f-3,
rtol=1.0f-3)
test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps;
atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(d, (:weight, :bias), (2, 2))
println(wn)
ps, st = Lux.setup(rng, wn)
x = randn(rng, Float32, 3, 1)

run_JET_tests(wn, x, ps, st)
test_gradient_correctness_fdm(x -> sum(first(wn(x, ps, st))), x; atol=1.0f-3,
rtol=1.0f-3)
test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps;
atol=1.0f-3, rtol=1.0f-3)

wn = WeightNorm(d, (:weight,), (2,))
println(wn)
ps, st = Lux.setup(rng, wn)
x = randn(rng, Float32, 3, 1)

run_JET_tests(wn, x, ps, st)
test_gradient_correctness_fdm(x -> sum(first(wn(x, ps, st))), x; atol=1.0f-3,
rtol=1.0f-3)
test_gradient_correctness_fdm((x, ps) -> sum(first(wn(x, ps, st))), x, ps;
atol=1.0f-3, rtol=1.0f-3)
end

# See https://github.com/avik-pal/Lux.jl/issues/95
@testset "Normalizing Zero Parameters" begin
c = Conv((3, 3), 3 => 3)

wn = WeightNorm(c, (:weight, :bias))
@test_throws ArgumentError Lux.setup(rng, wn)

wn = WeightNorm(c, (:weight,))
@test_nowarn Lux.setup(rng, wn)

c = Conv((3, 3), 3 => 3; init_bias=Lux.ones32)

wn = WeightNorm(c, (:weight, :bias))
@test_nowarn Lux.setup(rng, wn)

wn = WeightNorm(c, (:weight,))
@test_nowarn Lux.setup(rng, wn)
end
end
12 changes: 10 additions & 2 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using FiniteDifferences, JET, Lux, Optimisers, Random, Test, Zygote
using ComponentArrays, FiniteDifferences, JET, Lux, Optimisers, Random, Test, Zygote

function Base.isapprox(x, y; kwargs...)
@warn "`isapprox` is not defined for ($(typeof(x)), $(typeof(y))). Using `==` instead."
Expand All @@ -22,10 +22,18 @@ function Base.isapprox(t1::NTuple{N, T}, t2::NTuple{N, T}; kwargs...) where {N,
return all(checkapprox, zip(t1, t2))
end

Base.isapprox(::Nothing, v::AbstractArray; kwargs...) = length(v) == 0
Base.isapprox(v::AbstractArray, ::Nothing; kwargs...) = length(v) == 0

# Test the gradients generated using AD against the gradients generated using Finite Differences
_named_tuple(x::ComponentArray) = NamedTuple(x)
_named_tuple(x) = x

function test_gradient_correctness_fdm(f::Function, args...; kwargs...)
gs_ad = Zygote.gradient(f, args...)
gs_fdm = FiniteDifferences.grad(FiniteDifferences.central_fdm(5, 1), f, args...)
gs_fdm = FiniteDifferences.grad(FiniteDifferences.central_fdm(5, 1), f,
ComponentArray.(args)...)
gs_fdm = _named_tuple.(gs_fdm)
for (g_ad, g_fdm) in zip(gs_ad, gs_fdm)
@test isapprox(g_ad, g_fdm; kwargs...)
end
Expand Down

0 comments on commit 6babea5

Please sign in to comment.