Skip to content

Commit

Permalink
Merge branch 'megnet' of https://github.com/colehank/mne-icalabel int…
Browse files Browse the repository at this point in the history
…o megnet
  • Loading branch information
colehank committed Oct 18, 2024
2 parents 5b7dc9c + 34c2f31 commit 96ed02d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 29 deletions.
14 changes: 7 additions & 7 deletions mne_icalabel/megnet/features.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import io

import matplotlib.pyplot as plt
import mne # type: ignore
import mne # type: ignore
import numpy as np
from mne.io import BaseRaw # type: ignore
from mne.preprocessing import ICA # type: ignore
from mne.utils import warn # type: ignore
from _utils import cart2sph, pol2cart
from mne.io import BaseRaw # type: ignore
from mne.preprocessing import ICA # type: ignore
from mne.utils import warn # type: ignore
from numpy.typing import NDArray
from PIL import Image
from scipy import interpolate # type: ignore
from scipy.spatial import ConvexHull # type: ignore

from scipy import interpolate # type: ignore
from scipy.spatial import ConvexHull # type: ignore
from ._utils import cart2sph, pol2cart


Expand Down
35 changes: 13 additions & 22 deletions mne_icalabel/megnet/label_componets.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def _chunk_predicting(
overlap_len=3750,
) -> NDArray:
"""Predict the labels for each component using
MEGnet's chunk volte algorithm."""
MEGnet's chunk volte algorithm.
"""
predction_vote = []

for comp_series, comp_map in zip(time_series, spatial_maps):
Expand All @@ -79,9 +80,7 @@ def _chunk_predicting(

chunk_votes = {start: 0 for start in start_times}
for t in range(time_len):
in_chunks = (
[start <= t < start + chunk_len for start in start_times]
)
in_chunks = [start <= t < start + chunk_len for start in start_times]
# how many chunks the time point is in
num_chunks = np.sum(in_chunks)
for start_time, is_in_chunk in zip(start_times, in_chunks):
Expand All @@ -91,22 +90,18 @@ def _chunk_predicting(
weighted_predictions = {}
for start_time in chunk_votes.keys():
onnx_inputs = {
session.get_inputs()[0].name:
np.expand_dims(comp_map, 0).astype(np.float32),
session.get_inputs()[1].name:
np.expand_dims(
np.expand_dims(
comp_series[start_time: start_time + chunk_len], 0
), -1).astype(np.float32),
session.get_inputs()[0].name: np.expand_dims(comp_map, 0).astype(
np.float32
),
session.get_inputs()[1].name: np.expand_dims(
np.expand_dims(comp_series[start_time : start_time + chunk_len], 0),
-1,
).astype(np.float32),
}
prediction = session.run(None, onnx_inputs)[0]
weighted_predictions[start_time] = (
prediction * chunk_votes[start_time]
)
weighted_predictions[start_time] = prediction * chunk_votes[start_time]

comp_prediction = np.stack(
list(weighted_predictions.values())
).mean(axis=0)
comp_prediction = np.stack(list(weighted_predictions.values())).mean(axis=0)
comp_prediction /= comp_prediction.sum()
predction_vote.append(comp_prediction)

Expand All @@ -129,21 +124,17 @@ def _get_chunk_start(

if __name__ == "__main__":
import mne

from features import get_megnet_features

sample_dir = mne.datasets.sample.data_path()
sample_fname = sample_dir / "MEG" / "sample" / "sample_audvis_raw.fif"
raw = mne.io.read_raw_fif(sample_fname).pick("mag")
raw.resample(250)
raw.filter(1, 100)
ica = mne.preprocessing.ICA(
n_components=20, max_iter="auto", method="infomax"
)
ica = mne.preprocessing.ICA(n_components=20, max_iter="auto", method="infomax")
ica.fit(raw)

res = megnet_label_components(raw, ica)
print(res)
ica.exclude = [i for i, label in enumerate(res["labels"]) if label != "brain/other"]
ica.plot_components()

0 comments on commit 96ed02d

Please sign in to comment.