Skip to content

Commit

Permalink
[ENH] Integrate trials object with GPFA (#610)
Browse files Browse the repository at this point in the history
* add handling for `Trial` class to `GPFA`
  • Loading branch information
Moritz-Alexander-Kern authored Mar 26, 2024
1 parent b8ccde8 commit 9326804
Show file tree
Hide file tree
Showing 5 changed files with 427 additions and 160 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ jobs:
- name: Test with pytest
run: |
coverage run --source=elephant -m pytest
coveralls --service=github
coverage run --source=elephant -m pytest && coveralls --service=github || echo "Coveralls submission failed"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

Expand Down Expand Up @@ -294,8 +293,7 @@ jobs:
- name: Test with pytest
run: |
mpiexec -n 1 python -m mpi4py -m coverage run --source=elephant -m pytest
coveralls --service=github
mpiexec -n 1 python -m mpi4py -m coverage run --source=elephant -m pytest && coveralls --service=github || echo "Coveralls submission failed"
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

Expand Down
287 changes: 182 additions & 105 deletions elephant/gpfa/gpfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,18 @@
:license: Modified BSD, see LICENSE.txt for details.
"""

from __future__ import division, print_function, unicode_literals

from typing import List, Union
import neo
import numpy as np
import quantities as pq
import sklearn

from elephant.gpfa import gpfa_core, gpfa_util
from elephant.trials import Trials
from elephant.utils import trials_to_list_of_spiketrainlist


__all__ = [
"GPFA"
]
__all__ = ["GPFA"]


class GPFA(sklearn.base.BaseEstimator):
Expand Down Expand Up @@ -228,9 +227,18 @@ class GPFA(sklearn.base.BaseEstimator):
... 'latent_variable'])
"""

def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01,
tau_init=100.0 * pq.ms, eps_init=1.0E-3, em_tol=1.0E-8,
em_max_iters=500, freq_ll=5, verbose=False):
def __init__(
self,
bin_size=20 * pq.ms,
x_dim=3,
min_var_frac=0.01,
tau_init=100.0 * pq.ms,
eps_init=1.0e-3,
em_tol=1.0e-8,
em_max_iters=500,
freq_ll=5,
verbose=False,
):
# Initialize object
self.bin_size = bin_size
self.x_dim = x_dim
Expand All @@ -241,11 +249,12 @@ def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01,
self.em_max_iters = em_max_iters
self.freq_ll = freq_ll
self.valid_data_names = (
'latent_variable_orth',
'latent_variable',
'Vsm',
'VsmGP',
'y')
"latent_variable_orth",
"latent_variable",
"Vsm",
"VsmGP",
"y",
)
self.verbose = verbose

if not isinstance(self.bin_size, pq.Quantity):
Expand All @@ -258,17 +267,53 @@ def __init__(self, bin_size=20 * pq.ms, x_dim=3, min_var_frac=0.01,
self.fit_info = dict()
self.transform_info = dict()

def fit(self, spiketrains):
@staticmethod
def _check_training_data(
spiketrains: List[List[neo.core.SpikeTrain]],
) -> None:
if len(spiketrains) == 0:
raise ValueError("Input spiketrains can not be empty")
if not all(
isinstance(item, neo.SpikeTrain)
for sublist in spiketrains
for item in sublist
):
raise ValueError(
"structure of the spiketrains is not "
"correct: 0-axis should be trials, 1-axis "
"neo.SpikeTrain and 2-axis spike times."
)

def _format_training_data(
self, spiketrains: List[List[neo.core.SpikeTrain]]
) -> np.recarray:
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
# Remove inactive units based on training set
self.has_spikes_bool = np.hstack(seqs["y"]).any(axis=1)
for seq in seqs:
seq["y"] = seq["y"][self.has_spikes_bool, :]
return seqs

@trials_to_list_of_spiketrainlist
def fit(
self,
spiketrains: Union[
List[List[neo.core.SpikeTrain]],
"Trials",
List[neo.core.spiketrainlist.SpikeTrainList],
],
) -> "GPFA":
"""
Fit the model with the given training data.
Parameters
----------
spiketrains : list of list of neo.SpikeTrain
---------- # noqa
spiketrains : :class:`elephant.trials.Trials`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or list of list of :class:`neo.core.SpikeTrain`
Spike train data to be fit to latent variables.
The outer list corresponds to trials and the inner list corresponds
to the neurons recorded in that trial, such that
`spiketrains[l][n]` is the spike train of neuron `n` in trial `l`.
For list of lists, the outer list corresponds to trials and the
inner list corresponds to the neurons recorded in that trial, such
that `spiketrains[l][n]` is the spike train of neuron `n` in trial
`l`.
Note that the number and order of `neo.SpikeTrain` objects per
trial must be fixed such that `spiketrains[l][n]` and
`spiketrains[k][n]` refer to spike trains of the same neuron
Expand All @@ -288,69 +333,74 @@ def fit(self, spiketrains):
If covariance matrix of input spike data is rank deficient.
"""
self._check_training_data(spiketrains)
seqs_train = self._format_training_data(spiketrains)
# Check if training data covariance is full rank
y_all = np.hstack(seqs_train['y'])
y_dim = y_all.shape[0]

if np.linalg.matrix_rank(np.cov(y_all)) < y_dim:
errmesg = 'Observation covariance matrix is rank deficient.\n' \
'Possible causes: ' \
'repeated units, not enough observations.'
raise ValueError(errmesg)

if self.verbose:
print('Number of training trials: {}'.format(len(seqs_train)))
print('Latent space dimensionality: {}'.format(self.x_dim))
print('Observation dimensionality: {}'.format(
self.has_spikes_bool.sum()))

# The following does the heavy lifting.
self.params_estimated, self.fit_info = gpfa_core.fit(
seqs_train=seqs_train,
x_dim=self.x_dim,
bin_width=self.bin_size.rescale('ms').magnitude,
min_var_frac=self.min_var_frac,
em_max_iters=self.em_max_iters,
em_tol=self.em_tol,
tau_init=self.tau_init.rescale('ms').magnitude,
eps_init=self.eps_init,
freq_ll=self.freq_ll,
verbose=self.verbose)

return self

@staticmethod
def _check_training_data(spiketrains):
if len(spiketrains) == 0:
raise ValueError("Input spiketrains cannot be empty")
if not isinstance(spiketrains[0][0], neo.SpikeTrain):
raise ValueError("structure of the spiketrains is not correct: "
"0-axis should be trials, 1-axis neo.SpikeTrain"
"and 2-axis spike times")

def _format_training_data(self, spiketrains):
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
# Remove inactive units based on training set
self.has_spikes_bool = np.hstack(seqs['y']).any(axis=1)
for seq in seqs:
seq['y'] = seq['y'][self.has_spikes_bool, :]
return seqs

def transform(self, spiketrains, returned_data=['latent_variable_orth']):
if all(
isinstance(item, neo.SpikeTrain)
for sublist in spiketrains
for item in sublist
):
self._check_training_data(spiketrains)
seqs_train = self._format_training_data(spiketrains)
# Check if training data covariance is full rank
y_all = np.hstack(seqs_train["y"])
y_dim = y_all.shape[0]

if np.linalg.matrix_rank(np.cov(y_all)) < y_dim:
errmesg = (
"Observation covariance matrix is rank deficient.\n"
"Possible causes: "
"repeated units, not enough observations."
)
raise ValueError(errmesg)

if self.verbose:
print("Number of training trials: {}".format(len(seqs_train)))
print("Latent space dimensionality: {}".format(self.x_dim))
print(
"Observation dimensionality: {}".format(
self.has_spikes_bool.sum()
)
)

# The following does the heavy lifting.
self.params_estimated, self.fit_info = gpfa_core.fit(
seqs_train=seqs_train,
x_dim=self.x_dim,
bin_width=self.bin_size.rescale("ms").magnitude,
min_var_frac=self.min_var_frac,
em_max_iters=self.em_max_iters,
em_tol=self.em_tol,
tau_init=self.tau_init.rescale("ms").magnitude,
eps_init=self.eps_init,
freq_ll=self.freq_ll,
verbose=self.verbose,
)
return self
else: # TODO: implement case for continuous data
raise ValueError

@trials_to_list_of_spiketrainlist
def transform(
self,
spiketrains: Union[
List[List[neo.core.SpikeTrain]],
"Trials",
List[neo.core.spiketrainlist.SpikeTrainList],
],
returned_data: str = ["latent_variable_orth"],
) -> "GPFA":
"""
Obtain trajectories of neural activity in a low-dimensional latent
variable space by inferring the posterior mean of the obtained GPFA
model and applying an orthonormalization on the latent variable space.
Parameters
----------
spiketrains : list of list of neo.SpikeTrain
---------- # noqa
spiketrains : list of list of :class:`neo.core.SpikeTrain`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or :class:`elephant.trials.Trials`
Spike train data to be transformed to latent variables.
The outer list corresponds to trials and the inner list corresponds
to the neurons recorded in that trial, such that
`spiketrains[l][n]` is the spike train of neuron `n` in trial `l`.
For list of lists, the outer list corresponds to trials and the
inner list corresponds to the neurons recorded in that trial, such
that `spiketrains[l][n]` is the spike train of neuron `n` in trial
`l`.
Note that the number and order of `neo.SpikeTrain` objects per
trial must be fixed such that `spiketrains[l][n]` and
`spiketrains[k][n]` refer to spike trains of the same neuron
Expand Down Expand Up @@ -378,7 +428,7 @@ def transform(self, spiketrains, returned_data=['latent_variable_orth']):
Returns
-------
np.ndarray or dict
:class:`np.ndarray` or dict
When the length of `returned_data` is one, a single np.ndarray,
containing the requested data (the first entry in `returned_data`
keys list), is returned. Otherwise, a dict of multiple np.ndarrays
Expand Down Expand Up @@ -411,36 +461,55 @@ def transform(self, spiketrains, returned_data=['latent_variable_orth']):
If `returned_data` contains keys different from the ones in
`self.valid_data_names`.
"""
if len(spiketrains[0]) != len(self.has_spikes_bool):
raise ValueError("'spiketrains' must contain the same number of "
"neurons as the training spiketrain data")
invalid_keys = set(returned_data).difference(self.valid_data_names)
if len(invalid_keys) > 0:
raise ValueError("'returned_data' can only have the following "
"entries: {}".format(self.valid_data_names))
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
for seq in seqs:
seq['y'] = seq['y'][self.has_spikes_bool, :]
seqs, ll = gpfa_core.exact_inference_with_ll(seqs,
self.params_estimated,
get_ll=True)
self.transform_info['log_likelihood'] = ll
self.transform_info['num_bins'] = seqs['T']
Corth, seqs = gpfa_core.orthonormalize(self.params_estimated, seqs)
self.transform_info['Corth'] = Corth
if len(returned_data) == 1:
return seqs[returned_data[0]]
return {x: seqs[x] for x in returned_data}

def fit_transform(self, spiketrains, returned_data=[
'latent_variable_orth']):
if all(
isinstance(item, neo.SpikeTrain)
for sublist in spiketrains
for item in sublist
):
if len(spiketrains[0]) != len(self.has_spikes_bool):
raise ValueError(
"'spiketrains' must contain the same number of "
"neurons as the training spiketrain data"
)
invalid_keys = set(returned_data).difference(self.valid_data_names)
if len(invalid_keys) > 0:
raise ValueError(
"'returned_data' can only have the following "
f"entries: {self.valid_data_names}"
)
seqs = gpfa_util.get_seqs(spiketrains, self.bin_size)
for seq in seqs:
seq["y"] = seq["y"][self.has_spikes_bool, :]
seqs, ll = gpfa_core.exact_inference_with_ll(
seqs, self.params_estimated, get_ll=True
)
self.transform_info["log_likelihood"] = ll
self.transform_info["num_bins"] = seqs["T"]
Corth, seqs = gpfa_core.orthonormalize(self.params_estimated, seqs)
self.transform_info["Corth"] = Corth
if len(returned_data) == 1:
return seqs[returned_data[0]]
return {x: seqs[x] for x in returned_data}
else: # TODO: implement case for continuous data
raise ValueError

@trials_to_list_of_spiketrainlist
def fit_transform(
self,
spiketrains: Union[
List[List[neo.core.SpikeTrain]],
"Trials",
List[neo.core.spiketrainlist.SpikeTrainList],
],
returned_data: str = ["latent_variable_orth"],
) -> "GPFA":
"""
Fit the model with `spiketrains` data and apply the dimensionality
reduction on `spiketrains`.
Parameters
----------
spiketrains : list of list of neo.SpikeTrain
---------- # noqa
spiketrains : list of list of :class:`neo.core.SpikeTrain`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or :class:`elephant.trials.Trials`
Refer to the :func:`GPFA.fit` docstring.
returned_data : list of str
Expand All @@ -465,13 +534,21 @@ def fit_transform(self, spiketrains, returned_data=[
self.fit(spiketrains)
return self.transform(spiketrains, returned_data=returned_data)

def score(self, spiketrains):
@trials_to_list_of_spiketrainlist
def score(
self,
spiketrains: Union[
List[List[neo.core.SpikeTrain]],
"Trials",
List[neo.core.spiketrainlist.SpikeTrainList],
],
) -> "GPFA":
"""
Returns the log-likelihood of the given data under the fitted model
Parameters
----------
spiketrains : list of list of neo.SpikeTrain
---------- # noqa
spiketrains : list of list of :class:`neo.core.SpikeTrain`, list of :class:`neo.core.spiketrainlist.SpikeTrainList` or :class:`elephant.trials.Trials`
Spike train data to be scored.
The outer list corresponds to trials and the inner list corresponds
to the neurons recorded in that trial, such that
Expand All @@ -487,4 +564,4 @@ def score(self, spiketrains):
Log-likelihood of the given spiketrains under the fitted model.
"""
self.transform(spiketrains)
return self.transform_info['log_likelihood']
return self.transform_info["log_likelihood"]
Loading

0 comments on commit 9326804

Please sign in to comment.