Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numba dfun for model KIonEx #727

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 126 additions & 1 deletion tvb_library/tvb/simulator/models/k_ion_exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@

import numpy

from numba import guvectorize, float64

class KIonEx(Model):
r"""
KIonEx (Potassium K+ Ion exchange) mean-field model was developed in (Bandyopadhyay & Rabuffo et al. 2023).
Expand Down Expand Up @@ -208,7 +210,7 @@ class KIonEx(Model):
# Stvar is the variable where stimulus is applied.
stvar = numpy.array([1], dtype=numpy.int32)

def dfun(self, state_variables, coupling, local_coupling=0.0):
def _numpy_dfun(self, state_variables, coupling, local_coupling=0.0):
r"""
The mean-field approximation for a population of Hodgkin-Huxley-type neurons driven by slow potassium dynamics consists of a 5D system:

Expand Down Expand Up @@ -343,3 +345,126 @@ def V_dot_form(I_Na,I_K,I_Cl,I_pump):
derivative[4] = epsilon * (K_bath - K_o)

return derivative

def dfun(self, x, c, local_coupling=0.0):

x_ = x
c_ = c + local_coupling * x[0]
deriv = _numba_dfun(x_, c_, self.E, self.K_bath, self.J, self.eta, self.Delta, self.c_minus, self.R_minus,
self.c_plus, self.R_plus, self.Vstar, self.Cm, self.tau_n, self.gamma, self.epsilon)
return deriv

@guvectorize([(float64[:],) * 17], '(n),(m)' + ',()' * 14 + '->(n)', nopython=True)
def _numba_dfun(state_variables, coupling, E, K_bath, J, eta, Delta, c_minus, R_minus, c_plus, R_plus, Vstar, Cm,
tau_n, gamma, epsilon, dx):
r"""
The mean-field approximation for a population of Hodgkin-Huxley-type neurons driven by slow potassium dynamics consists of a 5D system:

.. math::
\frac{dx}{dt}&=
\begin{cases}
\Delta+2R_{-}(V-c_{-})x - J r x; \ V\leq V^{\star}\\
\Delta+2R_{+}(V-c_{+})x - J r x; \ V> V^{\star},
\end{cases}\\
\frac{dV}{dt}&=
\begin{cases}
-\frac{1}{C_m}(I_{Cl}+I_{Na}+I_{K}+I_{pump})-R_{-}x^2+J r(E_{syn}-V)+\overline{\eta}; \ V\leq V^{\star}\\
-\frac{1}{C_m}(I_{Cl}+I_{Na}+I_{K}+I_{pump})-R_{+}x^2+J r(E_{syn}-V)+\overline{\eta}; \ V>V^{\star},
\end{cases}\\
\frac{dn}{dt} &= \frac{n_{\infty}(V)-n}{\tau_n}, \\
\frac{d \Delta [K^{+}]_{int}}{dt} &= - \frac{\gamma}{\omega_i}(I_K - 2 I_{pump}),\\
\frac{d[K^+]_g}{dt} &= \epsilon ([K^+]_{bath} - [K^+]_{ext}\}).\\

For details refer to (Bandyopadhyay & Rabuffo et al. 2023)
"""

x = state_variables[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can probably tuple unpack this like

x, V, n, DKi, Kg = state_variables

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

V = state_variables[1]
n = state_variables[2]
DKi = state_variables[3]
Kg = state_variables[4]

Coupling_Term = coupling[0] # This zero refers to the first element of cvar (trivial in this case)

# Constants
Cnap = 21.0 # mol.m**-3
DCnap = 2.0 # mol.m**-3
Ckp = 5.5 # mol.m**-3
DCkp = 1.0 # mol.m**-3
Cmna = -24.0 # mV
DCmna = 12.0 # mV
Chn = 0.4 # dimensionless
DChn = -8.0 # dimensionless
Cnk = -19.0 # mV
DCnk = 18.0 # mV #Ok in the paper
g_Cl = 7.5 # nS #Ok in the paper # chloride conductance
g_Na = 40.0 # nS # maximal sodiumconductance
g_K = 22.0 # nS # maximal potassium conductance
g_Nal = 0.02 # nS # sodium leak conductance
g_Kl = 0.12 # nS # potassium leak conductance
rho = 250. # 250.,#pA # maximal Na/K pump current
w_i = 2160.0 # umeter**3 # intracellular volume
w_o = 720.0 # umeter**3 # extracellular volume
Na_i0 = 16.0 # mMol/m**3 # initial concentration of intracellular Na
Na_o0 = 138.0 # mMol/m**3 # initial concentration of extracellular Na
K_i0 = 130.0 # mMol/m**3 # initial concentration of intracellular K
K_o0 = 4.80 # mMol/m**3 # initial concentration of extracellular K
Cl_i0 = 5.0 # mMol/m**3 # initial concentration of intracellular Cl
Cl_o0 = 112.0 # mMol/m**3 # initial concentration of extracellular Cl

# helper functions

def m_inf(V):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does numba compile this efficiently?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed decoration with @njit is more efficient

return 1.0 / (1.0 + numpy.exp((Cmna - V) / DCmna))

def n_inf(V):
return 1.0 / (1.0 + numpy.exp((Cnk - V) / DCnk))

def h(n):
return 1.1 - 1.0 / (1.0 + numpy.exp(-8.0 * (n - 0.4)))

def I_K_form(V, n, K_o, K_i):
return (g_Kl + g_K * n) * (V - 26.64 * numpy.log(K_o / K_i))

def I_Na_form(V, Na_o, Na_i, n):
return (g_Nal + g_Na * m_inf(V) * h(n)) * (V - 26.64 * numpy.log(Na_o / Na_i))

def I_Cl_form(V):
return g_Cl * (V + 26.64 * numpy.log(Cl_o0 / Cl_i0))

def I_pump_form(Na_i, K_o):
return rho * (
1.0 / (1.0 + numpy.exp((Cnap - Na_i) / DCnap)) * (1.0 / (1.0 + numpy.exp((Ckp - K_o) / DCkp))))

def V_dot_form(I_Na, I_K, I_Cl, I_pump):
return (-1.0 / Cm) * (I_Na + I_K + I_Cl + I_pump)

beta = w_i / w_o
DNa_i = -DKi
DNa_o = -beta * DNa_i
DK_o = -beta * DKi
K_i = K_i0 + DKi
Na_i = Na_i0 + DNa_i
Na_o = Na_o0 + DNa_o
K_o = K_o0 + DK_o + Kg

ninf = n_inf(V)
I_K = I_K_form(V, n, K_o, K_i)
I_Na = I_Na_form(V, Na_o, Na_i, n)
I_Cl = I_Cl_form(V)
I_pump = I_pump_form(Na_i, K_o)

r = R_minus[0] * x / numpy.pi
Vdot = (-1.0 / Cm[0]) * (I_Na + I_K + I_Cl + I_pump)

if_xdot = Delta[0] + 2 * R_minus[0] * (V - c_minus[0]) * x - J[0] * r * x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to see these if else terms factored for readability. e.g.

Vsmall = V <= Vstar
RVc = where(Vsmall, R_minus[0]*(V-c_minus[0]), R_plus[0]*(V - c_plus[0]))
dx[0] = Delta[0] + 2*RVc*x - J[0]*r*x

or even better just

if V <= Vstar:
  R, c = R_minus[0], c_minus[0]
else:
  R, c = R_plus[0], c_plus[0]
dx[0] = Delta[0] + 2*R*(V - c)*x - J[0]*r*x

makes it clear that we're just switching two parameter values on the V <= Vstar branch not the whole expression

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed. added the necessary reshapes also

else_xdot = Delta[0] + 2 * R_plus[0] * (V - c_plus[0]) * x - J[0] * r * x

if_Vdot = Vdot - R_minus[0] * x ** 2 + eta[0] + (R_minus[0] / numpy.pi) * Coupling_Term * (E[0] - V)
else_Vdot = Vdot - R_plus[0] * x ** 2 + eta[0] + (R_minus[0] / numpy.pi) * Coupling_Term * (E[0] - V)

dx[0] = numpy.where(V <= (Vstar * numpy.ones_like(V)), if_xdot, else_xdot)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is there a [0] at the end of this expression?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the np.where was removed and so the [0]

dx[1] = numpy.where(V <= (Vstar * numpy.ones_like(V)), if_Vdot, else_Vdot)[0]
dx[2] = (ninf - n) / tau_n[0]
dx[3] = -(gamma[0] / w_i) * (I_K - 2.0 * I_pump)
dx[4] = epsilon[0] * (K_bath[0] - K_o)
Loading