diff --git a/CHANGELOG.md b/CHANGELOG.md index 80b8bf32e..5c02b90b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # v0.4 +## v0.4.12 + + - Deprecate `bias` in favor of `use_bias` for `Dense` and `Scale`. + - Deprecate `elementwise_*` and `applyactivation` functions. + ## v0.4.11 - Introduces `Lux.Training` API for less clunky training loops. diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 23b1d2c77..d18795240 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -539,8 +539,8 @@ Create a traditional fully connected layer, whose forward pass is given by: - `init_weight`: initializer for the weight matrix (`weight = init_weight(rng, out_dims, in_dims)`) - - `init_bias`: initializer for the bias vector (ignored if `bias=false`) - - `bias`: whether to include a bias vector + - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) + - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` ## Input @@ -556,7 +556,7 @@ Create a traditional fully connected layer, whose forward pass is given by: - `weight`: Weight Matrix of size `out_dims × in_dims` - `bias`: Bias of size `out_dims × 1` (present if `bias=true`) """ -struct Dense{bias, F1, F2, F3} <: AbstractExplicitLayer +struct Dense{use_bias, F1, F2, F3} <: AbstractExplicitLayer activation::F1 in_dims::Int out_dims::Int @@ -564,31 +564,40 @@ struct Dense{bias, F1, F2, F3} <: AbstractExplicitLayer init_bias::F3 end -function Base.show(io::IO, d::Dense{bias}) where {bias} +function Base.show(io::IO, d::Dense{use_bias}) where {use_bias} print(io, "Dense($(d.in_dims) => $(d.out_dims)") (d.activation == identity) || print(io, ", $(d.activation)") - bias || print(io, ", bias=false") + use_bias || print(io, ", bias=false") return print(io, ")") end function Dense(mapping::Pair{<:Int, <:Int}, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, bias::Bool=true) - return Dense(first(mapping), last(mapping), activation; init_weight=init_weight, - init_bias=init_bias, bias=bias) + init_bias=zeros32, use_bias::Bool=true, bias::Union{Missing, Bool}=missing) + return Dense(first(mapping), last(mapping), activation; init_weight, init_bias, + use_bias, bias) end function Dense(in_dims::Int, out_dims::Int, activation=identity; init_weight=glorot_uniform, - init_bias=zeros32, bias::Bool=true) + init_bias=zeros32, use_bias::Bool=true, bias::Union{Missing, Bool}=missing) activation = NNlib.fast_act(activation) - return Dense{bias, typeof(activation), typeof(init_weight), typeof(init_bias)}(activation, - in_dims, - out_dims, - init_weight, - init_bias) + + # Deprecated Functionality (Remove in v0.5) + if !ismissing(bias) + Base.depwarn("`bias` argument to `Dense` has been deprecated and will be removed" * + " in v0.5. Use `use_bias` kwarg instead.", :Dense) + if !use_bias + throw(ArgumentError("Both `bias` and `use_bias` are set. Please only use " * + "the `use_bias` keyword argument.")) + end + use_bias = bias + end + + dtype = (use_bias, typeof(activation), typeof(init_weight), typeof(init_bias)) + return Dense{dtype...}(activation, in_dims, out_dims, init_weight, init_bias) end -function initialparameters(rng::AbstractRNG, d::Dense{bias}) where {bias} - if bias +function initialparameters(rng::AbstractRNG, d::Dense{use_bias}) where {use_bias} + if use_bias return (weight=d.init_weight(rng, d.out_dims, d.in_dims), bias=d.init_bias(rng, d.out_dims, 1)) else @@ -596,8 +605,8 @@ function initialparameters(rng::AbstractRNG, d::Dense{bias}) where {bias} end end -function parameterlength(d::Dense{bias}) where {bias} - return bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims +function parameterlength(d::Dense{use_bias}) where {use_bias} + return use_bias ? d.out_dims * (d.in_dims + 1) : d.out_dims * d.in_dims end statelength(d::Dense) = 0 @@ -666,8 +675,8 @@ Elements are non-zero). The forward pass is given by: `y = activation.(weight .* - `init_weight`: initializer for the weight matrix (`weight = init_weight(rng, out_dims, in_dims)`) - - `init_bias`: initializer for the bias vector (ignored if `bias=false`) - - `bias`: whether to include a bias vector + - `init_bias`: initializer for the bias vector (ignored if `use_bias=false`) + - `use_bias`: Trainable bias can be disabled entirely by setting this to `false` ## Input @@ -688,7 +697,7 @@ Elements are non-zero). The forward pass is given by: `y = activation.(weight .* `Scale` with multiple dimensions requires at least Lux 0.4.3. """ -struct Scale{bias, F1, D, F2, F3} <: AbstractExplicitLayer +struct Scale{use_bias, F1, D, F2, F3} <: AbstractExplicitLayer activation::F1 dims::D init_weight::F2 @@ -702,9 +711,22 @@ function Base.show(io::IO, d::Scale) end function Scale(dims::Tuple{Vararg{Integer}}, activation=identity; - init_weight=glorot_uniform, init_bias=zeros32, bias::Bool=true) + init_weight=glorot_uniform, init_bias=zeros32, use_bias::Bool=true, + bias::Union{Missing, Bool}=missing) activation = NNlib.fast_act(activation) - return Scale{bias, typeof(activation), typeof(dims), typeof(init_weight), + + # Deprecated Functionality (Remove in v0.5) + if !ismissing(bias) + Base.depwarn("`bias` argument to `Scale` has been deprecated and will be removed" * + " in v0.5. Use `use_bias` kwarg instead.", :Scale) + if !use_bias + throw(ArgumentError("Both `bias` and `use_bias` are set. Please only use " * + "the `use_bias` keyword argument.")) + end + use_bias = bias + end + + return Scale{use_bias, typeof(activation), typeof(dims), typeof(init_weight), typeof(init_bias)}(activation, dims, init_weight, init_bias) end @@ -713,14 +735,15 @@ function Scale(s1::Integer, s23::Integer...; _act=identity, kw...) end Scale(size_act...; kw...) = Scale(size_act[1:(end - 1)]...; _act=size_act[end], kw...) -function initialparameters(rng::AbstractRNG, d::Scale{true}) - return (weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...)) -end -function initialparameters(rng::AbstractRNG, d::Scale{false}) - return (weight=d.init_weight(rng, d.dims...),) +function initialparameters(rng::AbstractRNG, d::Scale{use_bias}) where {use_bias} + if use_bias + return (weight=d.init_weight(rng, d.dims...), bias=d.init_bias(rng, d.dims...)) + else + return (weight=d.init_weight(rng, d.dims...),) + end end -parameterlength(d::Scale{bias}) where {bias} = (1 + bias) * prod(d.dims) +parameterlength(d::Scale{use_bias}) where {use_bias} = (1 + bias) * prod(d.dims) statelength(d::Scale) = 0 function (d::Scale{true})(x::AbstractArray, ps, st::NamedTuple) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index b0d09419b..aaf7acf0e 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -215,7 +215,7 @@ end @test size(ps.bias) == (100, 1) @test layer.activation == identity - layer = Dense(10, 100, relu; bias=false) + layer = Dense(10, 100, relu; use_bias=false) ps, st = Lux.setup(rng, layer) @test !haskey(ps, :bias) @@ -256,6 +256,13 @@ end first(Lux.apply(layer, [ones(10, 1) 2 * ones(10, 1)], Lux.setup(rng, layer)...)) end == [10 20; 10 20] end + + # Deprecated Functionality (Remove in v0.5) + @testset "Deprecations" begin + @test_deprecated Dense(10, 100, relu; bias=false) + @test_deprecated Dense(10, 100, relu; bias=true) + @test_throws ArgumentError Dense(10, 100, relu; bias=false, use_bias=false) + end end @testset "Scale" begin @@ -267,7 +274,7 @@ end @test size(ps.bias) == (10, 100) @test layer.activation == identity - layer = Scale(10, 100, relu; bias=false) + layer = Scale(10, 100, relu; use_bias=false) ps, st = Lux.setup(rng, layer) @test !haskey(ps, :bias) @@ -303,4 +310,11 @@ end first(Lux.apply(layer, [1 2; 3 4], Lux.setup(rng, layer)...)) end == zeros(2, 2) end + + # Deprecated Functionality (Remove in v0.5) + @testset "Deprecations" begin + @test_deprecated Scale(10, 100, relu; bias=false) + @test_deprecated Scale(10, 100, relu; bias=true) + @test_throws ArgumentError Scale(10, 100, relu; bias=false, use_bias=false) + end end diff --git a/test/nnlib.jl b/test/nnlib.jl index 3c2422d40..84483e9f0 100644 --- a/test/nnlib.jl +++ b/test/nnlib.jl @@ -13,12 +13,9 @@ include("test_utils.jl") # On CPU the fallback should always work @test Lux.elementwise_add(x, y) == x .+ y - @test_deprecated Lux.elementwise_add(x, y) @test Lux.elementwise_mul(x, y) == x .* y - @test_deprecated Lux.elementwise_mul(x, y) @test Lux.applyactivation(tanh, x) == tanh.(x) @test Lux.applyactivation(custom_activation, x) == custom_activation.(x) - @test_deprecated Lux.applyactivation(tanh, x) if T <: Real # Gradient for complex outputs are not defined @@ -36,5 +33,10 @@ include("test_utils.jl") # Custom Activation test @test Lux.applyactivation(custom_activation, x_g) == custom_activation.(x_g) end + + # Deprecated Functionality (Remove in v0.5) + @test_deprecated Lux.elementwise_add(x, y) + @test_deprecated Lux.elementwise_mul(x, y) + @test_deprecated Lux.applyactivation(tanh, x) end end