Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
dmetivie committed Nov 5, 2024
1 parent 7b06d69 commit 93e505e
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 96 deletions.
8 changes: 4 additions & 4 deletions src/ExpectationMaximization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ argmaxrow(M) = [argmax(r) for r in eachrow(M)]
Evaluate the most likely category for each observations given a `MixtureModel`.
- `robust = true` will prevent the (log)likelihood to overflow to `-∞` or `∞`.
"""
function predict(mix::MixtureModel, y::AbstractVecOrMat; robust = false)
return argmaxrow(predict_proba(mix, y; robust = robust))
function predict(mix::MixtureModel, y::AbstractVecOrMat; robust=false)
return argmaxrow(predict_proba(mix, y; robust=robust))
end

"""
predict_proba(mix::MixtureModel, y::AbstractVecOrMat; robust=false)
Evaluate the probability for each observations to belong to a category given a `MixtureModel`..
- `robust = true` will prevent the (log)likelihood to under(overflow)flow to `-∞` (or `∞`).
"""
function predict_proba(mix::MixtureModel, y::AbstractVecOrMat; robust = false)
function predict_proba(mix::MixtureModel, y::AbstractVecOrMat; robust=false)
# evaluate likelihood for each components k
dists = mix.components
α = probs(mix)
Expand All @@ -44,7 +44,7 @@ function predict_proba(mix::MixtureModel, y::AbstractVecOrMat; robust = false)
LL = zeros(N, K)
γ = similar(LL)
c = zeros(N)
E_step!(LL, c, γ, dists, α, y; robust = robust)
E_step!(LL, c, γ, dists, α, y; robust=robust)
return γ
end

Expand Down
32 changes: 16 additions & 16 deletions src/classic_em.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ function fit_mle!(
dists::AbstractVector{F} where {F<:Distribution},
y::AbstractVecOrMat,
method::ClassicEM;
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
display=:none,
maxiter=1000,
atol=1e-3,
rtol=nothing,
robust=false,
)

@argcheck display in [:none, :iter, :final]
Expand All @@ -37,7 +37,7 @@ function fit_mle!(
c = zeros(N)

# E-step
E_step!(LL, c, γ, dists, α, y; robust = robust)
E_step!(LL, c, γ, dists, α, y; robust=robust)

# Loglikelihood
logtot = sum(c)
Expand All @@ -49,7 +49,7 @@ function fit_mle!(

# E-step
# evaluate likelihood for each type k
E_step!(LL, c, γ, dists, α, y; robust = robust)
E_step!(LL, c, γ, dists, α, y; robust=robust)

# Loglikelihood
logtotp = sum(c)
Expand Down Expand Up @@ -84,13 +84,13 @@ end
For the `ClassicEM` the weigths `γ` computed at E-step for each observation in `y` are used to update `α` and `dists`.
"""
function M_step!(α, dists, y::AbstractVecOrMat, γ, method::ClassicEM)
α[:] = mean(γ, dims = 1)
α[:] = mean(γ, dims=1)
dists[:] = [fit_mle(dists[k], y, γₖ) for (k, γₖ) in enumerate(eachcol(γ))]
end

#TODO: could probably replace γ, w by γ*w,
function M_step!(α, dists, y::AbstractVecOrMat, γ, w, method::ClassicEM)
α[:] = mean(γ, weights(w), dims = 1)
α[:] = mean(γ, weights(w), dims=1)
dists[:] = [fit_mle(dists[k], y, w[:] .* γₖ) for (k, γₖ) in enumerate(eachcol(γ))]
end

Expand All @@ -100,11 +100,11 @@ function fit_mle!(
y::AbstractVecOrMat,
w::AbstractVector,
method::ClassicEM;
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
display=:none,
maxiter=1000,
atol=1e-3,
rtol=nothing,
robust=false,
)

@argcheck display in [:none, :iter, :final]
Expand All @@ -120,7 +120,7 @@ function fit_mle!(
c = zeros(N)

# E-step
E_step!(LL, c, γ, dists, α, y; robust = robust)
E_step!(LL, c, γ, dists, α, y; robust=robust)

# Loglikelihood
logtot = sum(w[n] * c[n] for n = 1:N) #dot(w, c)
Expand All @@ -132,7 +132,7 @@ function fit_mle!(

# E-step
# evaluate likelihood for each type k
E_step!(LL, c, γ, dists, α, y; robust = robust)
E_step!(LL, c, γ, dists, α, y; robust=robust)

# Loglikelihood
logtotp = sum(w[n] * c[n] for n in eachindex(c)) #dot(w, c)
Expand Down
78 changes: 39 additions & 39 deletions src/fit_em.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ function fit_mle(
mix::MixtureModel,
y::AbstractVecOrMat,
weights...;
method = ClassicEM(),
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
infos = false,
method=ClassicEM(),
display=:none,
maxiter=1000,
atol=1e-3,
rtol=nothing,
robust=false,
infos=false,
)

# Initial parameters
Expand All @@ -34,11 +34,11 @@ function fit_mle(
dists,
y,
method;
display = display,
maxiter = maxiter,
atol = atol,
rtol = rtol,
robust = robust,
display=display,
maxiter=maxiter,
atol=atol,
rtol=rtol,
robust=robust,
)
else
history = fit_mle!(
Expand All @@ -47,11 +47,11 @@ function fit_mle(
y,
weights...,
method;
display = display,
maxiter = maxiter,
atol = atol,
rtol = rtol,
robust = robust,
display=display,
maxiter=maxiter,
atol=atol,
rtol=rtol,
robust=robust,
)
end

Expand All @@ -68,39 +68,39 @@ function fit_mle(
mix::AbstractArray{<:MixtureModel},
y::AbstractVecOrMat,
weights...;
method = ClassicEM(),
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
infos = false,
method=ClassicEM(),
display=:none,
maxiter=1000,
atol=1e-3,
rtol=nothing,
robust=false,
infos=false,
)

mx_max, history_max = fit_mle(
mix[1],
y,
weights...;
method = method,
display = display,
maxiter = maxiter,
atol = atol,
robust = robust,
infos = true,
method=method,
display=display,
maxiter=maxiter,
atol=atol,
robust=robust,
infos=true,
)
for j in eachindex(mix)[2:end]
try
mx_new, history_new = fit_mle(
mix[j],
y,
weights...;
method = method,
display = display,
maxiter = maxiter,
atol = atol,
rtol = rtol,
robust = robust,
infos = true,
method=method,
display=display,
maxiter=maxiter,
atol=atol,
rtol=rtol,
robust=robust,
infos=true,
)
if history_max["logtots"][end] < history_new["logtots"][end]
mx_max = mx_new
Expand All @@ -122,7 +122,7 @@ function E_step!(
dists::AbstractVector{F} where {F<:Distribution},
α::AbstractVector,
y::AbstractVector{<:Real};
robust = false,
robust=false,
) where {T<:AbstractFloat}
# evaluate likelihood for each type k
for k in eachindex(dists)
Expand All @@ -141,7 +141,7 @@ function E_step!(
dists::AbstractVector{F} where {F<:Distribution},
α::AbstractVector,
y::AbstractMatrix;
robust = false,
robust=false,
)
# evaluate likelihood for each type k
@views for k in eachindex(dists)
Expand Down
36 changes: 18 additions & 18 deletions src/stochastic_em.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ function fit_mle!(
dists::AbstractVector{F} where {F<:Distribution},
y::AbstractVecOrMat,
method::StochasticEM;
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
display=:none,
maxiter=1000,
atol=1e-3,
rtol=nothing,
robust=false,
)

@argcheck display in [:none, :iter, :final]
Expand All @@ -43,7 +43,7 @@ function fit_mle!(
c = zeros(N)
= zeros(Int, N)
# E-step
E_step!(LL, c, γ, dists, α, y; robust = robust)
E_step!(LL, c, γ, dists, α, y; robust=robust)

# Loglikelihood
logtot = sum(c)
Expand All @@ -59,7 +59,7 @@ function fit_mle!(

# E-step
# evaluate likelihood for each type k
E_step!(LL, c, γ, dists, α, y; robust = robust)
E_step!(LL, c, γ, dists, α, y; robust=robust)

# Loglikelihood
logtotp = sum(c)
Expand Down Expand Up @@ -95,22 +95,22 @@ For the `StochasticEM` the `cat` drawn at S-step for each observation in `y` is
"""
function M_step!(α, dists, y::AbstractVector, cat, method::StochasticEM)
#
α[:] = length.(cat)/size_sample(y)
α[:] = length.(cat) / size_sample(y)
dists[:] = [fit_mle(dists[k], y[cₖ]) for (k, cₖ) in enumerate(cat)]
end

function M_step!(α, dists, y::AbstractMatrix, cat, method::StochasticEM)
α[:] = length.(cat)/size_sample(y)
α[:] = length.(cat) / size_sample(y)
dists[:] = [fit_mle(dists[k], y[:, cₖ]) for (k, cₖ) in enumerate(cat)]
end

function M_step!(α, dists, y::AbstractVector, cat, w, method::StochasticEM)
α[:] = [sum(w[cₖ]) for cₖ in cat]/sum(w)
α[:] = [sum(w[cₖ]) for cₖ in cat] / sum(w)
dists[:] = [fit_mle(dists[k], y[cₖ], w[cₖ]) for (k, cₖ) in enumerate(cat)]
end

function M_step!(α, dists, y::AbstractMatrix, cat, w, method::StochasticEM)
α[:] = [sum(w[cat[k]]) for k in 1:K]/sum(w)
α[:] = [sum(w[cat[k]]) for k in 1:K] / sum(w)
dists[:] = [fit_mle(dists[k], y[:, cₖ], w[cₖ]) for (k, cₖ) in enumerate(cat)]
end

Expand All @@ -121,11 +121,11 @@ function fit_mle!(
y::AbstractVecOrMat,
w::AbstractVector,
method::StochasticEM;
display = :none,
maxiter = 1000,
atol = 1e-3,
rtol = nothing,
robust = false,
display=:none,
maxiter=1000,
atol=1e-3,
rtol=nothing,
robust=false,
)

@argcheck display in [:none, :iter, :final]
Expand All @@ -141,7 +141,7 @@ function fit_mle!(
c = zeros(N)

# E-step
E_step!(LL, c, γ, dists, α, y; robust = robust)
E_step!(LL, c, γ, dists, α, y; robust=robust)

# Loglikelihood
logtot = sum(w[n] * c[n] for n = 1:N) #dot(w, c)
Expand All @@ -157,7 +157,7 @@ function fit_mle!(

# E-step
# evaluate likelihood for each type k
E_step!(LL, c, γ, dists, α, y; robust = robust)
E_step!(LL, c, γ, dists, α, y; robust=robust)

# Loglikelihood
logtotp = sum(w[n] * c[n] for n in eachindex(c)) #dot(w, c)
Expand Down
Loading

0 comments on commit 93e505e

Please sign in to comment.