Skip to content

Commit

Permalink
Rnn bias deprecation (#120)
Browse files Browse the repository at this point in the history
* Deprecate bias in favor of use_bias for RNNCell

* RNNCell.bias deprecation: fix tests

* Format code

* add NNlib to recurrent test

* Remove  from arguments

* Fix show message: use_bias -> bias

* Bump version

* Add deprecation message to CHANGELOG

* Format code

* Add comment about deprecated bias kwarg to RNNCell

Co-authored-by: Avik Pal <[email protected]>
  • Loading branch information
lungd and avik-pal authored Aug 7, 2022
1 parent cbda1c5 commit 6041cdb
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 34 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# v0.4

## v0.4.14
- Deprecate `bias` in favor of `use_bias` for `RNNCell`.
- Add `use_bias` kwarg to `LSTMCell` and `GRUCell`

## v0.4.12
Expand Down
41 changes: 25 additions & 16 deletions src/layers/recurrent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ An Elman RNNCell cell with `activation` (typically set to `tanh` or `relu`).
- `rng`: Controls the randomness (if any) in the initial state generation
"""
struct RNNCell{bias, A, B, W, S} <: AbstractExplicitLayer
struct RNNCell{use_bias, A, B, W, S} <: AbstractExplicitLayer
activation::A
in_dims::Int
out_dims::Int
Expand All @@ -46,17 +46,29 @@ struct RNNCell{bias, A, B, W, S} <: AbstractExplicitLayer
init_state::S
end

function RNNCell((in_dims, out_dims)::Pair{<:Int, <:Int}, activation=tanh; bias::Bool=true,
init_bias=zeros32, init_weight=glorot_uniform, init_state=ones32)
return RNNCell{bias, typeof(activation), typeof(init_bias), typeof(init_weight),
function RNNCell((in_dims, out_dims)::Pair{<:Int, <:Int}, activation=tanh;
use_bias::Bool=true, bias::Union{Missing, Bool}=missing, init_bias=zeros32,
init_weight=glorot_uniform, init_state=ones32)
# Deprecated Functionality (Remove in v0.5)
if !ismissing(bias)
Base.depwarn("`bias` argument to `RNNCell` has been deprecated and will be removed" *
" in v0.5. Use `use_bias` kwarg instead.", :RNNCell)
if !use_bias
throw(ArgumentError("Both `bias` and `use_bias` are set. Please only use " *
"the `use_bias` keyword argument."))
end
use_bias = bias
end

return RNNCell{use_bias, typeof(activation), typeof(init_bias), typeof(init_weight),
typeof(init_state)}(activation, in_dims, out_dims, init_bias,
init_weight, init_state)
end

function initialparameters(rng::AbstractRNG, rnn::RNNCell{bias}) where {bias}
function initialparameters(rng::AbstractRNG, rnn::RNNCell{use_bias}) where {use_bias}
ps = (weight_ih=rnn.init_weight(rng, rnn.out_dims, rnn.in_dims),
weight_hh=rnn.init_weight(rng, rnn.out_dims, rnn.out_dims))
if bias
if use_bias
ps = merge(ps, (bias=rnn.init_bias(rng, rnn.out_dims),))
end
return ps
Expand All @@ -68,48 +80,45 @@ function initialstates(rng::AbstractRNG, ::RNNCell)
return (rng=replicate(rng),)
end

function (rnn::RNNCell)(x::AbstractMatrix, ps::Union{ComponentArray, NamedTuple},
st::NamedTuple)
function (rnn::RNNCell)(x::AbstractMatrix, ps, st::NamedTuple)
rng = replicate(st.rng)
@set! st.rng = rng
hidden_state = _init_hidden_state(rng, rnn, x)
return rnn((x, hidden_state), ps, st)
end

function (rnn::RNNCell{true})((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix},
ps::Union{ComponentArray, NamedTuple}, st::NamedTuple)
ps, st::NamedTuple)
h_new = rnn.activation.(ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias)
return h_new, st
end

function (rnn::RNNCell{true, typeof(identity)})((x,
hidden_state)::Tuple{<:AbstractMatrix,
<:AbstractMatrix},
ps::Union{ComponentArray, NamedTuple},
<:AbstractMatrix}, ps,
st::NamedTuple)
h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state .+ ps.bias
return h_new, st
end

function (rnn::RNNCell{false})((x, hidden_state)::Tuple{<:AbstractMatrix, <:AbstractMatrix},
ps::Union{ComponentArray, NamedTuple}, st::NamedTuple)
ps, st::NamedTuple)
h_new = rnn.activation.(ps.weight_ih * x .+ ps.weight_hh * hidden_state)
return h_new, st
end

function (rnn::RNNCell{false, typeof(identity)})((x,
hidden_state)::Tuple{<:AbstractMatrix,
<:AbstractMatrix},
ps::Union{ComponentArray, NamedTuple},
st::NamedTuple)
ps, st::NamedTuple)
h_new = ps.weight_ih * x .+ ps.weight_hh * hidden_state
return h_new, st
end

function Base.show(io::IO, r::RNNCell{bias}) where {bias}
function Base.show(io::IO, r::RNNCell{use_bias}) where {use_bias}
print(io, "RNNCell($(r.in_dims) => $(r.out_dims)")
(r.activation == identity) || print(io, ", $(r.activation)")
bias || print(io, ", bias=false")
use_bias || print(io, ", bias=false")
return print(io, ")")
end

Expand Down
44 changes: 26 additions & 18 deletions test/layers/recurrent.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,39 @@
using JET, Lux, Random, Test
using JET, Lux, NNlib, Random, Test

include("../test_utils.jl")

rng = Random.default_rng()
Random.seed!(rng, 0)

@testset "RNNCell" begin for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh),
RNNCell(3 => 5, tanh; bias=false),
RNNCell(3 => 5, identity; bias=false))
println(rnncell)
ps, st = Lux.setup(rng, rnncell)
x = randn(rng, Float32, 3, 2)
h, st_ = Lux.apply(rnncell, x, ps, st)
@testset "RNNCell" begin
for rnncell in (RNNCell(3 => 5, identity), RNNCell(3 => 5, tanh),
RNNCell(3 => 5, tanh; use_bias=false),
RNNCell(3 => 5, identity; use_bias=false))
println(rnncell)
ps, st = Lux.setup(rng, rnncell)
x = randn(rng, Float32, 3, 2)
h, st_ = Lux.apply(rnncell, x, ps, st)

run_JET_tests(rnncell, x, ps, st)
run_JET_tests(rnncell, (x, h), ps, st_)
run_JET_tests(rnncell, x, ps, st)
run_JET_tests(rnncell, (x, h), ps, st_)

function loss_loop_rnncell(p)
h, st_ = rnncell(x, p, st)
for i in 1:10
h, st_ = rnncell((x, h), p, st_)
function loss_loop_rnncell(p)
h, st_ = rnncell(x, p, st)
for i in 1:10
h, st_ = rnncell((x, h), p, st_)
end
return sum(abs2, h)
end
return sum(abs2, h)
end

test_gradient_correctness_fdm(loss_loop_rnncell, ps; atol=1e-3, rtol=1e-3)
end end
test_gradient_correctness_fdm(loss_loop_rnncell, ps; atol=1e-3, rtol=1e-3)
end
# Deprecated Functionality (Remove in v0.5)
@testset "Deprecations" begin
@test_deprecated RNNCell(3 => 5, relu; bias=false)
@test_deprecated RNNCell(3 => 5, relu; bias=true)
@test_throws ArgumentError RNNCell(3 => 5, relu; bias=false, use_bias=false)
end
end

@testset "LSTMCell" begin for lstmcell in (LSTMCell(3 => 5),
LSTMCell(3 => 5; use_bias=true),
Expand Down

0 comments on commit 6041cdb

Please sign in to comment.