Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] pairwise phase consistency #392

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
10 changes: 5 additions & 5 deletions elephant/phase_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,14 @@ def pairwise_phase_consistency(phases, method='ppc0'):
# Compute the distance between each pair of phases using dot product
# Optimize computation time using array multiplications instead of for
# loops
p_cos_2d = np.tile(np.cos(phase_array), reps=(n_trials, 1)) # TODO: optimize memory usage
p_sin_2d = np.tile(np.sin(phase_array), reps=(n_trials, 1))
p_cos_2d = np.broadcast_to(np.cos(phase_array), (n_trials, n_trials))
p_sin_2d = np.broadcast_to(np.sin(phase_array), (n_trials, n_trials))

# By doing the element-wise multiplication of this matrix with its
# transpose, we get the distance between phases for all possible pairs
# of elements in 'phase'
dot_prod = np.multiply(p_cos_2d, p_cos_2d.T) + \
np.multiply(p_sin_2d, p_sin_2d.T)
dot_prod = np.multiply(p_cos_2d, p_cos_2d.T, dtype=np.float32) + \
np.multiply(p_sin_2d, p_sin_2d.T, dtype=np.float32) # TODO: agree on using this precision or not

# Now average over all elements in temp_results (the diagonal are 1
# and should not be included)
Expand All @@ -270,7 +270,7 @@ def pairwise_phase_consistency(phases, method='ppc0'):
# Note: each pair i,j is computed twice in dot_prod. do not
# multiply by 2. n_trial * n_trials - n_trials = nr of filled elements
# in dot_prod
ppc = np.sum(dot_prod) / (n_trials * n_trials - n_trials)
ppc = np.sum(dot_prod) / (n_trials * n_trials - n_trials) # TODO: handle nan's
return ppc

elif method == 'ppc1':
Expand Down
126 changes: 126 additions & 0 deletions elephant/test/test_phase_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,5 +202,131 @@ def test_regression_269(self):
self.assertEqual(len(phases_noint[0]), 2)


class PairwisePhaseConsistencyTestCase(unittest.TestCase):

@classmethod
def setUpClass(cls): # Note: using setUp makes the class call this
TRuikes marked this conversation as resolved.
Show resolved Hide resolved
# function per test, while this way the function is called only
# 1 time per TestCase, slightly more efficient (0.5s tough)

# Same setup as SpikeTriggerePhaseTestCase
tlen0 = 100 * pq.s
f0 = 20. * pq.Hz
fs0 = 1 * pq.ms
t0 = np.arange(
0, tlen0.rescale(pq.s).magnitude,
fs0.rescale(pq.s).magnitude) * pq.s
cls.anasig0 = AnalogSignal(
np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)

# Spiketrain with perfect locking
cls.st_perfect = SpikeTrain(
np.arange(50, tlen0.rescale(pq.ms).magnitude - 50, 50) * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)

# Spiketrain with inperfect locking
cls.st_inperfect = SpikeTrain(
[100., 100.1, 100.2, 100.3, 100.9, 101.] * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)

# Generate 2 'bursting' spiketrains, both locking on sinus period,
# but with different strengths
n_spikes = 3 # n spikes per burst
burst_interval = (1 / f0.magnitude) * pq.s
burst_start_times = np.arange(
0,
tlen0.rescale('ms').magnitude,
burst_interval.rescale('ms').magnitude
)

# Spiketrain with strong locking
burst_freq_strong = 200. * pq.Hz # strongly locking unit
burst_spike_interval = (1 / burst_freq_strong.magnitude) * pq.s
st_in_burst = np.arange(
0,
burst_spike_interval.rescale('ms').magnitude * n_spikes,
burst_spike_interval.rescale('ms').magnitude
)
st = [st_in_burst + t_offset for t_offset in burst_start_times]
st = np.hstack(st) * pq.ms
cls.st_bursting_strong = SpikeTrain(st,
t_start=0 * pq.ms,
t_stop=tlen0
)

# Spiketrain with weak locking
burst_freq_weak = 100. * pq.Hz # weak locking unit
burst_spike_interval = (1 / burst_freq_weak.magnitude) * pq.s
st_in_burst = np.arange(
0,
burst_spike_interval.rescale('ms').magnitude * n_spikes,
burst_spike_interval.rescale('ms').magnitude
)
st = [st_in_burst + t_offset for t_offset in burst_start_times]
st = np.hstack(st) * pq.ms
cls.st_bursting_weak = SpikeTrain(st,
t_start=0 * pq.ms,
t_stop=tlen0
)

def test_perfect_locking(self):
phases, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st_perfect,
interpolate=True
)
# Pass input as single array
ppc0 = elephant.phase_analysis.pairwise_phase_consistency(
phases[0], method='ppc0'
)
self.assertEqual(ppc0, 1)
self.assertIsInstance(ppc0, float)

# Pass input as list of arrays
n_phases = int(phases[0].shape[0] / 2)
phases_cut = [phases[0][i * 2:i * 2 + 2] for i in range(n_phases)]
ppc0 = elephant.phase_analysis.pairwise_phase_consistency(
phases_cut, method='ppc0'
)
self.assertEqual(ppc0, 1)
self.assertIsInstance(ppc0, float)

def test_inperfect_locking(self):
phases, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st_inperfect,
interpolate=True
)
# Pass input as single array
ppc0 = elephant.phase_analysis.pairwise_phase_consistency(
phases[0], method='ppc0'
)
self.assertLess(ppc0, 1)
self.assertIsInstance(ppc0, float)

def test_strong_vs_weak_locking(self):
phases_weak, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st_bursting_weak,
interpolate=True
)
# Pass input as single array
ppc0_weak = elephant.phase_analysis.pairwise_phase_consistency(
phases_weak[0], method='ppc0'
)
phases_strong, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st_bursting_strong,
interpolate=True
)
# Pass input as single array
ppc0_strong = elephant.phase_analysis.pairwise_phase_consistency(
phases_strong[0], method='ppc0'
)

self.assertLess(ppc0_weak, ppc0_strong)


if __name__ == '__main__':
unittest.main()