-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathephys_features.py
executable file
·1450 lines (1143 loc) · 55.4 KB
/
ephys_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2017. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact [email protected] for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import warnings
import logging
import numpy as np
import scipy.signal as signal
from scipy.optimize import curve_fit
from functools import partial
def detect_putative_spikes(v, t, start=None, end=None, filter=10., dv_cutoff=20.):
"""Perform initial detection of spikes and return their indexes.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
start : start of time window for spike detection (optional)
end : end of time window for spike detection (optional)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
putative_spikes : numpy array of preliminary spike indexes
"""
if not isinstance(v, np.ndarray):
raise TypeError("v is not an np.ndarray")
if not isinstance(t, np.ndarray):
raise TypeError("t is not an np.ndarray")
if v.shape != t.shape:
raise FeatureError("Voltage and time series do not have the same dimensions")
if start is None:
start = t[0]
if end is None:
end = t[-1]
start_index = find_time_index(t, start)
end_index = find_time_index(t, end)
v_window = v[start_index:end_index + 1]
t_window = t[start_index:end_index + 1]
dvdt = calculate_dvdt(v_window, t_window, filter)
# Find positive-going crossings of dV/dt cutoff level
putative_spikes = np.flatnonzero(np.diff(np.greater_equal(dvdt, dv_cutoff).astype(int)) == 1)
if dvdt[0] > dv_cutoff:
putative_spikes = np.insert(putative_spikes, 0, 0)
if len(putative_spikes) <= 1:
# Set back to original index space (not just window)
return np.array(putative_spikes) + start_index
# Only keep spike times if dV/dt has dropped all the way to zero between putative spikes
putative_spikes = [putative_spikes[0]] + [s for i, s in enumerate(putative_spikes[1:])
if np.any(dvdt[putative_spikes[i]:s] < 0)]
# Set back to original index space (not just window)
return np.array(putative_spikes) + start_index
def find_peak_indexes(v, t, spike_indexes, end=None):
"""Find indexes of spike peaks.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of preliminary spike indexes
end : end of time window for spike detection (optional)
"""
if not end:
end = t[-1]
end_index = find_time_index(t, end)
spks_and_end = np.append(spike_indexes, end_index)
peak_indexes = [np.argmax(v[spk:next]) + spk for spk, next in
zip(spks_and_end[:-1], spks_and_end[1:])]
return np.array(peak_indexes)
def filter_putative_spikes(v, t, spike_indexes, peak_indexes, min_height=2.,
min_peak=-30., filter=10., dvdt=None):
"""Filter out events that are unlikely to be spikes based on:
* Voltage failing to go down between peak and the next spike's threshold
* Height (threshold to peak)
* Absolute peak level
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of preliminary spike indexes
peak_indexes : numpy array of indexes of spike peaks
min_height : minimum acceptable height from threshold to peak in mV (optional, default 2)
min_peak : minimum acceptable absolute peak level in mV (optional, default -30)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
spike_indexes : numpy array of threshold indexes
peak_indexes : numpy array of peak indexes
"""
if not spike_indexes.size or not peak_indexes.size:
return np.array([]), np.array([])
if dvdt is None:
dvdt = calculate_dvdt(v, t, filter)
diff_mask = [np.any(dvdt[peak_ind:spike_ind] < 0)
for peak_ind, spike_ind
in zip(peak_indexes[:-1], spike_indexes[1:])]
peak_indexes = peak_indexes[np.array(diff_mask + [True])]
spike_indexes = spike_indexes[np.array([True] + diff_mask)]
peak_level_mask = v[peak_indexes] >= min_peak
spike_indexes = spike_indexes[peak_level_mask]
peak_indexes = peak_indexes[peak_level_mask]
height_mask = (v[peak_indexes] - v[spike_indexes]) >= min_height
spike_indexes = spike_indexes[height_mask]
peak_indexes = peak_indexes[height_mask]
return spike_indexes, peak_indexes
def find_upstroke_indexes(v, t, spike_indexes, peak_indexes, filter=10., dvdt=None):
"""Find indexes of maximum upstroke of spike.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of preliminary spike indexes
peak_indexes : numpy array of indexes of spike peaks
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
upstroke_indexes : numpy array of upstroke indexes
"""
if dvdt is None:
dvdt = calculate_dvdt(v, t, filter)
upstroke_indexes = [np.argmax(dvdt[spike:peak]) + spike for spike, peak in
zip(spike_indexes, peak_indexes)]
return np.array(upstroke_indexes)
def refine_threshold_indexes_based_on_third_derivative(v, t, peak_indexes, upstroke_indexes, filter=10., dvdt=None):
"""Refine threshold detection of previously-found spikes. Simple code to handle too steep depolarisations in the beginning
now too.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
peak_indexes : numpy array of indexes of action potential peaks
upstroke_indexes : numpy array of indexes of spike upstrokes
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
threshold_indexes : numpy array of threshold indexes
"""
if not upstroke_indexes.size:
return np.array([])
if dvdt is None:
dvdt = calculate_dvdt(v, t, filter)
dvdt2 = calculate_dvdt(dvdt, t[:-1], filter) # Second derivative of the voltage
dvdt3 = calculate_dvdt(dvdt2, t[:-2], filter) # Third derivative of the voltage
threshold_indexes = []
peak_indexes_and_start = np.append(np.array([0]), peak_indexes)
for peak, upstroke in zip(peak_indexes_and_start[:-1], upstroke_indexes):
thresh_index = np.argmax(dvdt3[upstroke : peak + np.argmin(v[peak:upstroke]) : -1])
threshold_indexes.append(upstroke - thresh_index)
return np.array(threshold_indexes)
def refine_threshold_indexes_updated(v, t, upstroke_indexes, start = None, thresh_frac=0.05, filter=10., dvdt=None):
"""Refine threshold detection of previously-found spikes. Simple code to handle too steep depolarisations in the beginning
now too.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
upstroke_indexes : numpy array of indexes of spike upstrokes (for threshold target calculation)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
threshold_indexes : numpy array of threshold indexes
"""
if start is None:
start = t[0]
start_index = find_time_index(t, start)
if not upstroke_indexes.size:
return np.array([])
if dvdt is None:
dvdt = calculate_dvdt(v, t, filter)
avg_upstroke = dvdt[upstroke_indexes].mean()
target = avg_upstroke * thresh_frac
upstrokes_and_start = np.append(np.array([start_index]), upstroke_indexes)
threshold_indexes = []
for upstk, upstk_prev in zip(upstrokes_and_start[1:], upstrokes_and_start[:-1]):
if (upstk_prev == start_index and not upstk_prev == find_time_index(t, 0.1)): # Too steep depolarisations
threshold_indexes.append(upstk - np.argmin(dvdt[upstk:upstk_prev:-1]))
continue;
potential_indexes = np.flatnonzero(dvdt[upstk:upstk_prev:-1] <= target)
if not potential_indexes.size:
# couldn't find a matching value for threshold,
# so just going to the start of the search interval
threshold_indexes.append(upstk_prev)
else:
threshold_indexes.append(upstk - potential_indexes[0])
return np.array(threshold_indexes)
def refine_threshold_indexes(v, t, upstroke_indexes, thresh_frac=0.05, filter=10., dvdt=None):
"""Refine threshold detection of previously-found spikes.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
upstroke_indexes : numpy array of indexes of spike upstrokes (for threshold target calculation)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
threshold_indexes : numpy array of threshold indexes
"""
if not upstroke_indexes.size:
return np.array([])
if dvdt is None:
dvdt = calculate_dvdt(v, t, filter)
avg_upstroke = dvdt[upstroke_indexes].mean()
target = avg_upstroke * thresh_frac
upstrokes_and_start = np.append(np.array([0]), upstroke_indexes)
threshold_indexes = []
for upstk, upstk_prev in zip(upstrokes_and_start[1:], upstrokes_and_start[:-1]):
potential_indexes = np.flatnonzero(dvdt[upstk:upstk_prev:-1] <= target)
if not potential_indexes.size:
# couldn't find a matching value for threshold,
# so just going to the start of the search interval
threshold_indexes.append(upstk_prev)
else:
threshold_indexes.append(upstk - potential_indexes[0])
return np.array(threshold_indexes)
def check_thresholds_and_peaks(v, t, spike_indexes, peak_indexes, upstroke_indexes, end=None,
max_interval=0.005, thresh_frac=0.05, filter=10., dvdt=None,
tol=1.0):
"""Validate thresholds and peaks for set of spikes
Check that peaks and thresholds for consecutive spikes do not overlap
Spikes with overlapping thresholds and peaks will be merged.
Check that peaks and thresholds for a given spike are not too far apart.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of spike indexes
peak_indexes : numpy array of indexes of spike peaks
upstroke_indexes : numpy array of indexes of spike upstrokes
max_interval : maximum allowed time between start of spike and time of peak in sec (default 0.005)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
tol : tolerance for returning to threshold in mV (optional, default 1)
Returns
-------
spike_indexes : numpy array of modified spike indexes
peak_indexes : numpy array of modified spike peak indexes
upstroke_indexes : numpy array of modified spike upstroke indexes
clipped : numpy array of clipped status of spikes
"""
if not end:
end = t[-1]
overlaps = np.flatnonzero(spike_indexes[1:] <= peak_indexes[:-1] + 1)
if overlaps.size:
spike_mask = np.ones_like(spike_indexes, dtype=bool)
spike_mask[overlaps + 1] = False
spike_indexes = spike_indexes[spike_mask]
peak_mask = np.ones_like(peak_indexes, dtype=bool)
peak_mask[overlaps] = False
peak_indexes = peak_indexes[peak_mask]
upstroke_mask = np.ones_like(upstroke_indexes, dtype=bool)
upstroke_mask[overlaps] = False
upstroke_indexes = upstroke_indexes[upstroke_mask]
# Validate that peaks don't occur too long after the threshold
# If they do, try to re-find threshold from the peak
too_long_spikes = []
for i, (spk, peak) in enumerate(zip(spike_indexes, peak_indexes)):
if t[peak] - t[spk] >= max_interval:
logging.info("Need to recalculate threshold-peak pair that exceeds maximum allowed interval ({:f} s)".format(max_interval))
too_long_spikes.append(i)
if too_long_spikes:
if dvdt is None:
dvdt = calculate_dvdt(v, t, filter)
avg_upstroke = dvdt[upstroke_indexes].mean()
target = avg_upstroke * thresh_frac
drop_spikes = []
for i in too_long_spikes:
# First guessing that threshold is wrong and peak is right
peak = peak_indexes[i]
t_0 = find_time_index(t, t[peak] - max_interval)
below_target = np.flatnonzero(dvdt[upstroke_indexes[i]:t_0:-1] <= target)
if not below_target.size:
# Now try to see if threshold was right but peak was wrong
# Find the peak in a window twice the size of our allowed window
spike = spike_indexes[i]
t_0 = find_time_index(t, t[spike] + 2 * max_interval)
new_peak = np.argmax(v[spike:t_0]) + spike
# If that peak is okay (not outside the allowed window, not past the next spike)
# then keep it
if t[new_peak] - t[spike] < max_interval and \
(i == len(spike_indexes) - 1 or t[new_peak] < t[spike_indexes[i + 1]]):
peak_indexes[i] = new_peak
else:
# Otherwise, log and get rid of the spike
logging.info("Could not redetermine threshold-peak pair - dropping that pair")
drop_spikes.append(i)
# raise FeatureError("Could not redetermine threshold")
else:
spike_indexes[i] = upstroke_indexes[i] - below_target[0]
if drop_spikes:
spike_indexes = np.delete(spike_indexes, drop_spikes)
peak_indexes = np.delete(peak_indexes, drop_spikes)
upstroke_indexes = np.delete(upstroke_indexes, drop_spikes)
# Check that last spike was not cut off too early by end of stimulus
# by checking that the membrane potential returned to at least the threshold
# voltage - otherwise, drop it
clipped = np.zeros_like(spike_indexes, dtype=bool)
end_index = find_time_index(t, end)
if len(spike_indexes) > 0 and not np.any(v[peak_indexes[-1]:end_index + 1] <= v[spike_indexes[-1]] + tol):
logging.debug("Failed to return to threshold voltage + tolerance (%.2f) after last spike (min %.2f) - marking last spike as clipped", v[spike_indexes[-1]] + tol, v[peak_indexes[-1]:end_index + 1].min())
clipped[-1] = True
return spike_indexes, peak_indexes, upstroke_indexes, clipped
def check_threshold_w_peak(v, t, spike_indexes, peak_indexes, clipped):
# Noisy 'spikes' which have a threshold way too close to the peak index should be removed. Sometimes lowering the filter helps,
# you don't want to not see spikes which are spikes. This is an alternative approach.
if not spike_indexes.size or not peak_indexes.size:
return np.array([]), np.array([]), np.zeros_like(np.array([]), dtype = bool)
# If the peak comes already less than approximately a fifth of a millisecond after the putative threshold, it's probably noise
indices = [((t[peak] - t[spike]) > 0.0002) for peak, spike in zip(peak_indexes, spike_indexes)]
spike_indexes = spike_indexes[indices]
peak_indexes = peak_indexes[indices]
clipped = clipped[indices]
return spike_indexes, peak_indexes, clipped
def find_trough_indexes(v, t, spike_indexes, peak_indexes, clipped=None, end=None):
"""
Find indexes of minimum voltage (trough) between spikes.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of spike indexes
peak_indexes : numpy array of spike peak indexes
end : end of time window (optional)
Returns
-------
trough_indexes : numpy array of threshold indexes
"""
if not spike_indexes.size or not peak_indexes.size:
return np.array([])
if clipped is None:
clipped = np.zeros_like(spike_indexes, dtype=bool)
if end is None:
end = t[-1]
end_index = find_time_index(t, end)
trough_indexes = np.zeros_like(spike_indexes, dtype=float)
trough_indexes[:-1] = [v[peak:spk].argmin() + peak for peak, spk
in zip(peak_indexes[:-1], spike_indexes[1:])]
if clipped[-1]:
# If last spike is cut off by the end of the window, trough is undefined
trough_indexes[-1] = np.nan
else:
trough_indexes[-1] = v[peak_indexes[-1]:end_index].argmin() + peak_indexes[-1]
# nwg - trying to remove this next part for now - can't figure out if this will be needed with new "clipped" method
# If peak is the same point as the trough, drop that point
# trough_indexes = trough_indexes[np.where(peak_indexes[:len(trough_indexes)] != trough_indexes)]
return trough_indexes
def check_trough_w_peak(spike_indexes, upstroke_indexes, peak_indexes, trough_indexes, clipped, filter = 10., dvdt = None):
# Sometimes the stimulus at offset e.g. can result in a bump which could be classified as a spike. Sometimes lowering the filter can
# work, but you would have to lower it a lot sometimes which might result in a lack of actual spike detection. It seems that for these
# bumps the voltage does not increase much and that the trough and peak are at the same detected time index. Ofcourse a peak and
# trough cannot occur at the same time, so based on this we remove those 'spikes'.
if not spike_indexes.size or not peak_indexes.size:
return np.array([]), np.array([]), np.array([]), np.array([]), np.zeros_like(np.array([]), dtype = bool)
indices = [peak != trough for peak, trough in zip(peak_indexes, trough_indexes)]
spike_indexes = spike_indexes[indices]
upstroke_indexes = upstroke_indexes[indices]
peak_indexes = peak_indexes[indices]
trough_indexes = trough_indexes[indices]
clipped = clipped[indices]
return spike_indexes, upstroke_indexes, peak_indexes, trough_indexes, clipped
def find_downstroke_indexes(v, t, peak_indexes, trough_indexes, clipped=None, filter=10., dvdt=None):
"""Find indexes of minimum voltage (troughs) between spikes.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
peak_indexes : numpy array of spike peak indexes
trough_indexes : numpy array of threshold indexes
clipped: boolean array - False if spike not clipped by edge of window
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
downstroke_indexes : numpy array of downstroke indexes
"""
if not trough_indexes.size:
return np.array([])
if dvdt is None:
dvdt = calculate_dvdt(v, t, filter)
if clipped is None:
clipped = np.zeros_like(peak_indexes, dtype=bool)
if len(peak_indexes) < len(trough_indexes):
raise FeatureError("Cannot have more troughs than peaks")
# Taking this out...with clipped info, should always have the same number of points
# peak_indexes = peak_indexes[:len(trough_indexes)]
valid_peak_indexes = peak_indexes[~clipped].astype(int)
valid_trough_indexes = trough_indexes[~clipped].astype(int)
downstroke_indexes = np.zeros_like(peak_indexes) * np.nan
downstroke_index_values = [np.argmin(dvdt[peak:trough]) + peak for peak, trough
in zip(valid_peak_indexes, valid_trough_indexes)]
downstroke_indexes[~clipped] = downstroke_index_values
return downstroke_indexes
def find_widths(v, t, spike_indexes, peak_indexes, trough_indexes, clipped=None):
"""Find widths at half-height for spikes.
Widths are only returned when heights are defined
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of spike indexes
peak_indexes : numpy array of spike peak indexes
trough_indexes : numpy array of trough indexes
Returns
-------
widths : numpy array of spike widths in sec
"""
if not spike_indexes.size or not peak_indexes.size:
return np.array([])
if len(spike_indexes) < len(trough_indexes):
raise FeatureError("Cannot have more troughs than spikes")
if clipped is None:
clipped = np.zeros_like(spike_indexes, dtype=bool)
use_indexes = ~np.isnan(trough_indexes)
use_indexes[clipped] = False
heights = np.zeros_like(trough_indexes) * np.nan
heights[use_indexes] = v[peak_indexes[use_indexes]] - v[trough_indexes[use_indexes].astype(int)]
width_levels = np.zeros_like(trough_indexes) * np.nan
width_levels[use_indexes] = heights[use_indexes] / 2. + v[trough_indexes[use_indexes].astype(int)]
thresh_to_peak_levels = np.zeros_like(trough_indexes) * np.nan
thresh_to_peak_levels[use_indexes] = (v[peak_indexes[use_indexes]] - v[spike_indexes[use_indexes]]) / 2. + v[spike_indexes[use_indexes]]
# Some spikes in burst may have deep trough but short height, so can't use same
# definition for width
width_levels[width_levels < v[spike_indexes]] = \
thresh_to_peak_levels[width_levels < v[spike_indexes]]
width_starts = np.zeros_like(trough_indexes) * np.nan
width_starts[use_indexes] = np.array([pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0] if
np.flatnonzero(v[pk:spk:-1] <= wl).size > 0 else np.nan for pk, spk, wl
in zip(peak_indexes[use_indexes], spike_indexes[use_indexes], width_levels[use_indexes])])
width_ends = np.zeros_like(trough_indexes) * np.nan
width_ends[use_indexes] = np.array([pk + np.flatnonzero(v[pk:tr] <= wl)[0] if
np.flatnonzero(v[pk:tr] <= wl).size > 0 else np.nan for pk, tr, wl
in zip(peak_indexes[use_indexes], trough_indexes[use_indexes].astype(int), width_levels[use_indexes])])
missing_widths = np.isnan(width_starts) | np.isnan(width_ends)
widths = np.zeros_like(width_starts, dtype=np.float64)
widths[~missing_widths] = t[width_ends[~missing_widths].astype(int)] - \
t[width_starts[~missing_widths].astype(int)]
if any(missing_widths):
widths[missing_widths] = np.nan
return widths
def find_widths_wrt_threshold(v, t, spike_indexes, peak_indexes, trough_indexes, clipped=None):
"""Find widths at half-height for spikes but based this time on heights w.r.t. to the thresholds (not troughs)
Trough_indexes are still necessary to find the index for which the AP reaches the half-height after the peak
Widths are only returned when heights are defined
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of spike indexes (they are actually the thresholds)
peak_indexes : numpy array of spike peak indexes
trough_indexes : numpy array of trough indexes
Returns
-------
widths : numpy array of spike widths in sec
"""
if not spike_indexes.size or not peak_indexes.size:
return np.array([])
if len(spike_indexes) < len(trough_indexes):
raise FeatureError("Cannot have more troughs than spikes")
if clipped is None:
clipped = np.zeros_like(spike_indexes, dtype=bool)
use_indexes = ~np.isnan(trough_indexes)
use_indexes[clipped] = False
heights = np.zeros_like(trough_indexes) * np.nan
heights[use_indexes] = v[peak_indexes[use_indexes]] - v[spike_indexes[use_indexes].astype(int)]
width_levels = np.zeros_like(trough_indexes) * np.nan
width_levels[use_indexes] = heights[use_indexes] / 2. + v[spike_indexes[use_indexes].astype(int)]
width_starts = np.zeros_like(trough_indexes) * np.nan
width_starts[use_indexes] = np.array([pk - np.flatnonzero(v[pk:spk:-1] <= wl)[0] if
np.flatnonzero(v[pk:spk:-1] <= wl).size > 0 else np.nan for pk, spk, wl
in zip(peak_indexes[use_indexes], spike_indexes[use_indexes], width_levels[use_indexes])])
width_ends = np.zeros_like(trough_indexes) * np.nan
width_ends[use_indexes] = np.array([pk + np.flatnonzero(v[pk:tr] <= wl)[0] if
np.flatnonzero(v[pk:tr] <= wl).size > 0 else np.nan for pk, tr, wl
in zip(peak_indexes[use_indexes], trough_indexes[use_indexes].astype(int), width_levels[use_indexes])])
missing_widths = np.isnan(width_starts) | np.isnan(width_ends)
widths = np.zeros_like(width_starts, dtype=np.float64)
widths[~missing_widths] = t[width_ends[~missing_widths].astype(int)] - \
t[width_starts[~missing_widths].astype(int)]
if any(missing_widths):
widths[missing_widths] = np.nan
return widths
def analyze_trough_details(v, t, spike_indexes, peak_indexes, clipped=None, end=None, filter=10.,
heavy_filter=1., term_frac=0.003, adp_thresh=0.5, tol=0.5,
flat_interval=0.002, adp_max_delta_t=0.01, adp_max_delta_v=10., dvdt=None):
"""Analyze trough to determine if an ADP exists and whether the reset is a 'detour' or 'direct'
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
spike_indexes : numpy array of spike indexes
peak_indexes : numpy array of spike peak indexes
end : end of time window (optional)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (default 1)
heavy_filter : lower cutoff frequency for 4-pole low-pass Bessel filter in kHz (default 1)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
adp_thresh: minimum dV/dt in V/s to exceed to be considered to have an ADP (optional, default 1.5)
tol : tolerance for evaluating whether Vm drops appreciably further after end of spike (default 1.0 mV)
flat_interval: if the trace is flat for this duration, stop looking for an ADP (default 0.002 s)
adp_max_delta_t: max possible ADP delta t (default 0.005 s)
adp_max_delta_v: max possible ADP delta v (default 10 mV)
dvdt : pre-calculated time-derivative of voltage (optional)
Returns
-------
isi_types : numpy array of isi reset types (direct or detour)
fast_trough_indexes : numpy array of indexes at the start of the trough (i.e. end of the spike)
adp_indexes : numpy array of adp indexes (np.nan if there was no ADP in that ISI
slow_trough_indexes : numpy array of indexes at the minimum of the slow phase of the trough
(if there wasn't just a fast phase)
"""
if end is None:
end = t[-1]
end_index = find_time_index(t, end)
if clipped is None:
clipped = np.zeros_like(peak_indexes, dtype = bool)
# Can't evaluate for spikes that are clipped by the window
orig_len = len(peak_indexes)
valid_spike_indexes = spike_indexes[~clipped]
valid_peak_indexes = peak_indexes[~clipped]
if dvdt is None:
dvdt = calculate_dvdt(v, t, filter)
dvdt_hvy = calculate_dvdt(v, t, heavy_filter)
# Writing as for loop - see if I can vectorize any later
fast_trough_indexes = []
adp_indexes = []
slow_trough_indexes = []
isi_types = []
update_clipped = []
for peak, next_spk in zip(valid_peak_indexes, np.append(valid_spike_indexes[1:], end_index)):
downstroke = dvdt[peak:next_spk].argmin() + peak
target = term_frac * dvdt[downstroke]
terminated_points = np.flatnonzero(dvdt[downstroke:next_spk] >= target)
if terminated_points.size:
terminated = terminated_points[0] + downstroke
update_clipped.append(False)
else:
logging.debug("Could not identify fast trough - marking spike as clipped")
isi_types.append(np.nan)
fast_trough_indexes.append(np.nan)
adp_indexes.append(np.nan)
slow_trough_indexes.append(np.nan)
update_clipped.append(True)
continue
# Could there be an ADP?
adp_index = np.nan
dv_over_thresh = np.flatnonzero(dvdt_hvy[terminated:next_spk] >= adp_thresh)
if dv_over_thresh.size:
cross = dv_over_thresh[0] + terminated
# only want to look for ADP before things get pretty flat
# otherwise, could just pick up random transients long after the spike
if t[cross] - t[terminated] < flat_interval:
# Going back up fast, but could just be going into another spike
# so need to check for a reversal (zero-crossing) in dV/dt
zero_return_vals = np.flatnonzero(dvdt_hvy[cross:next_spk] <= 0)
if zero_return_vals.size:
putative_adp_index = zero_return_vals[0] + cross
min_index = v[putative_adp_index:next_spk].argmin() + putative_adp_index
if (v[putative_adp_index] - v[min_index] >= tol and
v[putative_adp_index] - v[terminated] <= adp_max_delta_v and
t[putative_adp_index] - t[terminated] <= adp_max_delta_t):
adp_index = putative_adp_index
slow_phase_min_index = min_index
isi_type = "detour"
if np.isnan(adp_index):
v_term = v[terminated]
min_index = v[terminated:next_spk].argmin() + terminated
if v_term - v[min_index] >= tol:
# dropped further after end of spike -> detour reset
isi_type = "detour"
slow_phase_min_index = min_index
else:
isi_type = "direct"
isi_types.append(isi_type)
fast_trough_indexes.append(terminated)
adp_indexes.append(adp_index)
if isi_type == "detour":
slow_trough_indexes.append(slow_phase_min_index)
else:
slow_trough_indexes.append(np.nan)
# If we had to kick some spikes out before, need to add nans at the end
output = []
output.append(np.array(isi_types))
for d in (fast_trough_indexes, adp_indexes, slow_trough_indexes):
output.append(np.array(d, dtype=float))
if orig_len > len(isi_types):
extra = np.zeros(orig_len - len(isi_types)) * np.nan
output = tuple((np.append(o, extra) for o in output))
# The ADP and slow trough for the last spike in a train are not reliably
# calculated, and usually extreme when wrong, so we will NaN them out.
#
# Note that this will result in a 0 value when delta V or delta T is
# calculated, which may not be strictly accurate to the trace, but the
# magnitude of the difference will be less than in many of the erroneous
# cases seen otherwise
output[2][-1] = np.nan # ADP
output[3][-1] = np.nan # slow trough
clipped[~clipped] = update_clipped
return output, clipped
def find_time_index(t, t_0):
"""Find the index value of a given time (t_0) in a time series (t)."""
t_gte = np.flatnonzero(t >= t_0)
if not t_gte.size:
raise FeatureError("Could not find given time in time vector")
return t_gte[0]
def calculate_dvdt(v, t, filter=None):
"""Low-pass filters (if requested) and differentiates voltage by time.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default None)
Returns
-------
dvdt : numpy array of time-derivative of voltage (V/s = mV/ms)
"""
if has_fixed_dt(t) and filter:
delta_t = t[1] - t[0]
sample_freq = 1. / delta_t
filt_coeff = (filter * 1e3) / (sample_freq / 2.) # filter kHz -> Hz, then get fraction of Nyquist frequency
if filt_coeff < 0 or filt_coeff >= 1:
raise ValueError("bessel coeff ({:f}) is outside of valid range [0,1); cannot filter sampling frequency {:.1f} kHz with cutoff frequency {:.1f} kHz.".format(filt_coeff, sample_freq / 1e3, filter))
b, a = signal.bessel(4, filt_coeff, "low")
v_filt = signal.filtfilt(b, a, v, axis=0)
dv = np.diff(v_filt)
else:
dv = np.diff(v)
dt = np.diff(t)
dvdt = 1e-3 * dv / dt # in V/s = mV/ms
# Remove nan values (in case any dt values == 0)
dvdt = dvdt[~np.isnan(dvdt)]
return dvdt
def get_isis(t, spikes):
"""Find interspike intervals in sec between spikes (as indexes)."""
if len(spikes) <= 1:
return np.array([])
return t[spikes[1:]] - t[spikes[:-1]]
def average_voltage(v, t, start=None, end=None):
"""Calculate average voltage between start and end.
Parameters
----------
v : numpy array of voltage time series in mV
t : numpy array of times in seconds
start : start of time window for spike detection (optional, default None)
end : end of time window for spike detection (optional, default None)
Returns
-------
v_avg : average voltage
"""
if start is None:
start = t[0]
if end is None:
end = t[-1]
start_index = find_time_index(t, start)
end_index = find_time_index(t, end)
return v[start_index:end_index].mean()
def adaptation_index(isis):
"""Calculate adaptation index of `isis`."""
if len(isis) == 0:
return np.nan
return norm_diff(isis)
def latency(t, spikes, start):
"""Calculate time to the first spike."""
if len(spikes) == 0:
return np.nan
if start is None:
start = t[0]
return t[spikes[0]] - start
def average_rate(t, spikes, start, end):
"""Calculate average firing rate during interval between `start` and `end`.
Parameters
----------
t : numpy array of times in seconds
spikes : numpy array of spike indexes
start : start of time window for spike detection
end : end of time window for spike detection
Returns
-------
avg_rate : average firing rate in spikes/sec
"""
if start is None:
start = t[0]
if end is None:
end = t[-1]
spikes_in_interval = [spk for spk in spikes if t[spk] >= start and t[spk] <= end]
avg_rate = len(spikes_in_interval) / (end - start)
return avg_rate
def norm_diff(a):
"""Calculate average of (a[i] - a[i+1]) / (a[i] + a[i+1])."""
if len(a) <= 1:
return np.nan
a = a.astype(float)
if np.allclose((a[1:] + a[:-1]), 0.):
return 0.
norm_diffs = (a[1:] - a[:-1]) / (a[1:] + a[:-1])
norm_diffs[(a[1:] == 0) & (a[:-1] == 0)] = 0.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")
avg = np.nanmean(norm_diffs)
return avg
def norm_sq_diff(a):
"""Calculate average of (a[i+1] - a[i])^2 / (a[i] + a[i+1])^2."""
if len(a) <= 1:
return np.nan
a = a.astype(float)
norm_sq_diffs = np.square((a[1:] - a[:-1])) / np.square((a[1:] + a[:-1]))
return norm_sq_diffs.mean()
def isi_adaptation(a):
"""Calculate average of (a[1:]/a[:-1])."""
if len(a) <= 1:
return np.nan
a = a.astype(float)
isis_changes = a[1:] / a[:-1]
return isis_changes.mean()
def ap_amp_adaptation(a):
"""Calculate average of (a[1:]/a[:-1])."""
if len(a) <= 1:
return np.nan
a = a.astype(float)
ap_amp_changes = a[1:] / a[:-1]
return ap_amp_changes.mean()
def has_fixed_dt(t):
"""Check that all time intervals are identical."""
dt = np.diff(t)
return np.allclose(dt, np.ones_like(dt) * dt[0])
def fit_membrane_time_constant(v, t, start, end, min_rsme=1e-4):
"""Fit an exponential to estimate membrane time constant between start and end
Parameters
----------
v : numpy array of voltages in mV
t : numpy array of times in seconds
start : start of time window for exponential fit
end : end of time window for exponential fit
min_rsme: minimal acceptable root mean square error (default 1e-4)
Returns
-------
a, inv_tau, y0 : Coeffients of equation y0 + a * exp(-inv_tau * x)
returns np.nan for values if fit fails
"""
start_index = find_time_index(t, start)
end_index = find_time_index(t, end)
guess = (v[start_index] - v[end_index], 50., v[end_index])
t_window = (t[start_index:end_index] - t[start_index]).astype(np.float64)
v_window = v[start_index:end_index].astype(np.float64)
try:
popt, pcov = curve_fit(_exp_curve, t_window, v_window, p0=guess)
except RuntimeError:
logging.info("Curve fit for membrane time constant failed")
return np.nan, np.nan, np.nan
pred = _exp_curve(t_window, *popt)
#print('pred: ', pred)
#print('voltage: ', v_window)