Skip to content

Commit

Permalink
updates cleaning up code, adding debug, other tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
njwfish committed Dec 3, 2023
1 parent de1f164 commit c26311f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 110 deletions.
33 changes: 23 additions & 10 deletions src/ConeProj.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ConeProj


include("QRupdate.jl")
include("UpdatableQR.jl")

using LinearAlgebra
Expand All @@ -21,9 +22,10 @@ function nnls(A, b; p=0, passive_set=nothing, uqr=nothing, tol=1e-8, maxit=nothi
bhat = zeros(n)
coefs = zeros(m + p)

if (passive_set === nothing) | (uqr === nothing)
if (passive_set == nothing) | (uqr == nothing)
if p == 0
passive_set = Vector{Int}()
uqr = nothing
else
passive_set = Vector(1:p)
uqr = UpdatableQR(A[:, passive_set])
Expand All @@ -43,16 +45,16 @@ function nnls(A, b; p=0, passive_set=nothing, uqr=nothing, tol=1e-8, maxit=nothi
coefs[passive_set] .= 0

if length(passive_set) == p
push!(passive_set, max_ind)
if (p == 0) & (uqr === nothing)
uqr = UpdatableQR(A[:, passive_set])
if uqr == nothing
uqr = UpdatableQR(A[:, max_ind:max_ind])
else
add_column!(uqr, A[:, max_ind])
end
push!(passive_set, max_ind)
end

for i in 1:maxit
A_passive = view(A, :, passive_set)
A_passive = A[:, passive_set]
coef_passive = solvex(uqr.R1, A_passive, b)
if length(coef_passive) > p
min_ind = p + partialsortperm(coef_passive[p+1:end], 1, rev=false)
Expand Down Expand Up @@ -84,7 +86,7 @@ function nnls(A, b; p=0, passive_set=nothing, uqr=nothing, tol=1e-8, maxit=nothi
end

#TODO #3 Investiage the issue even with the 1D case where we end up with negative coefficients
function ecnnls(A, b, C, d; p=0, passive_set=nothing, uqr=nothing, tol=1e-8, maxit=nothing)
function ecnnls(A, b, C, d; p=0, passive_set=nothing, uqr=nothing, tol=1e-8, maxit=nothing, debug=false, debug_proj=false)
optimal = true
n, = size(b)
q = size(d)
Expand All @@ -104,6 +106,7 @@ function ecnnls(A, b, C, d; p=0, passive_set=nothing, uqr=nothing, tol=1e-8, max
if (passive_set == nothing) | (uqr == nothing)
if p == 0
passive_set = Vector{Int}()
uqr = nothing
else
passive_set = Vector(1:p)
uqr = UpdatableQR(A[:, passive_set])
Expand All @@ -122,13 +125,13 @@ function ecnnls(A, b, C, d; p=0, passive_set=nothing, uqr=nothing, tol=1e-8, max
max_ind = feasible_constraint_set[
partialsortperm(proj_resid[feasible_constraint_set], 1, rev=true)
]
# constraint_set = [max_ind]
push!(passive_set, max_ind)
if (p == 0) & (uqr === nothing)
uqr = UpdatableQR(A[:, passive_set])
constraint_set = [max_ind]
if uqr == nothing
uqr = UpdatableQR(A[:, max_ind:max_ind])
else
add_column!(uqr, A[:, max_ind])
end
push!(passive_set, max_ind)
# _, r = qr(A)
# Cp = (pinv(r) * C')'
# print(size(Cp))
Expand All @@ -153,6 +156,13 @@ function ecnnls(A, b, C, d; p=0, passive_set=nothing, uqr=nothing, tol=1e-8, max
for it in 1:maxit
A_passive = A[:, passive_set]
coef_passive, lambd = solvexeq(uqr.R1, A_passive, b, C[:, passive_set], d)
if debug
println("iter: ", it)
println("\tpassive_set: ", passive_set)
println("\tcoefs: ", coef_passive)
println("\tlambd: ", lambd)

end
if length(coef_passive) > p
min_ind = p + partialsortperm(coef_passive[p+1:end], 1, rev=false)
# println("min_ind", min_ind, " ", coef_passive[min_ind], " ", sort(coef_passive))
Expand All @@ -163,6 +173,9 @@ function ecnnls(A, b, C, d; p=0, passive_set=nothing, uqr=nothing, tol=1e-8, max
bhat = A_passive * coef_passive
proj_resid = (A' * (b - bhat) + C' * lambd) / n
max_ind = partialsortperm(proj_resid, 1, rev=true)
if debug_proj
println("\tproj_resid: ", proj_resid[max_ind], " ", proj_resid)
end
if proj_resid[max_ind] < 2 * tol
coefs[passive_set] = coef_passive
@goto done
Expand Down
2 changes: 1 addition & 1 deletion src/QRupdate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,4 @@ function sutaddcol(
m = (c - M * R[1:end-1, end]) / R[end, end]
M = [M m]
return M
end
end
105 changes: 6 additions & 99 deletions src/UpdatableQR.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using LinearAlgebra

mutable struct UpdatableQR{T} <: Factorization{T}
"""
Gives the qr factorization an (n, m) matrix as Q1*R1
Expand Down Expand Up @@ -28,10 +26,11 @@ mutable struct UpdatableQR{T} <: Factorization{T}
UpperTriangular(view(R, 1:m, 1:m)))
end

end

function add_column!(F::Nothing, a::AbstractVector{T}) where {T}
return UpdatableQR(reshape(a, length(a), 1))
function Base.copy(F::UpdatableQR{T}) where {T}
new{T}(copy(F.Q), copy(F.R), F.n, F.m,
view(F.Q, :, 1:F.m), view(F.Q, :, F.m+1:F.n),
UpperTriangular(view(F.R, 1:F.m, 1:F.m)))
end
end

function add_column!(F::UpdatableQR{T}, a::AbstractVector{T}) where {T}
Expand Down Expand Up @@ -90,96 +89,4 @@ function update_views!(F::UpdatableQR{T}) where {T}
F.R1 = UpperTriangular(view(F.R, 1:F.m, 1:F.m))
F.Q1 = view(F.Q, :, 1:F.m)
F.Q2 = view(F.Q, :, F.m+1:F.n)
end


"""
Same as csne but only for x
"""
function solvex(Rin::AbstractMatrix{T}, A::AbstractMatrix{T}, b::Vector{T}) where {T}
R = UpperTriangular(Rin)
q = A' * b
x = R' \ q
x = R \ x
return x
end

"""
Solve the corrected semi-normal equations `R'Rx=A'b`.
x, r = csne(R, A, b) solves the least-squares problem
minimize ||r||_2, where r := b - A*x
using the corrected semi-normal equation approach described by
Bjork (1987). Assumes that `R` is upper triangular.
"""
function csne(Rin::AbstractMatrix{T}, A::AbstractMatrix{T}, b::Vector{T}) where {T}

R = UpperTriangular(Rin)
q = A'*b
x = R' \ q

bnorm2 = sum(b.^2)
xnorm2 = sum(x.^2)
d2 = bnorm2 - xnorm2

x = R \ x

# Apply one step of iterative refinement.
r = b - A*x
q = A'*r
dx = R' \ q
dx = R \ dx
x += dx
r = b - A*x
return (x, r)
end

function solvexeq(
Rin::AbstractMatrix{T}, A::AbstractMatrix{T}, b::Vector{T}, C::AbstractMatrix{T}, d::Vector{T}
) where {T}
x = solvex(Rin, A, b)
R = UpperTriangular(Rin)
M = C / R
_, U = qr(M')
y = U' \ (d - C * x)
y = U \ y
z = R' \ (C' * y)
z = R \ z
x = x + z
return x, y
end

# TODO: There should be a more efficient implementation of this. Empirically when we add a row/col
# to R that just adds a col to C, leaving the rest unchanged. There should be some way to compute this.
# If we can cheaply compute that additional column then we can keep a qless factorization of M' below by
# using the addrow functions above. I believe this will be faster than the full qr decomposition each time.
function solvexeq(
Rin::AbstractMatrix{T}, A::AbstractMatrix{T}, b::Vector{T}, C::AbstractMatrix{T}, d::AbstractMatrix{T}
) where {T}
x = solvex(Rin, A, b)
R = UpperTriangular(Rin)
M = C / R
_, U = qr(M')
y = U' \ (d - C * x)
y = U \ y
z = R' \ (C' * y)
z = R \ z
x = x + z
return x, y
end

function solvexeq(
Rin::AbstractMatrix{T}, A::AbstractMatrix{T}, b::Vector{T}, Uin::AbstractMatrix{T}, C::AbstractMatrix{T}, d::AbstractMatrix{T}
) where {T}
x = solvex(Rin, A, b)
R = UpperTriangular(Rin)
U = UpperTriangular(Uin)
y = U' \ (d - C * x)
y = U \ y
z = R' \ (C' * y)
z = R \ z
x = x + z
return x, y
end
end

0 comments on commit c26311f

Please sign in to comment.