Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MEGnet to make MNE-ICALabel work on MEG data #207

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ec28e4f
add megnet
colehank Oct 18, 2024
6f272b1
add megnet
colehank Oct 18, 2024
b3433c8
double check
colehank Oct 18, 2024
34c2f31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
5b7dc9c
bug fix
colehank Oct 18, 2024
96ed02d
Merge branch 'megnet' of https://github.com/colehank/mne-icalabel int…
colehank Oct 18, 2024
989cb40
topomaps plot modify & bug fix
colehank Oct 23, 2024
bc64aa2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
8af24b1
bug fix
colehank Oct 24, 2024
8f5e0e6
bug fix
colehank Oct 24, 2024
5aabe3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2024
bd7f8cc
bug fix
colehank Oct 24, 2024
59aedfb
bug fix
colehank Oct 24, 2024
067849c
Merge branch 'main' into megnet
colehank Oct 25, 2024
a0da5ee
bug fix
colehank Oct 25, 2024
143df13
:q!Merge branch 'megnet' of https://github.com/colehank/mne-icalabel …
colehank Oct 25, 2024
58a719a
more validation of raw obejct
colehank Oct 29, 2024
b89c864
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
8465017
bug fix
colehank Oct 29, 2024
a0e526d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
40b1074
bug fix
colehank Oct 29, 2024
49f39d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
4f2d43a
fix model path discovery and include assets in package
mscheltienne Nov 7, 2024
bbce3cc
improve docstrings
mscheltienne Nov 7, 2024
19e0260
simplify and test validation of line noise
mscheltienne Nov 7, 2024
c47d582
clean-up utils
mscheltienne Nov 7, 2024
686fda1
test scripts update
colehank Nov 14, 2024
c5897e1
Merge branch 'main' into megnet
colehank Nov 14, 2024
88d0619
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
7d0a7e2
Merge branch 'main' into megnet
colehank Nov 14, 2024
25824fb
bug fix
colehank Nov 14, 2024
036599b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
30cc94a
bug fix
colehank Nov 14, 2024
7922f2f
test file fixing
colehank Dec 27, 2024
ce6e82c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 27, 2024
4c672b2
add test fuc's doctring
colehank Dec 27, 2024
401740d
Merge branch 'main' into megnet
colehank Dec 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mne_icalabel/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def pytest_configure(config):
ignore:Python 3\.14 will, by default, filter extracted tar.*:DeprecationWarning
# onnxruntime on windows runners
ignore:Unsupported Windows version.*:UserWarning
# Matplotlib deprecation issued in VSCode test debugger
ignore:.*interactive_bk.*:matplotlib._api.deprecation.MatplotlibDeprecationWarning
"""
for warning_line in warnings_lines.split("\n"):
warning_line = warning_line.strip()
Expand Down
Empty file added mne_icalabel/megnet/__init__.py
Empty file.
51 changes: 51 additions & 0 deletions mne_icalabel/megnet/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
from numpy.typing import NDArray


def _cart2sph(x, y, z):
xy = np.sqrt(x * x + y * y)
r = np.sqrt(x * x + y * y + z * z)
theta = np.arctan2(y, x)
phi = np.arctan2(z, xy)
return r, theta, phi


def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple) -> dict:
"""Generate head outlines for topomap plotting.

This is a modified version of mne.viz.topomap._make_head_outlines.
The difference between this function and the original one is that
head_x and head_y here are scaled by a factor of 1.01 to make topomap
fit the 120x120 pixel size.
Also, removed the ear and nose outlines for not needed in MEGnet.

Parameters
----------
sphere : NDArray
The sphere parameters (x, y, z, radius).
pos : NDArray
The 2D sensor positions.
clip_origin : tuple
The origin of the clipping circle.

Returns
-------
dict
Dictionary containing the head outlines and mask positions.

"""
x, y, _, radius = sphere
ll = np.linspace(0, 2 * np.pi, 101)
head_x = np.cos(ll) * radius * 1.01 + x
head_y = np.sin(ll) * radius * 1.01 + y

mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius)
clip_radius = radius * mask_scale

outlines_dict = {
"head": (head_x, head_y),
"mask_pos": (mask_scale * head_x, mask_scale * head_y),
"clip_radius": (clip_radius,) * 2,
"clip_origin": clip_origin,
}
return outlines_dict
Binary file added mne_icalabel/megnet/assets/megnet.onnx
Binary file not shown.
217 changes: 217 additions & 0 deletions mne_icalabel/megnet/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import io

import matplotlib.pyplot as plt
import mne
import numpy as np
from mne.io import BaseRaw
from mne.preprocessing import ICA
from mne.utils import _validate_type, warn
from numpy.typing import NDArray
from PIL import Image
from scipy import interpolate
from scipy.spatial import ConvexHull

from mne_icalabel.iclabel._utils import _pol2cart

from ._utils import _cart2sph, _make_head_outlines


def get_megnet_features(raw: BaseRaw, ica: ICA):
"""Extract time series and topomaps for each ICA component.

MEGNet uses topomaps from BrainStorm exported as 120x120x3 RGB images.
Thus, we need to replicate the 'appearance'/'look' of a BrainStorm topomap.

Parameters
----------
raw : Raw
Raw MEG recording used to fit the ICA decomposition.
The raw instance should be bandpass filtered between
1 and 100 Hz and notch filtered at 50 or 60 Hz to
remove line noise, and downsampled to 250 Hz.
ica : ICA
ICA decomposition of the provided instance.
The ICA decomposition should use the infomax method.

Returns
-------
time_series : array of shape (n_components, n_samples)
The time series for each ICA component.
topomaps : array of shape (n_components, 120, 120, 3)
The topomap RGB images for each ICA component.
"""
_validate_type(raw, BaseRaw, "raw")
_validate_type(ica, ICA, "ica")
if not any(
ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(unique=True)
):
raise RuntimeError(
"Could not find MEG channels in the provided Raw instance."
"The MEGnet model was fitted on MEG data and is not"
"suited for other types of channels."
)
if (n_samples := raw.get_data().shape[1]) < 15000:
raise RuntimeError(
f"The provided raw instance has {n_samples} points. "
"MEGnet was designed to classify features extracted "
"from an MEG dataset at least 60 seconds long @ 250 Hz,"
"corresponding to at least. 15 000 samples."
)
if not np.isclose(raw.info["sfreq"], 250, atol=1e-1):
warn(
"The provided raw instance is not sampled at 250 Hz "
f"(sfreq={raw.info['sfreq']} Hz). "
"MEGnet was designed to classify features extracted from"
"an MEG dataset sampled at 250 Hz "
"(see the 'resample()' method for Raw instances). "
"The classification performance might be negatively impacted."
)
if raw.info["highpass"] != 1 or raw.info["lowpass"] != 100:
warn(
"The provided raw instance is not filtered between 1 and 100 Hz. "
"MEGnet was designed to classify features extracted from an MEG "
"dataset bandpass filtered between 1 and 100 Hz"
" (see the 'filter()' method for Raw instances)."
" The classification performance might be negatively impacted."
)
if _check_line_noise(raw):
warn(
"Line noise detected in 50/60 Hz. MEGnet was trained on"
"MEG data without line noise. Please remove line noise"
"before using MEGnet (see the 'notch_filter()' method"
"for Raw instances)."
)
if ica.method != "infomax":
warn(
f"The provided ICA instance was fitted with '{ica.method}'."
"MEGnet was designed with infomax method."
"To use the it, set mne.preprocessing.ICA instance with "
"the arguments ICA(method='infomax')."
)
if ica.n_components != 20:
warn(
f"The provided ICA instance has {ica.n_components} components. "
"MEGnet was designed with 20 components. "
"use mne.preprocessing.ICA instance with "
"the arguments ICA(n_components=20)."
)

mscheltienne marked this conversation as resolved.
Show resolved Hide resolved
pos_new, outlines = _get_topomaps_data(ica)
topomaps = _get_topomaps(ica, pos_new, outlines)
time_series = ica.get_sources(raw).get_data()
return time_series, topomaps


def _get_topomaps_data(ica: ICA):
"""Prepare 2D sensor positions and outlines for topomap plotting."""
mags = mne.pick_types(ica.info, meg="mag")
channel_info = ica.info["chs"]
loc_3d = [channel_info[i]["loc"][0:3] for i in mags]
channel_locations_3d = np.array(loc_3d)

# Convert to spherical and then to 2D
sph_coords = np.transpose(
_cart2sph(
channel_locations_3d[:, 0],
channel_locations_3d[:, 1],
channel_locations_3d[:, 2],
)
)
TH, PHI = sph_coords[:, 1], sph_coords[:, 2]
newR = 1 - PHI / np.pi * 2
channel_locations_2d = np.transpose(_pol2cart(TH, newR))

# Adjust coordinates with convex hull interpolation
hull = ConvexHull(channel_locations_2d)
border_indices = hull.vertices
Dborder = 1 / newR[border_indices]

funcTh = np.hstack(
[
TH[border_indices] - 2 * np.pi,
TH[border_indices],
TH[border_indices] + 2 * np.pi,
]
)
funcD = np.hstack((Dborder, Dborder, Dborder))
interp_func = interpolate.interp1d(funcTh, funcD)
D = interp_func(TH)

adjusted_R = np.array([min(newR[i] * D[i], 1) for i in range(len(mags))])
Xnew, Ynew = _pol2cart(TH, adjusted_R)
pos_new = np.vstack((Xnew, Ynew)).T

outlines = _make_head_outlines(np.array([0, 0, 0, 1]), pos_new, (0, 0))
return pos_new, outlines


def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict):
"""Generate topomap images for each ICA component."""
topomaps = []
data_picks = mne.pick_types(ica.info, meg="mag")
components = ica.get_components()

for comp in range(ica.n_components_):
data = components[data_picks, comp]
fig = plt.figure(figsize=(1.3, 1.3), dpi=100, facecolor="black")
ax = fig.add_subplot(111)
mnefig, _ = mne.viz.plot_topomap(
data,
pos_new,
sensors=False,
outlines=outlines,
extrapolate="head",
sphere=[0, 0, 0, 1],
contours=0,
res=120,
axes=ax,
show=False,
cmap="bwr",
)
img_buf = io.BytesIO()
mnefig.figure.savefig(
img_buf, format="png", dpi=120, bbox_inches="tight", pad_inches=0
)
img_buf.seek(0)
rgba_image = Image.open(img_buf)
rgb_image = rgba_image.convert("RGB")
img_buf.close()
plt.close(fig)
mscheltienne marked this conversation as resolved.
Show resolved Hide resolved

topomaps.append(np.array(rgb_image))

return np.array(topomaps)


def _check_line_noise(
raw: BaseRaw, *, neighbor_width: int = 4, threshold_factor: int = 10
) -> bool:
"""Check if line noise is present in the MEG/EEG data."""
# we don't know the line frequency
if raw.info.get("line_freq", None) is None:
return False
# validate the primary and first harmonic frequencies
nyquist_freq = raw.info["sfreq"] / 2.0
line_freqs = [raw.info["line_freq"], 2 * raw.info["line_freq"]]
if any(nyquist_freq < lf for lf in line_freqs):
# not raising because if we get here,
# it means that someone provided a raw with
# a sampling rate extremely low (100 Hz?) and (1)
# either they missed all of the previous warnings
# encountered or (2) they know what they are doing.
warn("The sampling rate raw.info['sfreq'] is too low" "to estimate line niose.")
return False
# compute the power spectrum and retrieve the frequencies of interest
spectrum = raw.compute_psd(picks="meg", exclude="bads")
data, freqs = spectrum.get_data(
fmin=raw.info["line_freq"] - neighbor_width,
fmax=raw.info["line_freq"] + neighbor_width,
return_freqs=True,
) # array of shape (n_good_channel, n_freqs)
idx = np.argmin(np.abs(freqs - raw.info["line_freq"]))
mask = np.ones(data.shape[1], dtype=bool)
mask[idx] = False
background_mean = np.mean(data[:, mask], axis=1)
background_std = np.std(data[:, mask], axis=1)
threshold = background_mean + threshold_factor * background_std
return np.any(data[:, idx] > threshold)
107 changes: 107 additions & 0 deletions mne_icalabel/megnet/label_components.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from importlib.resources import files

import numpy as np
import onnxruntime as ort
from mne.io import BaseRaw
from mne.preprocessing import ICA
from numpy.typing import NDArray

from .features import get_megnet_features

_MODEL_PATH: str = files("mne_icalabel.megnet") / "assets" / "megnet.onnx"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line causes issues.

See:

network_file = files("mne_icalabel.iclabel.network") / "assets" / "ICLabelNet.onnx"

Let us know if that doesn't work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think he modified anything from iclabel, that should not be failing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes nvm. I misread the error.

Hmm it seems the file is somehow not found.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I found the issue! After I added the __init__.py file in the mne_icalabel/megnet directory, it was able to correctly find megnet.onnx. (but I remember that python(at least v3.8) don't require init.py to treat a folder as a package.)



def megnet_label_components(raw: BaseRaw, ica: ICA) -> NDArray:
"""Label the provided ICA components with the MEGnet neural network.

Parameters
----------
raw : Raw
Raw MEG recording used to fit the ICA decomposition.
The raw instance should be bandpass filtered between 1 and 100 Hz
and notch filtered at 50 or 60 Hz to remove line noise,
and downsampled to 250 Hz.
ica : ICA
ICA decomposition of the provided instance.
The ICA decomposition should use the infomax method.

Returns
-------
labels_pred_proba : numpy.ndarray of shape (n_components, n_classes)
The estimated corresponding predicted probabilities of output classes
for each independent component. Columns are ordered with
'brain/other', 'eye movement', 'heart beat', 'eye blink',

"""
time_series, topomaps = get_megnet_features(raw, ica)

# sanity-checks
# number of time-series <-> topos
assert time_series.shape[0] == topomaps.shape[0]
# topos are images of shape 120x120x3
assert topomaps.shape[1:] == (120, 120, 3)
# minimum time-series length
assert 15000 <= time_series.shape[1]

session = ort.InferenceSession(_MODEL_PATH)
labels_pred_proba = _chunk_predicting(session, time_series, topomaps)
return labels_pred_proba[:, 0, :]


def _chunk_predicting(
session: ort.InferenceSession,
time_series: NDArray,
spatial_maps: NDArray,
chunk_len=15000,
overlap_len=3750,
) -> NDArray:
"""MEGnet's chunk volte algorithm."""
predction_vote = []

for comp_series, comp_map in zip(time_series, spatial_maps):
time_len = comp_series.shape[0]
start_times = _get_chunk_start(time_len, chunk_len, overlap_len)

if start_times[-1] + chunk_len <= time_len:
start_times.append(time_len - chunk_len)

chunk_votes = {start: 0 for start in start_times}
for t in range(time_len):
in_chunks = [start <= t < start + chunk_len for start in start_times]
# how many chunks the time point is in
num_chunks = np.sum(in_chunks)
for start_time, is_in_chunk in zip(start_times, in_chunks):
if is_in_chunk:
chunk_votes[start_time] += 1.0 / num_chunks

weighted_predictions = {}
for start_time in chunk_votes.keys():
onnx_inputs = {
session.get_inputs()[0].name: np.expand_dims(comp_map, 0).astype(
np.float32
),
session.get_inputs()[1].name: np.expand_dims(
np.expand_dims(comp_series[start_time : start_time + chunk_len], 0),
-1,
).astype(np.float32),
}
prediction = session.run(None, onnx_inputs)[0]
weighted_predictions[start_time] = prediction * chunk_votes[start_time]

comp_prediction = np.stack(list(weighted_predictions.values())).mean(axis=0)
comp_prediction /= comp_prediction.sum()
predction_vote.append(comp_prediction)

return np.stack(predction_vote)


def _get_chunk_start(
input_len: int, chunk_len: int = 15000, overlap_len: int = 3750
) -> list:
"""Calculate start times for time series chunks with overlap."""
start_times = []
start_time = 0
while start_time + chunk_len <= input_len:
start_times.append(start_time)
start_time += chunk_len - overlap_len
return start_times
Empty file.
Loading