Skip to content

Commit

Permalink
Merge pull request #114 from avik-pal/ap/dense
Browse files Browse the repository at this point in the history
Go through the dense bias deprecation
  • Loading branch information
avik-pal authored Jul 30, 2022
2 parents 11ac3e4 + 6af4f33 commit 33d1eb1
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 34 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# v0.4

## v0.4.12

- Deprecate `bias` in favor of `use_bias` for `Dense` and `Scale`.
- Deprecate `elementwise_*` and `applyactivation` functions.

## v0.4.11

- Introduces `Lux.Training` API for less clunky training loops.
Expand Down
81 changes: 52 additions & 29 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,8 +539,8 @@ Create a traditional fully connected layer, whose forward pass is given by:
- `init_weight`: initializer for the weight matrix
(`weight = init_weight(rng, out_dims, in_dims)`)
- `init_bias`: initializer for the bias vector (ignored if `bias=false`)
- `bias`: whether to include a bias vector
- `init_bias`: initializer for the bias vector (ignored if `use_bias=false`)
- `use_bias`: Trainable bias can be disabled entirely by setting this to `false`
## Input
Expand All @@ -556,48 +556,57 @@ Create a traditional fully connected layer, whose forward pass is given by:
- `weight`: Weight Matrix of size `out_dims × in_dims`
- `bias`: Bias of size `out_dims × 1` (present if `bias=true`)
"""
struct Dense{bias, F1, F2, F3} <: AbstractExplicitLayer
struct Dense{use_bias, F1, F2, F3} <: AbstractExplicitLayer
activation::F1
in_dims::Int
out_dims::Int
init_weight::F2
init_bias::F3
end

function Base.show(io::IO, d::Dense{bias}) where {bias}
function Base.show(io::IO, d::Dense{use_bias}) where {use_bias}
print(io, "Dense($(d.in_dims) => $(d.out_dims)")
(d.activation == identity) || print(io, ", $(d.activation)")
bias || print(io, ", bias=false")
use_bias || print(io, ", bias=false")
return print(io, ")")
end

function Dense(mapping::Pair{<:Int, <:Int}, activation=identity; init_weight=glorot_uniform,
init_bias=zeros32, bias::Bool=true)
return Dense(first(mapping), last(mapping), activation; init_weight=init_weight,
init_bias=init_bias, bias=bias)
init_bias=zeros32, use_bias::Bool=true, bias::Union{Missing, Bool}=missing)
return Dense(first(mapping), last(mapping), activation; init_weight, init_bias,
use_bias, bias)
end

function Dense(in_dims::Int, out_dims::Int, activation=identity; init_weight=glorot_uniform,
init_bias=zeros32, bias::Bool=true)
init_bias=zeros32, use_bias::Bool=true, bias::Union{Missing, Bool}=missing)
activation = NNlib.fast_act(activation)
return Dense{bias, typeof(activation), typeof(init_weight), typeof(init_bias)}(activation,
in_dims,
out_dims,
init_weight,
init_bias)

# Deprecated Functionality (Remove in v0.5)
if !ismissing(bias)
Base.depwarn("`bias` argument to `Dense` has been deprecated and will be removed" *
" in v0.5. Use `use_bias` kwarg instead.", :Dense)
if !use_bias
throw(ArgumentError("Both `bias` and `use_bias` are set. Please only use " *
"the `use_bias` keyword argument."))
end
use_bias = bias
end

dtype = (use_bias, typeof(activation), typeof(init_weight), typeof(init_bias))
return Dense{dtype...}(activation, in_dims, out_dims, init_weight, init_bias)
end

function initialparameters(rng::AbstractRNG, d::Dense{bias}) where {bias}
if bias
function initialparameters(rng::AbstractRNG, d::Dense{use_bias}) where {use_bias}
if use_bias
return (weight=d.init_weight(rng, d.out_dims, d.in_dims),
bias=d.init_bias(rng, d.out_dims, 1))
else
return (weight=d.init_weight(rng, d.out_dims, d.in_dims),)
end
end

function parameterlength(d::Dense{bias}) where {bias}
return bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims
function parameterlength(d::Dense{use_bias}) where {use_bias}
return use_bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims
end
statelength(d::Dense) = 0

Expand Down Expand Up @@ -666,8 +675,8 @@ Elements are non-zero). The forward pass is given by: `y = activation.(weight .*
- `init_weight`: initializer for the weight matrix
(`weight = init_weight(rng, out_dims, in_dims)`)
- `init_bias`: initializer for the bias vector (ignored if `bias=false`)
- `bias`: whether to include a bias vector
- `init_bias`: initializer for the bias vector (ignored if `use_bias=false`)
- `use_bias`: Trainable bias can be disabled entirely by setting this to `false`
## Input
Expand All @@ -688,7 +697,7 @@ Elements are non-zero). The forward pass is given by: `y = activation.(weight .*
`Scale` with multiple dimensions requires at least Lux 0.4.3.
"""
struct Scale{bias, F1, D, F2, F3} <: AbstractExplicitLayer
struct Scale{use_bias, F1, D, F2, F3} <: AbstractExplicitLayer
activation::F1
dims::D
init_weight::F2
Expand All @@ -702,9 +711,22 @@ function Base.show(io::IO, d::Scale)
end

function Scale(dims::Tuple{Vararg{Integer}}, activation=identity;
init_weight=glorot_uniform, init_bias=zeros32, bias::Bool=true)
init_weight=glorot_uniform, init_bias=zeros32, use_bias::Bool=true,
bias::Union{Missing, Bool}=missing)
activation = NNlib.fast_act(activation)
return Scale{bias, typeof(activation), typeof(dims), typeof(init_weight),

# Deprecated Functionality (Remove in v0.5)
if !ismissing(bias)
Base.depwarn("`bias` argument to `Scale` has been deprecated and will be removed" *
" in v0.5. Use `use_bias` kwarg instead.", :Scale)
if !use_bias
throw(ArgumentError("Both `bias` and `use_bias` are set. Please only use " *
"the `use_bias` keyword argument."))
end
use_bias = bias
end

return Scale{use_bias, typeof(activation), typeof(dims), typeof(init_weight),
typeof(init_bias)}(activation, dims, init_weight, init_bias)
end

Expand All @@ -713,14 +735,15 @@ function Scale(s1::Integer, s23::Integer...; _act=identity, kw...)
end
Scale(size_act...; kw...) = Scale(size_act[1:(end - 1)]...; _act=size_act[end], kw...)

function initialparameters(rng::AbstractRNG, d::Scale{true})
return (weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...))
end
function initialparameters(rng::AbstractRNG, d::Scale{false})
return (weight=d.init_weight(rng, d.dims...),)
function initialparameters(rng::AbstractRNG, d::Scale{use_bias}) where {use_bias}
if use_bias
return (weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...))
else
return (weight=d.init_weight(rng, d.dims...),)
end
end

parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * prod(d.dims)
parameterlength(d::Scale{use_bias}) where {use_bias} = (1 + bias) * prod(d.dims)
statelength(d::Scale) = 0

function (d::Scale{true})(x::AbstractArray, ps, st::NamedTuple)
Expand Down
18 changes: 16 additions & 2 deletions test/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ end
@test size(ps.bias) == (100, 1)
@test layer.activation == identity

layer = Dense(10, 100, relu; bias=false)
layer = Dense(10, 100, relu; use_bias=false)
ps, st = Lux.setup(rng, layer)

@test !haskey(ps, :bias)
Expand Down Expand Up @@ -256,6 +256,13 @@ end
first(Lux.apply(layer, [ones(10, 1) 2 * ones(10, 1)], Lux.setup(rng, layer)...))
end == [10 20; 10 20]
end

# Deprecated Functionality (Remove in v0.5)
@testset "Deprecations" begin
@test_deprecated Dense(10, 100, relu; bias=false)
@test_deprecated Dense(10, 100, relu; bias=true)
@test_throws ArgumentError Dense(10, 100, relu; bias=false, use_bias=false)
end
end

@testset "Scale" begin
Expand All @@ -267,7 +274,7 @@ end
@test size(ps.bias) == (10, 100)
@test layer.activation == identity

layer = Scale(10, 100, relu; bias=false)
layer = Scale(10, 100, relu; use_bias=false)
ps, st = Lux.setup(rng, layer)

@test !haskey(ps, :bias)
Expand Down Expand Up @@ -303,4 +310,11 @@ end
first(Lux.apply(layer, [1 2; 3 4], Lux.setup(rng, layer)...))
end == zeros(2, 2)
end

# Deprecated Functionality (Remove in v0.5)
@testset "Deprecations" begin
@test_deprecated Scale(10, 100, relu; bias=false)
@test_deprecated Scale(10, 100, relu; bias=true)
@test_throws ArgumentError Scale(10, 100, relu; bias=false, use_bias=false)
end
end
8 changes: 5 additions & 3 deletions test/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@ include("test_utils.jl")

# On CPU the fallback should always work
@test Lux.elementwise_add(x, y) == x .+ y
@test_deprecated Lux.elementwise_add(x, y)
@test Lux.elementwise_mul(x, y) == x .* y
@test_deprecated Lux.elementwise_mul(x, y)
@test Lux.applyactivation(tanh, x) == tanh.(x)
@test Lux.applyactivation(custom_activation, x) == custom_activation.(x)
@test_deprecated Lux.applyactivation(tanh, x)

if T <: Real
# Gradient for complex outputs are not defined
Expand All @@ -36,5 +33,10 @@ include("test_utils.jl")
# Custom Activation test
@test Lux.applyactivation(custom_activation, x_g) == custom_activation.(x_g)
end

# Deprecated Functionality (Remove in v0.5)
@test_deprecated Lux.elementwise_add(x, y)
@test_deprecated Lux.elementwise_mul(x, y)
@test_deprecated Lux.applyactivation(tanh, x)
end
end

0 comments on commit 33d1eb1

Please sign in to comment.