Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
dmetivie committed Dec 28, 2023
1 parent 4f86333 commit 4303fbb
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 42 deletions.
2 changes: 0 additions & 2 deletions docs/src/benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ You can find the benchmark code in [here](https://github.com/dmetivie/Expectatio

![timing_K_2_rudimentary_wo_memory_leak](https://user-images.githubusercontent.com/46794064/227195619-c75b9276-932b-4029-8b49-6cce919acc87.svg)

<!-- I guess that to increase performance in this package, it would be nice to be able to do in place `fit_mle` for large multidimensional cases. -->

[^1]: Note that `@btime` with `RCall` and `PyCall` might produce a small-time overhead compare to the true R/Python time see [here for example](https://discourse.julialang.org/t/benchmarking-julia-vs-python-vs-r-with-pycall-and-rcall/37308).
I did compare with `R` `microbenchmark` and Python `timeit` and it produces very similar timing but in my experience `BenchmarkTools` is smarter and simpler to use, i.e. it will figure out alone the number of repetition to do in function of the run.

Expand Down
80 changes: 40 additions & 40 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@ using Random
y = rand(mix_true, N)
mix_guess = MixtureModel([Exponential(1), Gamma(0.5, 1)], [0.5, 1 - 0.5])
mix_mle =
fit_mle(mix_guess, y; display = :none, atol = 1e-3, robust = false, infos = false)
fit_mle(mix_guess, y; display=:none, atol=1e-3, robust=false, infos=false)

p = params(mix_mle)[1]
@test isapprox([β, 1 - β], probs(mix_mle); rtol = rtol)
@test isapprox(θ₁, p[1]...; rtol = rtol)
@test isapprox(α, p[2][1]; rtol = rtol)
@test isapprox(θ₂, p[2][2]; rtol = rtol)
@test isapprox([β, 1 - β], probs(mix_mle); rtol=rtol)
@test isapprox(θ₁, p[1]...; rtol=rtol)
@test isapprox(α, p[2][1]; rtol=rtol)
@test isapprox(θ₂, p[2][2]; rtol=rtol)

# Test rtol
mix_mle2 =
fit_mle(mix_guess, y; display = :none, rtol = 1e-8, atol = 0, robust = false, infos = false)
fit_mle(mix_guess, y; display=:none, rtol=1e-8, atol=0, robust=false, infos=false)
p = params(mix_mle2)[1]
@test isapprox([β, 1 - β], probs(mix_mle2); rtol = rtol)
@test isapprox(θ₁, p[1]...; rtol = rtol)
@test isapprox(α, p[2][1]; rtol = rtol)
@test isapprox(θ₂, p[2][2]; rtol = rtol)
@test isapprox([β, 1 - β], probs(mix_mle2); rtol=rtol)
@test isapprox(θ₁, p[1]...; rtol=rtol)
@test isapprox(α, p[2][1]; rtol=rtol)
@test isapprox(θ₂, p[2][2]; rtol=rtol)
end

@testset "Stochastic EM Univariate continuous Mixture Exponential + Laplace" begin
Expand All @@ -49,36 +49,36 @@ end
mix_mle = fit_mle(
mix_guess,
y;
display = :none,
atol = 1e-3,
robust = false,
infos = false,
method = StochasticEM(),
display=:none,
atol=1e-3,
robust=false,
infos=false,
method=StochasticEM(),
)

p = params(mix_mle)[1]
@test isapprox([β, 1 - β], probs(mix_mle); rtol = rtol)
@test isapprox(θ₁, p[1][2]; rtol = rtol)
@test isapprox(μ, p[1][1]; rtol = rtol)
@test isapprox(α, p[2][1]; rtol = rtol)
@test isapprox(θ₂, p[2][2]; rtol = rtol)
@test isapprox([β, 1 - β], probs(mix_mle); rtol=rtol)
@test isapprox(θ₁, p[1][2]; rtol=rtol)
@test isapprox(μ, p[1][1]; rtol=rtol)
@test isapprox(α, p[2][1]; rtol=rtol)
@test isapprox(θ₂, p[2][2]; rtol=rtol)

mix_mle2 = fit_mle(
mix_guess,
y;
display = :none,
atol = 0,
rtol = 1e-6,
robust = false,
infos = false,
method = StochasticEM(),
display=:none,
atol=0,
rtol=1e-6,
robust=false,
infos=false,
method=StochasticEM(),
)
p = params(mix_mle2)[1]
@test isapprox([β, 1 - β], probs(mix_mle2); rtol = rtol)
@test isapprox(θ₁, p[1][2]; rtol = rtol)
@test isapprox(μ, p[1][1]; rtol = rtol)
@test isapprox(α, p[2][1]; rtol = rtol)
@test isapprox(θ₂, p[2][2]; rtol = rtol)
@test isapprox([β, 1 - β], probs(mix_mle2); rtol=rtol)
@test isapprox(θ₁, p[1][2]; rtol=rtol)
@test isapprox(μ, p[1][1]; rtol=rtol)
@test isapprox(α, p[2][1]; rtol=rtol)
@test isapprox(θ₂, p[2][2]; rtol=rtol)
end

@testset "Multivariate Gaussian Mixture" begin
Expand Down Expand Up @@ -111,12 +111,12 @@ end

# Fit MLE
mix_mle =
fit_mle(mix_guess, y; display = :none, atol = 1e-3, robust = false, infos = false)
fit_mle(mix_guess, y; display=:none, atol=1e-3, robust=false, infos=false)

p = params(mix_mle)[1]
@test isapprox([β, 1 - β], probs(mix_mle); rtol = rtol)
@test isapprox(collect(p[1]), [θ₁, Σ₁], rtol = rtol)
@test isapprox(collect(p[2]), [θ₂, Σ₂], rtol = rtol)
@test isapprox([β, 1 - β], probs(mix_mle); rtol=rtol)
@test isapprox(collect(p[1]), [θ₁, Σ₁], rtol=rtol)
@test isapprox(collect(p[2]), [θ₂, Σ₂], rtol=rtol)
end

# Bernoulli Mixture i.e. Mixture of Bernoulli Product (S = 10 term and K = 3 mixture components).
Expand Down Expand Up @@ -149,11 +149,11 @@ end

# Fit MLE
mix_mle =
fit_mle(mix_guess, y; display = :none, atol = 1e-3, robust = false, infos = false)
fit_mle(mix_guess, y; display=:none, atol=1e-3, robust=false, infos=false)

p = params(mix_mle)[1]
@test isapprox([β / 2, 1 - β, β / 2], probs(mix_mle); rtol = rtol)
@test isapprox(first.(hcat(p...)), θ, rtol = rtol)
@test isapprox([β / 2, 1 - β, β / 2], probs(mix_mle); rtol=rtol)
@test isapprox(first.(hcat(p...)), θ, rtol=rtol)
end

@testset "Univariate continuous Mixture of (mixture + Normal)" begin
Expand Down Expand Up @@ -185,7 +185,7 @@ end

mix_guess = MixtureModel([d1_guess, d2_guess], [β + 0.1, 1 - β - 0.1])
mix_mle =
fit_mle(mix_guess, y; display = :none, atol = 1e-3, robust = false, infos = false)
fit_mle(mix_guess, y; display=:none, atol=1e-3, robust=false, infos=false)
y_guess = rand(mix_mle, N)

@test probs(mix_mle) [β, 1 - β] rtol = rtol
Expand Down Expand Up @@ -228,7 +228,7 @@ end

mix_guess = MixtureModel([d1_guess, d2_guess], [β + 0.1, 1 - β - 0.1])
mix_mle =
fit_mle(mix_guess, y; display = :none, atol = 1e-3, robust = false, infos = false)
fit_mle(mix_guess, y; display=:none, atol=1e-3, robust=false, infos=false)
# without print
# 1.368 s (17002715 allocations: 1.48 GiB)
# 1.485 s (17853393 allocations: 1.61 GiB)
Expand Down

0 comments on commit 4303fbb

Please sign in to comment.