-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathephys_extractor.py
executable file
·1213 lines (958 loc) · 49.8 KB
/
ephys_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2015-2016. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact [email protected] for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
import numpy as np
from pandas import DataFrame
import warnings
import logging
from collections import Counter
import ephys_features as ft
import six
# Constants for stimulus-specific analysis
RAMPS_START = 1.02
LONG_SQUARES_START = 1.02
LONG_SQUARES_END = 2.02
SHORT_SQUARES_WINDOW_START = 1.02
SHORT_SQUARES_WINDOW_END = 1.021
SHORT_SQUARE_TRIPLE_WINDOW_START = 2.02
SHORT_SQUARE_TRIPLE_WINDOW_END = 2.021
class EphysSweepFeatureExtractor:
"""Feature calculation for a sweep (voltage and/or current time series)."""
def __init__(self, t=None, v=None, i=None, start=None, end=None, filter=10.,
dv_cutoff=20., max_interval=0.005, min_height=2., min_peak=-30.,
thresh_frac=0.05, baseline_interval=0.1, baseline_detect_thresh=0.3,
id=None):
"""Initialize SweepFeatures object.
Parameters
----------
t : ndarray of times (seconds)
v : ndarray of voltages (mV)
i : ndarray of currents (pA)
start : start of time window for feature analysis (optional)
end : end of time window for feature analysis (optional)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20)
max_interval : maximum acceptable time between start of spike and time of peak in sec (optional, default 0.005)
min_height : minimum acceptable height from threshold to peak in mV (optional, default 2)
min_peak : minimum acceptable absolute peak level in mV (optional, default -30)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
baseline_interval: interval length for baseline voltage calculation (before start if start is defined, default 0.1)
baseline_detect_thresh : dV/dt threshold for evaluating flatness of baseline region (optional, default 0.3)
"""
self.id = id
self.t = t
self.v = v
self.i = i
self.start = start
self.end = end
self.filter = filter
self.dv_cutoff = dv_cutoff
self.max_interval = max_interval
self.min_height = min_height
self.min_peak = min_peak
self.thresh_frac = thresh_frac
self.baseline_interval = baseline_interval
self.baseline_detect_thresh = baseline_detect_thresh
self.stimulus_amplitude_calculator = None
self._sweep_features = {}
self._affected_by_clipping = []
def process_spikes(self):
"""Perform spike-related feature analysis"""
self._process_individual_spikes()
self._process_spike_related_features()
def _process_individual_spikes(self):
v = self.v
t = self.t
dvdt = ft.calculate_dvdt(v, t, self.filter)
# Basic features of spikes
putative_spikes = ft.detect_putative_spikes(v, t, self.start, self.end,
self.filter, self.dv_cutoff)
peaks = ft.find_peak_indexes(v, t, putative_spikes, self.end)
putative_spikes, peaks = ft.filter_putative_spikes(v, t, putative_spikes, peaks,
self.min_height, self.min_peak,
dvdt=dvdt, filter=self.filter)
if not putative_spikes.size:
# Save time if no spikes detected
self._spikes_df = DataFrame()
return
upstrokes = ft.find_upstroke_indexes(v, t, putative_spikes, peaks, self.filter, dvdt)
thresholds = ft.refine_threshold_indexes(v, t, upstrokes, filter = self.filter, dvdt = dvdt)
if t[thresholds][0]<self.start:
#print('The threshold was put before start of stimulation @ ', t[thresholds][0])
thresholds_ = ft.refine_threshold_indexes_based_on_third_derivative(v, t, peaks, upstrokes, filter=10., dvdt=dvdt)
#print('Calculated a new threshold @ ', t[thresholds_][0])
thresholds=thresholds_
thresholds, peaks, upstrokes, clipped = ft.check_thresholds_and_peaks(v, t, thresholds, peaks,
upstrokes, end = self.end, max_interval = self.max_interval,
filter = self.filter)
if not thresholds.size:
# Save time if no spikes detected
self._spikes_df = DataFrame()
return
# Spike list and thresholds have been refined - now find other features
upstrokes = ft.find_upstroke_indexes(v, t, thresholds, peaks, self.filter, dvdt)
troughs = ft.find_trough_indexes(v, t, thresholds, peaks, clipped, self.end)
# You can comment the following lines if you don't want to check this
thresholds, upstrokes, peaks, troughs, clipped = ft.check_trough_w_peak(thresholds, upstrokes, peaks, troughs, \
clipped, filter = 10., dvdt = None)
# Maybe you have nothing anymore so save time by the following:
if not thresholds.size:
self._spikes_df = DataFrame()
return
downstrokes = ft.find_downstroke_indexes(v, t, peaks, troughs, clipped, self.filter, dvdt)
trough_details, clipped = ft.analyze_trough_details(v, t, thresholds, peaks, clipped, self.end,
self.filter, dvdt=dvdt)
widths = ft.find_widths_wrt_threshold(v, t, thresholds, peaks, trough_details[1], clipped)
base_clipped_list = []
# Points where we care about t, v, and i if available
vit_data_indexes = {
"threshold": thresholds,
"peak": peaks,
"trough": troughs,
}
base_clipped_list += ["trough"]
# Points where we care about t and dv/dt
dvdt_data_indexes = {
"upstroke": upstrokes,
"downstroke": downstrokes
}
base_clipped_list += ["downstroke"]
# Trough details
isi_types = trough_details[0]
trough_detail_indexes = dict(zip(["fast_trough", "adp", "slow_trough"], trough_details[1:]))
base_clipped_list += ["fast_trough", "adp", "slow_trough"]
# Redundant, but ensures that DataFrame has right number of rows
# Any better way to do it?
spikes_df = DataFrame(data=thresholds, columns=["threshold_index"])
spikes_df["clipped"] = clipped
for k, all_vals in six.iteritems(vit_data_indexes):
valid_ind = ~np.isnan(all_vals)
vals = all_vals[valid_ind].astype(int)
spikes_df[k + "_index"] = np.nan
spikes_df[k + "_t"] = np.nan
spikes_df[k + "_v"] = np.nan
if len(vals) > 0:
spikes_df.loc[valid_ind, k + "_index"] = vals
spikes_df.loc[valid_ind, k + "_t"] = t[vals]
spikes_df.loc[valid_ind, k + "_v"] = v[vals]
if self.i is not None:
spikes_df[k + "_i"] = np.nan
if len(vals) > 0:
spikes_df.loc[valid_ind, k + "_i"] = self.i[vals]
if k in base_clipped_list:
self._affected_by_clipping += [
k + "_index",
k + "_t",
k + "_v",
k + "_i",
]
for k, all_vals in six.iteritems(dvdt_data_indexes):
valid_ind = ~np.isnan(all_vals)
vals = all_vals[valid_ind].astype(int)
spikes_df[k + "_index"] = np.nan
spikes_df[k] = np.nan
if len(vals) > 0:
spikes_df.loc[valid_ind, k + "_index"] = vals
spikes_df.loc[valid_ind, k + "_t"] = t[vals]
spikes_df.loc[valid_ind, k + "_v"] = v[vals]
spikes_df.loc[valid_ind, k] = dvdt[vals]
if k in base_clipped_list:
self._affected_by_clipping += [
k + "_index",
k + "_t",
k + "_v",
k,
]
spikes_df["isi_type"] = isi_types
self._affected_by_clipping += ["isi_type"]
for k, all_vals in six.iteritems(trough_detail_indexes):
valid_ind = ~np.isnan(all_vals)
vals = all_vals[valid_ind].astype(int)
spikes_df[k + "_index"] = np.nan
spikes_df[k + "_t"] = np.nan
spikes_df[k + "_v"] = np.nan
if len(vals) > 0:
spikes_df.loc[valid_ind, k + "_index"] = vals
spikes_df.loc[valid_ind, k + "_t"] = t[vals]
spikes_df.loc[valid_ind, k + "_v"] = v[vals]
if self.i is not None:
spikes_df[k + "_i"] = np.nan
if len(vals) > 0:
spikes_df.loc[valid_ind, k + "_i"] = self.i[vals]
if k in base_clipped_list:
self._affected_by_clipping += [
k + "_index",
k + "_t",
k + "_v",
k + "_i",
]
spikes_df["width"] = widths
self._affected_by_clipping += ["width"]
spikes_df["upstroke_downstroke_ratio"] = spikes_df["upstroke"] / -spikes_df["downstroke"]
self._affected_by_clipping += ["upstroke_downstroke_ratio"]
self._spikes_df = spikes_df
def _process_spike_related_features(self):
t = self.t
if len(self._spikes_df) == 0:
self._sweep_features["avg_rate"] = 0
return
# Start recently added
peak_heights = None
if not self._spikes_df.empty:
peak_heights = self._spikes_df['peak_v'].values - self._spikes_df['threshold_v'].values
# End recently added
thresholds = self._spikes_df["threshold_index"].values.astype(int)
isis = ft.get_isis(t, thresholds)
with warnings.catch_warnings():
# ignore mean of empty slice warnings here
warnings.filterwarnings("ignore", category=RuntimeWarning, module="numpy")
sweep_level_features = {
"adapt": ft.adaptation_index(isis),
"latency": ft.latency(t, thresholds, self.start),
"isi_cv": (isis.std() / isis.mean()) if len(isis) >= 1 else np.nan,
"mean_isi": isis.mean() if len(isis) > 0 else np.nan,
"median_isi": np.median(isis),
"first_isi": isis[0] if len(isis) >= 1 else np.nan,
# We want at least 3 peaks (i.e. 2 isis) to calculate the adaptation index (given in percentage)
"isi_adapt": (isis[1]/isis[0]) if len(isis) >= 2 else np.nan,
# Start recently added
#"AP_amp_adapt": self._spikes_df['peak_height'][1]/self._spikes_df['peak_height'][0] if self._spikes_df.shape[1] >= 2 else np.nan,
"AP_amp_adapt": (peak_heights[1]/peak_heights[0]) if peak_heights.size >= 2 else np.nan,
#"AP_amp_change": ft.ap_amp_change(self._spikes_df['peak_height'].values) if self._spikes_df.shape.shape[1] >= 2 else np.nan,
"AP_amp_adapt_average": ft.ap_amp_adaptation(peak_heights) if peak_heights.size >= 2 else np.nan,
# End recently added
"AP_fano_factor": ((peak_heights.std()**2)/peak_heights.mean()) if peak_heights.size >=2 else np.nan,
"AP_cv": ((peak_heights.std())/peak_heights.mean()) if peak_heights.size >=2 else np.nan,
"isi_adapt_average": ft.isi_adaptation(isis) if len(isis) >= 2 else np.nan,
#"norm_sq_isis": ft.norm_sq_diff(isis) if len(isis) >= 2 else np.nan,
# You could in principle make the Fano factor and cv 0 for n = 1 ISI, but we choose to make them Nan, i.e. they
# are not that informative here
"fano_factor": ((isis.std()**2) / isis.mean()) if len(isis) > 1 else np.nan,
"cv": (isis.std() / isis.mean()) if len(isis) > 1 else np.nan,
"avg_rate": ft.average_rate(t, thresholds, self.start, self.end)
}
for k, v in six.iteritems(sweep_level_features):
self._sweep_features[k] = v
def _process_pauses(self, cost_weight=1.0):
# Pauses are unusually long ISIs with a "detour reset" among delay resets
thresholds = self._spikes_df["threshold_index"].values.astype(int)
isis = ft.get_isis(self.t, thresholds)
isi_types = self._spikes_df["isi_type"][:-1].values
return ft.detect_pauses(isis, isi_types, cost_weight)
def pause_metrics(self):
"""Estimate average number of pauses and average fraction of time spent in a pause
Attempts to detect pauses with a variety of conditions and averages results together.
Pauses that are consistently detected contribute more to estimates.
Returns
-------
avg_n_pauses : average number of pauses detected across conditions
avg_pause_frac : average fraction of interval (between start and end) spent in a pause
max_reliability : max fraction of times most reliable pause was detected given weights tested
n_max_rel_pauses : number of pauses detected with `max_reliability`
"""
thresholds = self._spikes_df["threshold_index"].values.astype(int)
isis = ft.get_isis(self.t, thresholds)
weight = 1.0
pause_list = self._process_pauses(weight)
if len(pause_list) == 0:
return 0, 0.
n_pauses = len(pause_list)
pause_frac = isis[pause_list].sum()
pause_frac /= self.end - self.start
return n_pauses, pause_frac
def _process_bursts(self, tol=0.5, pause_cost=1.0):
thresholds = self._spikes_df["threshold_index"].values.astype(int)
isis = ft.get_isis(self.t, thresholds)
isi_types = self._spikes_df["isi_type"][:-1].values
fast_tr_v = self._spikes_df["fast_trough_v"].values
fast_tr_t = self._spikes_df["fast_trough_t"].values
slow_tr_v = self._spikes_df["slow_trough_v"].values
slow_tr_t = self._spikes_df["slow_trough_t"].values
thr_v = self._spikes_df["threshold_v"].values
bursts = ft.detect_bursts(isis, isi_types, fast_tr_v, fast_tr_t, slow_tr_v, slow_tr_t,
thr_v, tol, pause_cost)
return np.array(bursts)
def burst_metrics(self):
"""Find bursts and return max "burstiness" index (normalized max rate in burst vs out).
Returns
-------
max_burstiness_index : max "burstiness" index across detected bursts
num_bursts : number of bursts detected
"""
burst_info = self._process_bursts()
if burst_info.shape[0] > 0:
return burst_info[:, 0].max(), burst_info.shape[0]
else:
return 0., 0
def delay_metrics(self):
"""Calculates ratio of latency to dominant time constant of rise before spike
Returns
-------
delay_ratio : ratio of latency to tau (higher means more delay)
tau : dominant time constant of rise before spike
"""
if len(self._spikes_df) == 0:
logging.info("No spikes available for delay calculation")
return 0., 0.
start = self.start
spike_time = self._spikes_df["threshold_t"].values[0]
tau = ft.fit_prespike_time_constant(self.v, self.t, start, spike_time)
latency = spike_time - start
delay_ratio = latency / tau
return delay_ratio, tau
def _get_baseline_voltage(self):
v = self.v
t = self.t
filter_frequency = 1. # in kHz
# Look at baseline interval before start if start is defined
if self.start is not None:
return ft.average_voltage(v, t, self.start - self.baseline_interval, self.start)
# Otherwise try to find an interval where things are pretty flat
dv = ft.calculate_dvdt(v, t, filter_frequency)
non_flat_points = np.flatnonzero(np.abs(dv >= self.baseline_detect_thresh))
flat_intervals = t[non_flat_points[1:]] - t[non_flat_points[:-1]]
long_flat_intervals = np.flatnonzero(flat_intervals >= self.baseline_interval)
if long_flat_intervals.size > 0:
interval_index = long_flat_intervals[0] + 1
baseline_end_time = t[non_flat_points[interval_index]]
return ft.average_voltage(v, t, baseline_end_time - self.baseline_interval,
baseline_end_time)
else:
logging.info("Could not find sufficiently flat interval for automatic baseline voltage", RuntimeWarning)
return np.nan
def voltage_deflection(self, deflect_type=None):
"""Measure deflection (min or max, between start and end if specified).
Parameters
----------
deflect_type : measure minimal ('min') or maximal ('max') voltage deflection
If not specified, it will check to see if the current (i) is positive or negative
between start and end, then choose 'max' or 'min', respectively
If the current is not defined, it will default to 'min'.
Returns
-------
deflect_v : peak
deflect_index : index of peak deflection
"""
deflect_dispatch = {
"min": np.argmin,
"max": np.argmax,
}
start = self.start
if not start:
start = 0
start_index = ft.find_time_index(self.t, start)
end = self.end - 0.1 # Let us add -0.1 because we don't expect to find a trough that close to the end of current stimulation
# This actually helps us ignore cases where the voltage acts funny (i.e. drops mistakenly taken as trough)
# right at current stimulation end.
if not end:
end = self.t[-1]
end_index = ft.find_time_index(self.t, end)
if deflect_type is None:
if self.i is not None:
halfway_index = ft.find_time_index(self.t, (end - start) / 2. + start)
if self.i[halfway_index] >= 0:
deflect_type = "max"
else:
deflect_type = "min"
else:
deflect_type = "min"
deflect_func = deflect_dispatch[deflect_type]
v_window = self.v[start_index:end_index]
deflect_index = deflect_func(v_window) + start_index
return self.v[deflect_index], deflect_index
def stimulus_amplitude(self):
""" """
if self.stimulus_amplitude_calculator is not None:
return self.stimulus_amplitude_calculator(self)
else:
return np.nan
def estimate_time_constant(self):
"""Calculate the membrane time constant by fitting the voltage response with a
single exponential.
Returns
-------
tau : membrane time constant in seconds
"""
# Assumes this is being done on a hyperpolarizing step
v_peak, peak_index = self.voltage_deflection("min")
v_baseline = self.sweep_feature("v_baseline")
if self.start:
start_index = ft.find_time_index(self.t, self.start)
else:
start_index = 0
frac = 0.1
search_result = np.flatnonzero(self.v[start_index:] <= frac * (v_peak - v_baseline) + v_baseline)
if not search_result.size:
raise ft.FeatureError("could not find interval for time constant estimate")
fit_start = self.t[search_result[0] + start_index]
fit_end = self.t[peak_index]
# There was one cell with a noisy (?) peak downwards (to -250 mV) unfortunately. That's why we have the if-statement here.
# You can delete this if-statement if you have normal traces.
# If this all still didn't work as expected, then hopefully there are more hyperpolarisation traces for which tau can be estimated
if (self.v[peak_index] < -200) :
print("A DOWNWARD PEAK WAS OBSERVED GOING TO LESS THAN 200 MV!!!")
# Look for another local minimum closer to stimulus onset
# We look for a couple of milliseconds after stimulus onset to 50 ms before the downward peak
end_index = (start_index + 50) + np.argmin(self.v[start_index + 50 : peak_index - 1250])
fit_end = self.t[end_index]
fit_start = self.t[start_index + 50]
a, inv_tau, y0 = ft.fit_membrane_time_constant(self.v, self.t, fit_start, fit_end)
return 1. / inv_tau
def estimate_time_constant_at_end(self):
"""Calculate the membrane time constant by fitting the voltage response with a single expontial at the end of a hyperpolarising
stimulus.
Returns
-------
tau : membrane time constant in seconds
"""
# Assumes this is being done on a hyperpolarizing step
v_peak, peak_index = self.voltage_deflection("min")
v_baseline = self.sweep_feature("v_baseline")
if self.end:
start_index = ft.find_time_index(self.t, self.end)
else:
start_index = ft.find_time_index(self.t, 0.7)
frac = 0.1
search_result = np.flatnonzero(self.v[start_index:] >= frac * (v_baseline - v_peak) + v_peak)
if not search_result.size:
raise ft.FeatureError("Could not find interval for time constant estimate")
fit_start = self.t[search_result[0] + start_index]
fit_end = self.t[-1]
b, inv_tau, A = ft.fit_membrane_time_constant_at_end(self.v, self.t, fit_start, fit_end)
return 1. / inv_tau
def estimate_sag(self, peak_width=0.005):
"""Calculate the sag in a hyperpolarizing voltage response.
Parameters
----------
peak_width : window width to get more robust peak estimate in sec (default 0.005)
Returns
-------
sag : fraction that membrane potential relaxes back to baseline
sag_ratio: ratio of steady state voltage decrease to the largest voltage decrease
"""
t = self.t
v = self.v
start = self.start
if not start:
start = 0
end = self.end # To calculate the steady state, not the peak deflection (see code below)
if not end:
end = self.t[-1]
v_peak, peak_index = self.voltage_deflection("min")
# There was one cell with a noisy (?) peak downwards (to -250 mV) unfortunately. That's why we have the if-statement here.
# You can delete this if-statement if you have normal traces.
# If this all still didn't work as expected, then hopefully there are more hyperpolarisation traces for which tau can be estimated
if (self.v[peak_index] < -200) :
print("A DOWNWARD PEAK WAS OBSERVED GOING TO LESS THAN 200 MV!!!")
# Look for another local minimum closer to stimulus onset
# A spike should only last about a couple of milliseconds, so let's look a bit before the 'spike'
peak_index = peak_index - (ft.find_time_index(t, 0.12) - ft.find_time_index(t, 0.1))
#print(t[peak_index])
v_peak_avg = ft.average_voltage(v, t, start=t[peak_index] - peak_width / 2.,
end=t[peak_index] + peak_width / 2.)
v_baseline = self.sweep_feature("v_baseline")
v_steady = ft.average_voltage(v, t, start=end - self.baseline_interval, end=end)
#print('v_stead: ', v_steady)
#print('v_baseline: ', v_baseline)
#print('v_peak_avg: ', v_peak_avg)
#print('denominater=v_stead-v_baseline: ', v_steady-v_baseline)
#print('numerator=v_peak_avg-v_baseline: ', v_peak_avg-v_baseline)
sag = (v_peak_avg - v_steady) / (v_peak_avg - v_baseline)
sag_ratio = (v_peak_avg - v_baseline)/(v_steady-v_baseline)
#print(sag_ratio)
return sag, sag_ratio
def spikes(self):
"""Get all features for each spike as a list of records."""
return self._spikes_df.to_dict('records')
def spike_feature(self, key, include_clipped=False, force_exclude_clipped=False):
"""Get specified feature for every spike.
Parameters
----------
key : feature name
include_clipped: return values for every identified spike, even when clipping means they will be incorrect/undefined
Returns
-------
spike_feature_values : ndarray of features for each spike
"""
if not hasattr(self, "_spikes_df"):
raise AttributeError("EphysSweepFeatureExtractor instance attribute with spike information does not exist yet - have spikes been processed?")
if len(self._spikes_df) == 0:
return np.array([])
if key not in self._spikes_df.columns:
raise KeyError("requested feature '{:s}' not available".format(key))
values = self._spikes_df[key].values
if include_clipped and force_exclude_clipped:
raise ValueError("include_clipped and force_exclude_clipped cannot both be true")
if not include_clipped and self.is_spike_feature_affected_by_clipping(key):
values = values[~self._spikes_df["clipped"].values]
elif force_exclude_clipped:
values = values[~self._spikes_df["clipped"].values]
return values
def is_spike_feature_affected_by_clipping(self, key):
return key in self._affected_by_clipping
def spike_feature_keys(self):
"""Get list of every available spike feature."""
return self._spikes_df.columns.values.tolist()
def sweep_feature(self, key, allow_missing=False):
"""Get sweep-level feature (`key`).
Parameters
----------
key : name of sweep-level feature
allow_missing : return np.nan if key is missing for sweep (default False)
Returns
-------
sweep_feature : sweep-level feature value
"""
on_request_dispatch = {
"v_baseline": self._get_baseline_voltage,
"tau": self.estimate_time_constant,
"sag": self.estimate_sag,
"peak_deflect": self.voltage_deflection,
"stim_amp": self.stimulus_amplitude,
}
if allow_missing and key not in self._sweep_features and key not in on_request_dispatch:
return np.nan
elif key not in self._sweep_features and key not in on_request_dispatch:
raise KeyError("requested feature '{:s}' not available".format(key))
if key not in self._sweep_features and key in on_request_dispatch:
fn = on_request_dispatch[key]
if fn is not None:
self._sweep_features[key] = fn()
else:
raise KeyError("requested feature '{:s}' not defined".format(key))
return self._sweep_features[key]
def process_new_spike_feature(self, feature_name, feature_func, affected_by_clipping=False):
"""Add new spike-level feature calculation function
The function should take this sweep extractor as its argument. Its results
can be accessed by calling the method spike_feature(<feature_name>).
"""
if feature_name in self._spikes_df.columns:
raise KeyError("Feature {:s} already exists for sweep".format(feature_name))
self._spikes_df[feature_name] = feature_func(self)
if affected_by_clipping:
self._affected_by_clipping.append(feature_name)
def process_new_sweep_feature(self, feature_name, feature_func):
"""Add new sweep-level feature calculation function
The function should take this sweep extractor as its argument. Its results
can be accessed by calling the method sweep_feature(<feature_name>).
"""
if feature_name in self._sweep_features:
raise KeyError("Feature {:s} already exists for sweep".format(feature_name))
self._sweep_features[feature_name] = feature_func(self)
def set_stimulus_amplitude_calculator(self, function):
self.stimulus_amplitude_calculator = function
def sweep_feature_keys(self):
"""Get list of every available sweep-level feature."""
return self._sweep_features.keys()
def as_dict(self):
"""Create dict of features and spikes."""
output_dict = self._sweep_features.copy()
output_dict["spikes"] = self.spikes()
if self.id is not None:
output_dict["id"] = self.id
return output_dict
class EphysSweepSetFeatureExtractor:
def __init__(self, t_set=None, v_set=None, i_set=None, start=None, end=None,
filter=10., dv_cutoff=20., max_interval=0.005, min_height=2.,
min_peak=-30., thresh_frac=0.05, baseline_interval=0.1,
baseline_detect_thresh=0.3, id_set=None):
"""Initialize EphysSweepSetFeatureExtractor object.
Parameters
----------
t_set : list of ndarray of times in seconds
v_set : list of ndarray of voltages in mV
i_set : list of ndarray of currents in pA
start : start of time window for feature analysis (optional, can be list)
end : end of time window for feature analysis (optional, can be list)
filter : cutoff frequency for 4-pole low-pass Bessel filter in kHz (optional, default 10)
dv_cutoff : minimum dV/dt to qualify as a spike in V/s (optional, default 20)
max_interval : maximum acceptable time between start of spike and time of peak in sec (optional, default 0.005)
min_height : minimum acceptable height from threshold to peak in mV (optional, default 2)
min_peak : minimum acceptable absolute peak level in mV (optional, default -30)
thresh_frac : fraction of average upstroke for threshold calculation (optional, default 0.05)
baseline_interval: interval length for baseline voltage calculation (before start if start is defined, default 0.1)
baseline_detect_thresh : dV/dt threshold for evaluating flatness of baseline region (optional, default 0.3)
"""
if t_set is not None and v_set is not None:
self._set_sweeps(t_set, v_set, i_set, start, end, filter, dv_cutoff, max_interval,
min_height, min_peak, thresh_frac, baseline_interval,
baseline_detect_thresh, id_set)
else:
self._sweeps = None
@classmethod
def from_sweeps(cls, sweep_list):
"""Initialize EphysSweepSetFeatureExtractor object with a list of pre-existing
sweep feature extractor objects.
"""
obj = cls()
obj._sweeps = sweep_list
return obj
def _set_sweeps(self, t_set, v_set, i_set, start, end, filter, dv_cutoff, max_interval,
min_height, min_peak, thresh_frac, baseline_interval,
baseline_detect_thresh, id_set):
if type(t_set) != list:
raise ValueError("t_set must be a list")
if type(v_set) != list:
raise ValueError("v_set must be a list")
if i_set is not None and type(i_set) != list:
raise ValueError("i_set must be a list")
if len(t_set) != len(v_set):
raise ValueError("t_set and v_set must have the same number of items")
if i_set and len(t_set) != len(i_set):
raise ValueError("t_set and i_set must have the same number of items")
if id_set is None:
id_set = range(len(t_set))
if len(id_set) != len(t_set):
raise ValueError("t_set and id_set must have the same number of items")
sweeps = []
if i_set is None:
i_set = [None] * len(t_set)
if type(start) is not list:
start = [start] * len(t_set)
end = [end] * len(t_set)
sweeps = [ EphysSweepFeatureExtractor(t, v, i, start, end,
filter=filter, dv_cutoff=dv_cutoff,
max_interval=max_interval,
min_height=min_height, min_peak=min_peak,
thresh_frac=thresh_frac,
baseline_interval=baseline_interval,
baseline_detect_thresh=baseline_detect_thresh,
id=sid) \
for t, v, i, start, end, sid in zip(t_set, v_set, i_set, start, end, id_set) ]
self._sweeps = sweeps
def sweeps(self):
"""Get list of EphysSweepFeatureExtractor objects."""
return self._sweeps
def process_spikes(self):
"""Analyze spike features for all sweeps."""
for sweep in self._sweeps:
sweep.process_spikes()
def sweep_features(self, key, allow_missing=False):
"""Get nparray of sweep-level feature (`key`) for all sweeps
Parameters
----------
key : name of sweep-level feature
allow_missing : return np.nan if key is missing for sweep (default False)
Returns
-------
sweep_feature : nparray of sweep-level feature values
"""
return np.array([swp.sweep_feature(key, allow_missing) for swp in self._sweeps])
def spike_feature_averages(self, key):
"""Get nparray of average spike-level feature (`key`) for all sweeps"""
return np.array([swp.spike_feature(key).mean() for swp in self._sweeps])
class EphysCellFeatureExtractor:
# Class constants for specific processing
SUBTHRESH_MAX_AMP = 0
SAG_TARGET = -100.
def __init__(self, ramps_ext, short_squares_ext, long_squares_ext, subthresh_min_amp=-100):
"""Initialize EphysCellFeatureExtractor object from EphysSweepSetExtractors for
ramp, short square, and long square sweeps.
Parameters
----------
dataset : NwbDataSet
ramps_ext : EphysSweepSetFeatureExtractor prepared with ramp sweeps
short_squares_ext : EphysSweepSetFeatureExtractor prepared with short square sweeps
long_squares_ext : EphysSweepSetFeatureExtractor prepared with long square sweeps
"""
self._ramps_ext = ramps_ext
self._short_squares_ext = short_squares_ext
self._long_squares_ext = long_squares_ext
self._subthresh_min_amp = subthresh_min_amp
self._features = {
"ramps": {},
"short_squares": {},
"long_squares": {},
}
self._spiking_long_squares_ext = None
self._subthreshold_long_squares_ext = None
self._subthreshold_membrane_property_ext = None
def process(self, keys=None):
"""Processes features. Can take a specific key (or set of keys) to do a subset of processing."""
dispatch = {
"ramps": self._analyze_ramps,
"short_squares": self._analyze_short_squares,
"long_squares": self._analyze_long_squares,
"long_squares_spiking": self._analyze_long_squares_spiking,
}
if keys is None:
keys = dispatch.keys()
if type(keys) is not list:
keys = [keys]
for k in [j for j in keys if j in dispatch.keys()]:
dispatch[k]()
def _analyze_ramps(self):
ext = self._ramps_ext
ext.process_spikes()
self._all_ramps_ext = ext
# pull out the spiking sweeps
spiking_sweeps = [ sweep for sweep in self._ramps_ext.sweeps() if sweep.sweep_feature("avg_rate") > 0 ]
ext = EphysSweepSetFeatureExtractor.from_sweeps(spiking_sweeps)
self._ramps_ext = ext
self._features["ramps"]["spiking_sweeps"] = ext.sweeps()
def ramps_features(self, all=False):
if all:
return self._all_ramps_ext
else:
return self._ramps_ext
def _analyze_short_squares(self):
ext = self._short_squares_ext
ext.process_spikes()
# Need to count how many had spikes at each amplitude; find most; ties go to lower amplitude
spiking_sweeps = [sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0]
if len(spiking_sweeps) == 0:
raise ft.FeatureError("No spiking short square sweeps, cannot compute cell features.")
most_common = Counter(map(_short_step_stim_amp, spiking_sweeps)).most_common()
common_amp, common_count = most_common[0]
for c in most_common[1:]:
if c[1] < common_count:
break
if c[0] < common_amp:
common_amp = c[0]
self._features["short_squares"]["stimulus_amplitude"] = common_amp
ext = EphysSweepSetFeatureExtractor.from_sweeps([sweep for sweep in spiking_sweeps if _short_step_stim_amp(sweep) == common_amp])
self._short_squares_ext = ext
self._features["short_squares"]["common_amp_sweeps"] = ext.sweeps()
for s in self._features["short_squares"]["common_amp_sweeps"]:
s.set_stimulus_amplitude_calculator(_short_step_stim_amp)
def short_squares_features(self):
return self._short_squares_ext
def _analyze_long_squares(self):
self._analyze_long_squares_spiking()
self._analyze_long_squares_subthreshold()
def _analyze_long_squares_spiking(self, force_reprocess=False):
if not force_reprocess and self._spiking_long_squares_ext:
return
ext = self._long_squares_ext
ext.process_spikes()
self._features["long_squares"]["sweeps"] = ext.sweeps()
for s in self._features["long_squares"]["sweeps"]:
s.set_stimulus_amplitude_calculator(_step_stim_amp)
spiking_indexes = np.flatnonzero(ext.sweep_features("avg_rate"))
if len(spiking_indexes) == 0:
raise ft.FeatureError("No spiking long square sweeps, cannot compute cell features.")
amps = ext.sweep_features("stim_amp")#self.long_squares_stim_amps()
min_index = np.argmin(amps[spiking_indexes])
rheobase_index = spiking_indexes[min_index]
rheobase_i = _step_stim_amp(ext.sweeps()[rheobase_index])
self._features["long_squares"]["rheobase_extractor_index"] = rheobase_index
self._features["long_squares"]["rheobase_i"] = rheobase_i
self._features["long_squares"]["rheobase_sweep"] = ext.sweeps()[rheobase_index]
spiking_sweeps = [sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") > 0]
self._spiking_long_squares_ext = EphysSweepSetFeatureExtractor.from_sweeps(spiking_sweeps)
self._features["long_squares"]["spiking_sweeps"] = self._spiking_long_squares_ext.sweeps()
self._features["long_squares"]["fi_fit_slope"] = fit_fi_slope(self._spiking_long_squares_ext)
def _analyze_long_squares_subthreshold(self):
ext = self._long_squares_ext
subthresh_sweeps = [sweep for sweep in ext.sweeps() if sweep.sweep_feature("avg_rate") == 0]
subthresh_ext = EphysSweepSetFeatureExtractor.from_sweeps(subthresh_sweeps)
self._subthreshold_long_squares_ext = subthresh_ext
if len(subthresh_ext.sweeps()) == 0:
raise ft.FeatureError("No subthreshold long square sweeps, cannot evaluate cell features.")
peaks = subthresh_ext.sweep_features("peak_deflect")
sags = subthresh_ext.sweep_features("sag")
sag_eval_levels = np.array([sweep.voltage_deflection()[0] for sweep in subthresh_ext.sweeps()])
target_level = self.SAG_TARGET
closest_index = np.argmin(np.abs(sag_eval_levels - target_level))
self._features["long_squares"]["sag"] = sags[closest_index]
self._features["long_squares"]["vm_for_sag"] = sag_eval_levels[closest_index]
self._features["long_squares"]["subthreshold_sweeps"] = subthresh_ext.sweeps()
for s in self._features["long_squares"]["subthreshold_sweeps"]:
s.set_stimulus_amplitude_calculator(_step_stim_amp)
logging.debug("subthresh_sweeps: %d", len(subthresh_sweeps))
calc_subthresh_sweeps = [sweep for sweep in subthresh_sweeps if