Skip to content

Commit

Permalink
Codespell. Fixed some typos
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Nicodemus committed Jan 14, 2025
1 parent 20d23cf commit 5cce045
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 31 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,4 @@ skip = '.git*,*.css,pyproject.toml,data,swc,*.swc,obj,*.obj,*.pyi,*.c'
check-hidden = true
# embedded images into jupyter notebooks, acronyms and names starting with capital letter
ignore-regex = '(^\s*"image/\S+": ".*|\b([A-Z][a-zA-Z]+|scl/fo/|ser: Series|ot\.lp|networkx\.algorithms\.mis)\b)'
ignore-words-list = 'coo'
ignore-words-list = 'coo,ot'
2 changes: 1 addition & 1 deletion src/cajal/ugw/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def rho_of(gw_cost: float, mass_kept: float):
add some overhead to the algorithm because of the time to
compute the GW values.""",
"",
""":param mass_kept: A real number between 0 and 1, the minumum fraction of mass
""":param mass_kept: A real number between 0 and 1, the minimum fraction of mass
to be preserved by UGW transport plans between two cells in the same neighborhood.""",
_eps_docstring,
""":param dmats: An array of squareform distance matrices of shape (k,n,n),
Expand Down
49 changes: 21 additions & 28 deletions src/cajal/ugw/unbalanced_gw_core.fut
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module unbalanced_gw_core (M: real) = {
(gw.GW_cost X Y P) M.+
(let P1 = map M.sum P in
rho1 M.* (KL2 (tensor P1 P1) (tensor mu mu))) M.+
(let P2 = map M.sum (transpose P) in
(let P2 = map M.sum (transpose P) in
rho2 M.* (KL2 (tensor P2 P2) (tensor nu nu)))

def F [n][m] (A : [n][n]t) (B :[m][m]t) pi gamma mu nu rho1 rho2 =
Expand All @@ -30,7 +30,7 @@ module unbalanced_gw_core (M: real) = {
-- Of course we have the same for
-- D_phi rho1 (\pi2 \otimes \gamma2 ) (\nu\otimes\nu).

-- This is the thing we are trying to optimize.
-- This is the thing we are trying to optimize.
def Feps (r : sinkhorn.otp[][]) A B pi gamma =
F A B pi gamma r.mu r.nu r.rho1 r.rho2
M.+ r.eps M.* KL4 pi gamma r.mu r.nu
Expand All @@ -45,7 +45,7 @@ module unbalanced_gw_core (M: real) = {
def compensate [m][n] (r : sinkhorn.otp[m][n]) =
let m_mu = (M.sum r.mu) in
let m_nu = (M.sum r.nu) in
let m_mu_nu = m_mu M.* m_nu in
let m_mu_nu = m_mu M.* m_nu in
M.(r.rho1 * m_mu * m_mu +
r.rho2 * m_nu * m_nu
+ r.eps * m_mu_nu * m_mu_nu
Expand All @@ -62,15 +62,12 @@ module unbalanced_gw_core (M: real) = {
let c_eps = r.eps M.* KLu2 r.C r.mu r.nu in
let c_rho1 = r.rho1 M.* KLu pi_X r.mu in
let c_rho2 = r.rho2 M.* KLu pi_Y r.nu in
let ans =
M.(gw_cost + (i64 2 * (c_eps +
let x =
M.(gw_cost + (i64 2 * (c_eps +
c_rho1 + c_rho2) * mP) - ((mP * mP) * (r.eps +
r.rho1 + r.rho2)))
in
ans M.+ compensate r
-- let diff = M.(abs(ans + compensate r - UGW_eps r X Y)) in
-- -- if (diff < f64 1e-5) then (UGW_eps r X Y)
-- assert M.( (#[trace]diff) < f64 1e-5) (UGW_eps r X Y)
x M.+ compensate r

-- The UGW cost together with the constituent costs.
def UGW_cost_arr (r : sinkhorn.otp[][]) X Y =
Expand Down Expand Up @@ -183,9 +180,9 @@ module unbalanced_gw_core (M: real) = {
eps : t
}

-- def ugw_loop_structure [m][n]
-- def ugw_loop_structure [m][n]
-- (initialization_function : (sinkhorn.otp[m][n]) -> [m][n]M.t -> [m][n]M.t -> params -> [m]M.t * [n]M.t * [m][n]M.t)
-- (update : (sinkhorn.otp[m][n]) -> [m][n]M.t -> [m][n]M.t -> params -> [m]M.t * [n]M.t * [m][n]M.t))
-- (update : (sinkhorn.otp[m][n]) -> [m][n]M.t -> [m][n]M.t -> params -> [m]M.t * [n]M.t * [m][n]M.t))

def unbalanced_gw_init [n][m] (r : sinkhorn.otp[][]) X Y params tol_outerloop =
let (u0, v0, p0) =
Expand Down Expand Up @@ -228,7 +225,7 @@ module unbalanced_gw_core (M: real) = {
-- \nabla UGW_eps = \nabla_L + eps * \nabla KL4.
map2 (map2 (\c a -> M.fma a r.eps c)) (nabla_L r.rho1 r.rho2 A r.mu B r.nu P) (nabla_KL4 P r.mu r.nu)
|> map (map M.(\a -> max (neg (f64 1e100)) a))

def nabla_UGW_eps_debug [m][n] rho1 rho2 epsilon
(A: [m][m]t) (mu :[m]t) (B: [n][n]t) (nu: [n]t) (P: [m][n]t) (diff: [m][n]t)
: t =
Expand All @@ -239,35 +236,35 @@ module unbalanced_gw_core (M: real) = {
let current_gw_loss = trace (gw.GW_cost A B P) in
let new_gw_loss = trace (gw.GW_cost A B P') in
let linear_loss_diff = trace (frobenius (gw.nabla_G A B P) diff) in
let guess_ratio_gw =
let guess_ratio_gw =
M.( (new_gw_loss - current_gw_loss)/linear_loss_diff) in
let guess_ratio_gw = #[trace] guess_ratio_gw in
mytrace guess_ratio_gw new_gw_loss
in
in

let margin_A =
let current_marginal_loss_A = KL2 (tensor (map M.sum P) (map M.sum P)) (tensor mu mu) in
let new_marginal_loss_A = KL2 (tensor (map M.sum P') (map M.sum P')) (tensor mu mu) in
let linear_loss_diff = frobenius (nabla_marginal mu P) diff in
let guess_ratio_margin =
let guess_ratio_margin =
M.((new_marginal_loss_A - current_marginal_loss_A)/linear_loss_diff) in
let guess_ratio_margin = #[trace] guess_ratio_margin in
mytrace guess_ratio_margin (rho1 M.* new_marginal_loss_A )
in
in

let margin_B =
-- let current_marginal_loss_B = KL (map M.sum (transpose P)) nu in
let current_marginal_loss_B = KL2 (tensor (map M.sum (transpose P)) (map M.sum (transpose P))) (tensor nu nu) in
let new_marginal_loss_B = KL2 (tensor (map M.sum (transpose P')) (map M.sum (transpose P'))) (tensor nu nu) in
-- let new_marginal_loss_B = KL (map2 (M.+) (map M.sum (transpose P)) (map M.sum (transpose diff))) nu in
let linear_loss_diff = frobenius (nabla_marginal nu (transpose P)) (transpose diff) in
let guess_ratio_margin =
let guess_ratio_margin =
M.((new_marginal_loss_B - current_marginal_loss_B)/linear_loss_diff) in
let guess_ratio_margin = #[trace] guess_ratio_margin in
mytrace guess_ratio_margin (rho2 M.* new_marginal_loss_B)
in

let KL4_div =
let KL4_div =
let current_loss_KL4 = (KL4 P P mu nu) in

let new_loss_KL4 = (KL4 P' P' mu nu) in
Expand All @@ -277,15 +274,11 @@ module unbalanced_gw_core (M: real) = {
mytrace (#[trace] guess_ratio) (epsilon M.* new_loss_KL4)
in

let ans = gw_cost M.+ margin_A M.+ margin_B M.+ KL4_div in
let x = gw_cost M.+ margin_A M.+ margin_B M.+ KL4_div in
let other_val = UGW_eps {rho1, rho2, eps=epsilon, mu, nu, C = (map2 (map2 (M.+)) P diff) } A B in
mytrace
(#[trace] other_val)
(#[trace] ans)

-- assert M.(abs(ans - other_val) < f64 1e-10) other_val

-- assert M.(KL4_div - (epsilon * (KL4_div_1 + KL4_div_2)) M.< M.f64 1e-10)
(#[trace] x)

def safe_starting_diff [m][n] (point: [m][n]M.t) (step: [m][n]M.t) : [m][n]M.t =
if M.(map2 (map2 (+)) point step) |> map (all (M.i64 0 M.<)) |> reduce (&&) true then step else
Expand Down Expand Up @@ -377,18 +370,18 @@ module unbalanced_gw_core (M: real) = {
let diff = (map2 (map2 (M.-)) x a.C) in
armijo_line_search loss_fn (initial_loss) a.C diff
(\A B -> gradient_fn A |> frobenius B) (M.f64 0.5) (M.f64 0.5)
in
in
(u1, v1, map2 (map2 (M.+)) a.C diff, loss)
else
(map (\_ -> zero) u1, map (\_ -> zero) v1, map (map (\_ -> zero)) C' , M.i64 0)
def ratio_err_ok tol a b =

def ratio_err_ok tol a b =
M.(a * (one + tol) >= b && b * (one + tol) >= a)

def init [n][m] (r : sinkhorn.otp[][]) X Y params tol_outerloop =
let (u0, v0, p0) =
unbalanced_gw_init_step r X Y params in
let initial_loss = UGW_eps_1 r X Y in
let initial_loss = UGW_eps_1 r X Y in
let update (u: [n]t) (v:[m]t) (p: [n][m]t) (loss: t): ([n]t, [m]t, [n][m]t, t) =
descent ((r with C = p) : sinkhorn.otp[][]) X Y u v params loss
in
Expand Down
2 changes: 1 addition & 1 deletion src/cajal/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def write_npz(
) -> list[tuple[str, Err[T]]]:
"""
Write the stream to an npz file. This writing method keeps all data in memory
at one time so it may be inappropriate for situtations where the point clouds are
at one time so it may be inappropriate for situations where the point clouds are
large or there are many cells.
:param sidelength: The side length of all matrices in dist_mats.
Expand Down

0 comments on commit 5cce045

Please sign in to comment.