Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…ulation into main
  • Loading branch information
timonmerk committed Aug 14, 2024
2 parents fbf12d3 + 88d15aa commit 10f2945
Show file tree
Hide file tree
Showing 37 changed files with 558 additions and 243 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ coverage.xml
*.py,cover
.hypothesis/
.pytest_cache/
/test_data

# Translations
*.mo
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_0_first_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def generate_random_walk(NUM_CHANNELS, TIME_DATA_SAMPLES):
line_noise=50,
)

features = stream.run(data)
features = stream.run(data, save_csv=True)

# %%
# Feature Analysis
Expand Down
13 changes: 6 additions & 7 deletions examples/plot_1_example_BIDS.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@
line_noise,
coord_list,
coord_names,
) = nm_IO.read_BIDS_data(
PATH_RUN=PATH_RUN
)
) = nm_IO.read_BIDS_data(PATH_RUN=PATH_RUN)

nm_channels = nm_define_nmchannels.set_channels(
ch_names=raw.ch_names,
Expand Down Expand Up @@ -94,9 +92,9 @@
settings.features.sharpwave_analysis = True
settings.features.coherence = True

settings.coherence.channels = [("LFP_RIGHT_0", "ECOG_RIGHT_0")]
settings.coherence.channels = [("LFP_RIGHT_0", "ECOG_RIGHT_0")]

settings.coherence.frequency_bands = ["high beta", "low gamma"]
settings.coherence.frequency_bands = ["high_beta", "low_gamma"]
settings.sharpwave_analysis_settings.estimator["mean"] = []
settings.sharpwave_analysis_settings.sharpwave_features.enable_all()
for sw_feature in settings.sharpwave_analysis_settings.sharpwave_features.list_all():
Expand All @@ -118,6 +116,7 @@
data=data,
out_path_root=PATH_OUT,
folder_name=RUN_NAME,
save_csv=True,
)

# %%
Expand Down Expand Up @@ -168,9 +167,9 @@
# %%
nm_plots.plot_corr_matrix(
feature=feature_reader.feature_arr.filter(regex="ECOG_RIGHT_0"),
ch_name="ECOG_RIGHT_0-avgref",
ch_name="ECOG_RIGHT_0_avgref",
feature_names=list(
feature_reader.feature_arr.filter(regex="ECOG_RIGHT_0-avgref").columns
feature_reader.feature_arr.filter(regex="ECOG_RIGHT_0_avgref").columns
),
feature_file=feature_reader.feature_file,
show_plot=True,
Expand Down
6 changes: 3 additions & 3 deletions examples/plot_3_example_sharpwave_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,16 +310,16 @@
verbose=True,
)

df_features = stream.run(data=data[:, :30000])
df_features = stream.run(data=data[:, :30000], save_csv=True)

# %%
# We can then plot two exemplary features, prominence and interval, and see that the movement amplitude can be clustered with those two features alone:

plt.figure(figsize=(5, 3), dpi=300)
print(df_features.columns)
plt.scatter(
df_features["ECOG_RIGHT_0-avgref_Sharpwave_Max_prominence_range_5_80"],
df_features["ECOG_RIGHT_5-avgref_Sharpwave_Mean_interval_range_5_80"],
df_features["ECOG_RIGHT_0_avgref_Sharpwave_Max_prominence_range_5_80"],
df_features["ECOG_RIGHT_5_avgref_Sharpwave_Mean_interval_range_5_80"],
c=df_features.MOV_RIGHT,
alpha=0.8,
s=30,
Expand Down
1 change: 1 addition & 0 deletions examples/plot_4_example_gridPointProjection.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
data=data[:, :int(sfreq*5)],
out_path_root=PATH_OUT,
folder_name=RUN_NAME,
save_csv=True,
)

# %%
Expand Down
62 changes: 43 additions & 19 deletions py_neuromodulation/nm_IO.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from py_neuromodulation import logger, PYNM_DIR

if TYPE_CHECKING:
from py_neuromodulation.nm_settings import NMSettings
from mne_bids import BIDSPath
from mne import io as mne_io

Expand Down Expand Up @@ -78,11 +77,12 @@ def read_BIDS_data(
coord_names,
)


def read_mne_data(
PATH_RUN: "_PathLike | BIDSPath",
line_noise: int = 50,
):
"""Read data in the mne.io.read_raw supported format.
"""Read data in the mne.io.read_raw supported format.
Parameters
----------
Expand Down Expand Up @@ -118,9 +118,10 @@ def read_mne_data(
logger.info(
f"Line noise is not available in the data, using value of {line_noise} Hz."
)

return raw_arr.get_data(), sfreq, ch_names, ch_types, bads


def get_coord_list(
raw: "mne_io.BaseRaw",
) -> tuple[list, list] | tuple[None, None]:
Expand Down Expand Up @@ -228,6 +229,7 @@ def read_plot_modules(
z_stn,
)


def write_csv(df, path_out):
"""
Function to save Pandas dataframes to disk as CSV using
Expand All @@ -239,40 +241,46 @@ def write_csv(df, path_out):

csv.write_csv(Table.from_pandas(df), path_out)


def save_nm_channels(
nmchannels: pd.DataFrame,
path_out: _PathLike,
folder_name: str = "",
out_dir: _PathLike,
prefix: str = "",
) -> None:
if folder_name:
path_out = PurePath(path_out, folder_name, folder_name + "_nm_channels.csv")
filename = f"{prefix}_nm_channels.csv" if prefix else "nm_channels.csv"
path_out = PurePath(out_dir, filename)
write_csv(nmchannels, path_out)
logger.info(f"nm_channels.csv saved to {path_out}")


def save_features(
df_features: pd.DataFrame,
path_out: _PathLike,
folder_name: str = "",
out_dir: _PathLike,
prefix: str = "",
) -> None:
if folder_name:
path_out = PurePath(path_out, folder_name, folder_name + "_FEATURES.csv")
write_csv(df_features, path_out)
logger.info(f"FEATURES.csv saved to {str(path_out)}")
filename = f"{prefix}_FEATURES.csv" if prefix else "_FEATURES.csv"
out_dir = PurePath(out_dir, filename)
write_csv(df_features, out_dir)
logger.info(f"FEATURES.csv saved to {str(out_dir)}")


def save_sidecar(sidecar: dict, path_out: _PathLike, folder_name: str = "") -> None:
save_general_dict(sidecar, path_out, "_SIDECAR.json", folder_name)
def save_sidecar(
sidecar: dict,
out_dir: _PathLike,
prefix: str = "",
) -> None:
save_general_dict(sidecar, out_dir, prefix, "_SIDECAR.json")


def save_general_dict(
dict_: dict,
path_out: _PathLike,
out_dir: _PathLike,
prefix: str = "",
str_add: str = "",
folder_name: str = "",
) -> None:
if folder_name:
path_out = PurePath(path_out, folder_name, folder_name + str_add)
# We should change this to a proper experiment name

path_out = PurePath(out_dir, f"{prefix}{str_add}")

with open(path_out, "w") as f:
json.dump(
Expand Down Expand Up @@ -387,3 +395,19 @@ def _todict(matobj) -> dict:
else:
dict[strg] = elem
return dict


def generate_unique_filename(path: _PathLike):
path = Path(path)

dir = path.parent
filename = path.stem
extension = path.suffix

counter = 1
while True:
new_filename = f"{filename}_{counter}{extension}"
new_file_path = dir / new_filename
if not new_file_path.exists():
return Path(new_file_path)
counter += 1
3 changes: 1 addition & 2 deletions py_neuromodulation/nm_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ def plot_all_features(

@staticmethod
def get_performace_sub_strip(performance_sub: dict, plt_grid: bool = False):

ecog_strip_performance = []
ecog_coords_strip = []
cortex_grid = []
Expand Down Expand Up @@ -967,8 +966,8 @@ def set_score(
nm_IO.save_general_dict(
dict_=performance_dict,
path_out=PATH_OUT,
prefix=folder_name,
str_add=str_add,
folder_name=folder_name,
)
return performance_dict

Expand Down
17 changes: 11 additions & 6 deletions py_neuromodulation/nm_bispectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def test_range(cls, filter_range):
), f"second frequency range value needs to be higher than first one, got {filter_range}"
return filter_range

@field_validator("frequency_bands")
def fbands_spaces_to_underscores(cls, frequency_bands):
return [f.replace(" ", "_") for f in frequency_bands]


FEATURE_DICT: dict[str, Callable] = {
"mean": np.nanmean,
Expand Down Expand Up @@ -83,7 +87,7 @@ def __init__(
self.max_freq = max(
self.settings.f1s.frequency_high_hz, self.settings.f2s.frequency_high_hz
)

# self.freqs: np.ndarray = np.array([]) # In case we pre-computed this

def calc_feature(self, data: np.ndarray) -> dict:
Expand Down Expand Up @@ -115,16 +119,17 @@ def calc_feature(self, data: np.ndarray) -> dict:
sampling_freq=self.sfreq,
verbose=False,
)

waveshape.compute(
f1s=tuple(self.settings.f1s), # type: ignore
f2s=tuple(self.settings.f2s), # type: ignore
)

feature_results = {}
for ch_idx, ch_name in enumerate(self.ch_names):

bispectrum = waveshape._bicoherence[ch_idx] # Same as waveshape.results._data, skips a copy
bispectrum = waveshape._bicoherence[
ch_idx
] # Same as waveshape.results._data, skips a copy

for component in self.settings.components.get_enabled():
spectrum_ch = COMPONENT_DICT[component](bispectrum)
Expand All @@ -146,4 +151,4 @@ def calc_feature(self, data: np.ndarray) -> dict:
f"{ch_name}_Bispectrum_{component}_{bispectrum_feature}_whole_fband_range"
] = FEATURE_DICT[bispectrum_feature](spectrum_ch)

return feature_results
return feature_results
25 changes: 15 additions & 10 deletions py_neuromodulation/nm_bursts.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np

if np.__version__ >= "2.0.0":
from numpy.lib._function_base_impl import _quantile as np_quantile # type:ignore
else:
from numpy.lib.function_base import _quantile as np_quantile # type:ignore
from collections.abc import Sequence
from itertools import product

from pydantic import Field
from pydantic import Field, field_validator
from py_neuromodulation.nm_types import BoolSelector, NMBaseModel
from py_neuromodulation.nm_features import NMFeature

Expand All @@ -20,7 +21,6 @@


def get_label_pos(burst_labels, valid_labels):

max_label = np.max(burst_labels, axis=2).flatten()
min_label = np.min(
burst_labels, axis=2, initial=LARGE_NUM, where=burst_labels != 0
Expand Down Expand Up @@ -48,9 +48,13 @@ class BurstFeatures(BoolSelector):
class BurstSettings(NMBaseModel):
threshold: float = Field(default=75, ge=0, le=100)
time_duration_s: float = Field(default=30, ge=0)
frequency_bands: list[str] = ["low beta", "high beta", "low gamma"]
frequency_bands: list[str] = ["low_beta", "high_beta", "low_gamma"]
burst_features: BurstFeatures = BurstFeatures()

@field_validator("frequency_bands")
def fbands_spaces_to_underscores(cls, frequency_bands):
return [f.replace(" ", "_") for f in frequency_bands]


class Burst(NMFeature):
def __init__(
Expand Down Expand Up @@ -182,12 +186,11 @@ def calc_feature(self, data: np.ndarray) -> dict:
]

# Find (channel, band) coordinates for each valid label and get an array that maps each valid label to its channel/band
# Channel band coordinate is flattened to a 1D array of length (n_channels x n_fbands)
# Channel band coordinate is flattened to a 1D array of length (n_channels x n_fbands)
label_positions = get_label_pos(burst_labels, valid_labels)

# Now we're ready to calculate features


if "duration" in self.used_features or "burst_rate_per_s" in self.used_features:
# Handle division by zero using np.divide. Where num_bursts is 0, the result is 0
self.burst_duration_mean = (
Expand All @@ -202,7 +205,9 @@ def calc_feature(self, data: np.ndarray) -> dict:

if "duration" in self.used_features:
# First get burst length for each valid burst
burst_lengths = label_sum(bursts, burst_labels, index=valid_labels) / self.sfreq
burst_lengths = (
label_sum(bursts, burst_labels, index=valid_labels) / self.sfreq
)

# Now the max needs to be calculated per channel/band
# For that, loop over channels/bands, get the corresponding burst lengths, and get the max
Expand All @@ -221,7 +226,7 @@ def calc_feature(self, data: np.ndarray) -> dict:
if "amplitude" in self.used_features:
# Max amplitude is just the max of the filtered data where there is a burst
self.burst_amplitude_max = (filtered_data * bursts).max(axis=2)

# The mean is actually a mean of means, so we need the mean for each individual burst
label_means = label_mean(filtered_data, burst_labels, index=valid_labels)
# Now, loop over channels/bands, get the corresponding burst means, and calculate the mean of means
Expand Down Expand Up @@ -266,9 +271,9 @@ def store_duration(
def store_amplitude(
self, feature_results: dict, ch_i: int, ch: str, fb_i: int, fb: str
):
feature_results[f"{ch}_bursts_{fb}_amplitude_mean"] = (
self.burst_amplitude_mean[ch_i, fb_i]
)
feature_results[f"{ch}_bursts_{fb}_amplitude_mean"] = self.burst_amplitude_mean[
ch_i, fb_i
]
feature_results[f"{ch}_bursts_{fb}_amplitude_max"] = self.burst_amplitude_max[
ch_i, fb_i
]
Expand Down
Loading

0 comments on commit 10f2945

Please sign in to comment.