Skip to content

Commit

Permalink
Checking derivatives.
Browse files Browse the repository at this point in the history
  • Loading branch information
albop committed Oct 22, 2024
1 parent b91c00a commit f955fe2
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 13 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[weakdeps]
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

Expand All @@ -53,3 +54,4 @@ DolooneAPIExt = "oneAPI"
[compat]
Dolang = "3.3.0"
LabelledArrays = "≥1.16.0"
Plots = "1.40.8"
84 changes: 80 additions & 4 deletions demo.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,86 @@
a = 4

using Dolo
using DoloYAML

# model = include("examples/ymodels/consumption_savings.jl")
model = include("examples/ymodels/rbc_iid.yaml")
Dolo.time_iteration(model; improve=false)

dmodel = Dolo.discretize(model)

wk = Dolo.time_iteration_workspace(dmodel; improve=true)

using FiniteDiff

function residual_1(dmodel,wk,v)
x = Dolo.unravel(wk.x0, v)
r = Dolo.F(dmodel,x, wk.ψ)
Dolo.ravel(r)
end

ψ = deepcopy(wk.φ)

function residual_2(dmodel,wk,v)
x = Dolo.unravel(wk.x0, v)
Dolo.fit!(ψ, x)
r = Dolo.F(dmodel,wk.x0, ψ)
Dolo.ravel(r)
end



v = Dolo.ravel(wk.x0)

residual(dmodel, wk, v, wk.φ)

Jdiff = FiniteDiff.finite_difference_jacobian(u->residual(dmodel, wk, u, ψ), v)

J = Dolo.dF_1(dmodel, wk.x0, wk.φ)
Jmat = convert(Matrix, J)

@assert maximum( abs.(Jdiff - Jmat)./(1 .+ abs.(Jdiff)) ) <1e-6


Jdiff_2 = FiniteDiff.finite_difference_jacobian(u->residual_2(dmodel, wk, u), v)
J_2 = Dolo.dF_2(dmodel, wk.x0, wk.φ)

Jmat_2 = convert(Matrix, J_2)

using LinearMaps

L_2 = convert(LinearMap, J_2)

using Plots
plot(xlims=(-10, 10), ylims=(-2,2))

spy(Jmat)
spy(Jmat_2, xlims=(1,150), ylims=(1,150))

spy(Jmat_2)


Δ = Jdiff_2 - Jmat_2
maximum(abs.(Δ)./( 1 .+ Jmat_2))
# # model = DoloYAML.yaml_import("examples/ymodels/consumption_savings.yaml")


# wk = Dolo.time_iteration_workspace(dmodel; improve=true)

# # @time Dolo.time_iteration(dmodel, wk; verbose=false, improve=false, improve_wait=10);

# @time Dolo.time_iteration(dmodel, wk;
# verbose=false, improve=false, improve_wait=10, engine=nothing);


# dm.grid[1]

# [dm.grid...]


# # wk = Dolo.time_iteration_workspace(model; improve=true, interp_mode=:cubic)
# # @time Dolo.time_iteration(model, wk; verbose=true, improve=true, improve_wait=20);

model = include("examples/ymodels/consumption_savings.jl")

Dolo.time_iteration(model)

# # (r0, model, x0, φ, t_engine) = Dolo.time_iteration(model; verbose=false, improve=true);

# # Dolo.F!(r0, model, x0, φ, t_engine)
4 changes: 2 additions & 2 deletions examples/ymodels/consumption_savings_iid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ model = let
(:w,:r,:e),
SVector( [Q[i] for i=1:size(Q,1)]... )
)×CartesianSpace(;
y=[0.01, 100]
y=(0.01, 100.0)
)

controls = CartesianSpace(;
c=(0,Inf),
c=(0.0,Inf),
)
exogenous = Dolo.MarkovChain(
(:w,:r,:e), P,Q
Expand Down
8 changes: 5 additions & 3 deletions src/algos/time_iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ end
function time_iteration(model::YModel; kwargs...)
discr_options = get(kwargs, :discretization, Dict())
interp_mode = get(kwargs, :interpolation, :cubic)
improve = get(kwargs, :improve, false)
dmodel = discretize(model, discr_options)
wksp = time_iteration_workspace(dmodel; interp_mode=interp_mode)
wksp = time_iteration_workspace(dmodel; interp_mode=interp_mode, improve=improve)
kwargs2 = pairs(NamedTuple( k=>v for (k,v) in kwargs if !(k in (:discretization, :interpolation))))
time_iteration(dmodel, wksp; kwargs2...)
end
Expand Down Expand Up @@ -202,7 +203,7 @@ function time_iteration(model::DYModel,
# mem = typeof(workspace) <: Nothing ? time_iteration_workspace(model) : workspace
mbsteps = 5

(;x0, x1, x2, dx, r0, J, φ) = workspace
(;x0, x1, x2, r0, dx, J, φ) = workspace


local η_0 = NaN
Expand Down Expand Up @@ -235,6 +236,7 @@ function time_iteration(model::DYModel,
trace && push!(ti_trace.data, deepcopy(φ))

F!(r0, model, x0, φ, t_engine)

# r0 = F(model, x0, φ)

ε = norm(r0)
Expand All @@ -254,7 +256,7 @@ function time_iteration(model::DYModel,
for k=1:max_bsteps

F!(r0, model, x1, φ, t_engine)

ε_n = norm(r0)
if ε_n<tol_ε
iterations = t
Expand Down
3 changes: 2 additions & 1 deletion src/dev_L2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ function *(L::LL, x0)
Tf = getprecision(x0)
r.data .*= convert(Tf, 0.0)
mul!(r, L, x0)

r

end

# this takes 0.2 s !
Expand Down
9 changes: 6 additions & 3 deletions src/funs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ function DFun(domain, values, itp)
elseif eltype(values) <: SVector # && length(eltype(values)) == 1
vars = (:y,)
else
println(values)
# println(values)
nothing
end
return DFun(domain, values, itp, vars)
end
Expand All @@ -53,7 +54,8 @@ function DFun(states, values::GVector{G,V}, vars=nothing; interp_mode=:linear) w
elseif eltype(values) <: SVector # && length(eltype(values)) == 1
vars = (:y,)
else
println(values)
# println(values)
nothing
end
end

Expand All @@ -80,7 +82,8 @@ function DFun(states, values::GVector{G,V}, vars=nothing; interp_mode=:linear) w
elseif eltype(values) <: SVector # && length(eltype(values)) == 1
vars = (:y,)
else
println(values)
# println(values)
nothing
end
end

Expand Down

0 comments on commit f955fe2

Please sign in to comment.