Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/EconForge/Dolo.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
albop committed Oct 28, 2024
2 parents af4f695 + f955fe2 commit ca2beb4
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ['1.9','1.10']
julia-version: ['1.10','1.11']
julia-arch: [x64]
os: [ubuntu-latest, windows-latest, macOS-latest]

Expand Down
14 changes: 10 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,16 @@ StringDistances = "88034a9c-02f8-509d-84a9-84ec65e18404"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[weakdeps]
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
DolooneAPIExt = "oneAPI"
DoloCUDAExt = "CUDA"
DolooneAPIExt = "oneAPI"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
[compat]
Dolang = "3.3.0"
LabelledArrays = "≥1.16.0"
Plots = "1.40.8"
86 changes: 81 additions & 5 deletions demo.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,86 @@
model = include("examples/ymodels/rbc_mc.jl")
using Dolo
using DoloYAML

dm = Dolo.discretize(model)
# model = include("examples/ymodels/consumption_savings.jl")
model = include("examples/ymodels/rbc_iid.yaml")
Dolo.time_iteration(model; improve=false)

@time Dolo.time_iteration(dm, verbose=falsefalse);
dmodel = Dolo.discretize(model)

@time Dolo.time_iteration(dm; verbose=true, improve=true);
wk = Dolo.time_iteration_workspace(dmodel; improve=true)

@time wk = Dolo.time_iteration_workspace(dm);
using FiniteDiff

function residual_1(dmodel,wk,v)
x = Dolo.unravel(wk.x0, v)
r = Dolo.F(dmodel,x, wk.ψ)
Dolo.ravel(r)
end

ψ = deepcopy(wk.φ)

function residual_2(dmodel,wk,v)
x = Dolo.unravel(wk.x0, v)
Dolo.fit!(ψ, x)
r = Dolo.F(dmodel,wk.x0, ψ)
Dolo.ravel(r)
end



v = Dolo.ravel(wk.x0)

residual(dmodel, wk, v, wk.φ)

Jdiff = FiniteDiff.finite_difference_jacobian(u->residual(dmodel, wk, u, ψ), v)

J = Dolo.dF_1(dmodel, wk.x0, wk.φ)
Jmat = convert(Matrix, J)

@assert maximum( abs.(Jdiff - Jmat)./(1 .+ abs.(Jdiff)) ) <1e-6


Jdiff_2 = FiniteDiff.finite_difference_jacobian(u->residual_2(dmodel, wk, u), v)
J_2 = Dolo.dF_2(dmodel, wk.x0, wk.φ)

Jmat_2 = convert(Matrix, J_2)

using LinearMaps

L_2 = convert(LinearMap, J_2)

using Plots
plot(xlims=(-10, 10), ylims=(-2,2))

spy(Jmat)
spy(Jmat_2, xlims=(1,150), ylims=(1,150))

spy(Jmat_2)


Δ = Jdiff_2 - Jmat_2
maximum(abs.(Δ)./( 1 .+ Jmat_2))
# # model = DoloYAML.yaml_import("examples/ymodels/consumption_savings.yaml")


# wk = Dolo.time_iteration_workspace(dmodel; improve=true)

# # @time Dolo.time_iteration(dmodel, wk; verbose=false, improve=false, improve_wait=10);

# @time Dolo.time_iteration(dmodel, wk;
# verbose=false, improve=false, improve_wait=10, engine=nothing);


# dm.grid[1]

# [dm.grid...]


# # wk = Dolo.time_iteration_workspace(model; improve=true, interp_mode=:cubic)
# # @time Dolo.time_iteration(model, wk; verbose=true, improve=true, improve_wait=20);



# # (r0, model, x0, φ, t_engine) = Dolo.time_iteration(model; verbose=false, improve=true);

# # Dolo.F!(r0, model, x0, φ, t_engine)
4 changes: 2 additions & 2 deletions examples/ymodels/consumption_savings_iid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ model = let
(:w,:r,:e),
SVector( [Q[i] for i=1:size(Q,1)]... )
)×CartesianSpace(;
y=(0.01, 100)
y=(0.01, 100.0)
)

controls = CartesianSpace(;
c=(0,Inf),
c=(0.0,Inf),
)
exogenous = Dolo.MarkovChain(
(:w,:r,:e), P,Q
Expand Down
8 changes: 5 additions & 3 deletions src/algos/time_iteration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,9 @@ end
function time_iteration(model::YModel; kwargs...)
discr_options = get(kwargs, :discretization, Dict())
interp_mode = get(kwargs, :interpolation, :cubic)
improve = get(kwargs, :improve, false)
dmodel = discretize(model, discr_options)
wksp = time_iteration_workspace(dmodel; interp_mode=interp_mode)
wksp = time_iteration_workspace(dmodel; interp_mode=interp_mode, improve=improve)
kwargs2 = pairs(NamedTuple( k=>v for (k,v) in kwargs if !(k in (:discretization, :interpolation))))
time_iteration(dmodel, wksp; kwargs2...)
end
Expand Down Expand Up @@ -202,7 +203,7 @@ function time_iteration(model::DYModel,
# mem = typeof(workspace) <: Nothing ? time_iteration_workspace(model) : workspace
mbsteps = 5

(;x0, x1, x2, dx, r0, J, φ) = workspace
(;x0, x1, x2, r0, dx, J, φ) = workspace


local η_0 = NaN
Expand Down Expand Up @@ -235,6 +236,7 @@ function time_iteration(model::DYModel,
trace && push!(ti_trace.data, deepcopy(φ))

F!(r0, model, x0, φ, t_engine)

# r0 = F(model, x0, φ)

ε = norm(r0)
Expand All @@ -254,7 +256,7 @@ function time_iteration(model::DYModel,
for k=1:max_bsteps

F!(r0, model, x1, φ, t_engine)

ε_n = norm(r0)
if ε_n<tol_ε
iterations = t
Expand Down
3 changes: 2 additions & 1 deletion src/dev_L2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ function *(L::LL, x0)
Tf = getprecision(x0)
r.data .*= convert(Tf, 0.0)
mul!(r, L, x0)

r

end

# this takes 0.2 s !
Expand Down
23 changes: 16 additions & 7 deletions src/funs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ function DFun(domain, values, itp)
elseif eltype(values) <: SVector # && length(eltype(values)) == 1
vars = (:y,)
else
println(values)
# println(values)
nothing
end
return DFun(domain, values, itp, vars)
end
Expand All @@ -53,7 +54,8 @@ function DFun(states, values::GVector{G,V}, vars=nothing; interp_mode=:linear) w
elseif eltype(values) <: SVector # && length(eltype(values)) == 1
vars = (:y,)
else
println(values)
# println(values)
nothing
end
end

Expand All @@ -80,7 +82,8 @@ function DFun(states, values::GVector{G,V}, vars=nothing; interp_mode=:linear) w
elseif eltype(values) <: SVector # && length(eltype(values)) == 1
vars = (:y,)
else
println(values)
# println(values)
nothing
end
end

Expand Down Expand Up @@ -109,13 +112,19 @@ function fit!(φ::DFun, x::GVector{G}) where G<:CGrid

end

## PGrid
## PGrid ( SGrid × CGrid )


function (f::DFun{A,B,I,vars})(x::QP) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars
f(x.loc...)
end

function (f::DFun{A,B,I,vars})(i::Int, x::SVector{d2, U}) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars where d2 where U
f.itp[i](x)
end

function (f::DFun{A,B,I,vars})(i::Int64, j::Int64) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars where d2 where U
f((i,j))
function (f::DFun{A,B,I,vars})(i::Int, j::Int) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars
f.values[i,j]
end

function (f::DFun{A,B,I,vars})(x::QP) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars
Expand All @@ -136,7 +145,7 @@ function (f::DFun{A,B,I,vars})(loc::Tuple{Tuple{Int64}, SVector{d2, U}}) where
f.itp[loc[1][1]](x)
end

function (f::DFun{A,B,I,vars})(loc::Tuple{Int64, Int64}) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars where d2 where U
function (f::DFun{A,B,I,vars})(loc::Tuple{Int64, Int64}) where A where B<:GArray{G,V} where V where I where G<:PGrid{G1,G2} where G1<:SGrid where G2<:CGrid where vars
# TODO: not beautiful
i,j = loc
x = f.values.grid.grids[2][j]
Expand Down
2 changes: 0 additions & 2 deletions src/grids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ end
const PGrid = ProductGrid

function ProductGrid(A::AGrid{d1},B::AGrid{d2}) where d1 where d2
println(d1)
println(d2)
ProductGrid{typeof(A), typeof(B), d1+d2}( (A,B) )
end

Expand Down

0 comments on commit ca2beb4

Please sign in to comment.