From 34c2f31a8a6e688599e65b42c4e5d27798e138df Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Oct 2024 08:02:28 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mne_icalabel/megnet/features.py | 15 ++++++----- mne_icalabel/megnet/label_componets.py | 35 ++++++++++---------------- 2 files changed, 20 insertions(+), 30 deletions(-) diff --git a/mne_icalabel/megnet/features.py b/mne_icalabel/megnet/features.py index 62831e48c..608ab0b0e 100644 --- a/mne_icalabel/megnet/features.py +++ b/mne_icalabel/megnet/features.py @@ -1,17 +1,16 @@ import io import matplotlib.pyplot as plt -import mne # type: ignore +import mne # type: ignore import numpy as np -from mne.io import BaseRaw # type: ignore -from mne.preprocessing import ICA # type: ignore -from mne.utils import warn # type: ignore +from _utils import cart2sph, pol2cart +from mne.io import BaseRaw # type: ignore +from mne.preprocessing import ICA # type: ignore +from mne.utils import warn # type: ignore from numpy.typing import NDArray from PIL import Image -from scipy import interpolate # type: ignore -from scipy.spatial import ConvexHull # type: ignore - -from _utils import cart2sph, pol2cart +from scipy import interpolate # type: ignore +from scipy.spatial import ConvexHull # type: ignore def get_megnet_features(raw: BaseRaw, ica: ICA): diff --git a/mne_icalabel/megnet/label_componets.py b/mne_icalabel/megnet/label_componets.py index 3ca59f9f8..dd1257a7d 100644 --- a/mne_icalabel/megnet/label_componets.py +++ b/mne_icalabel/megnet/label_componets.py @@ -67,7 +67,8 @@ def _chunk_predicting( overlap_len=3750, ) -> NDArray: """Predict the labels for each component using - MEGnet's chunk volte algorithm.""" + MEGnet's chunk volte algorithm. + """ predction_vote = [] for comp_series, comp_map in zip(time_series, spatial_maps): @@ -79,9 +80,7 @@ def _chunk_predicting( chunk_votes = {start: 0 for start in start_times} for t in range(time_len): - in_chunks = ( - [start <= t < start + chunk_len for start in start_times] - ) + in_chunks = [start <= t < start + chunk_len for start in start_times] # how many chunks the time point is in num_chunks = np.sum(in_chunks) for start_time, is_in_chunk in zip(start_times, in_chunks): @@ -91,22 +90,18 @@ def _chunk_predicting( weighted_predictions = {} for start_time in chunk_votes.keys(): onnx_inputs = { - session.get_inputs()[0].name: - np.expand_dims(comp_map, 0).astype(np.float32), - session.get_inputs()[1].name: - np.expand_dims( - np.expand_dims( - comp_series[start_time: start_time + chunk_len], 0 - ), -1).astype(np.float32), + session.get_inputs()[0].name: np.expand_dims(comp_map, 0).astype( + np.float32 + ), + session.get_inputs()[1].name: np.expand_dims( + np.expand_dims(comp_series[start_time : start_time + chunk_len], 0), + -1, + ).astype(np.float32), } prediction = session.run(None, onnx_inputs)[0] - weighted_predictions[start_time] = ( - prediction * chunk_votes[start_time] - ) + weighted_predictions[start_time] = prediction * chunk_votes[start_time] - comp_prediction = np.stack( - list(weighted_predictions.values()) - ).mean(axis=0) + comp_prediction = np.stack(list(weighted_predictions.values())).mean(axis=0) comp_prediction /= comp_prediction.sum() predction_vote.append(comp_prediction) @@ -129,7 +124,6 @@ def _get_chunk_start( if __name__ == "__main__": import mne - from features import get_megnet_features sample_dir = mne.datasets.sample.data_path() @@ -137,13 +131,10 @@ def _get_chunk_start( raw = mne.io.read_raw_fif(sample_fname).pick("mag") raw.resample(250) raw.filter(1, 100) - ica = mne.preprocessing.ICA( - n_components=20, max_iter="auto", method="infomax" - ) + ica = mne.preprocessing.ICA(n_components=20, max_iter="auto", method="infomax") ica.fit(raw) res = megnet_label_components(raw, ica) print(res) ica.exclude = [i for i, label in enumerate(res["labels"]) if label != "brain/other"] ica.plot_components() -