Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add MEGnet to make MNE-ICALabel work on MEG data #207

Open
wants to merge 37 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
ec28e4f
add megnet
colehank Oct 18, 2024
6f272b1
add megnet
colehank Oct 18, 2024
b3433c8
double check
colehank Oct 18, 2024
34c2f31
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 18, 2024
5b7dc9c
bug fix
colehank Oct 18, 2024
96ed02d
Merge branch 'megnet' of https://github.com/colehank/mne-icalabel int…
colehank Oct 18, 2024
989cb40
topomaps plot modify & bug fix
colehank Oct 23, 2024
bc64aa2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
8af24b1
bug fix
colehank Oct 24, 2024
8f5e0e6
bug fix
colehank Oct 24, 2024
5aabe3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2024
bd7f8cc
bug fix
colehank Oct 24, 2024
59aedfb
bug fix
colehank Oct 24, 2024
067849c
Merge branch 'main' into megnet
colehank Oct 25, 2024
a0da5ee
bug fix
colehank Oct 25, 2024
143df13
:q!Merge branch 'megnet' of https://github.com/colehank/mne-icalabel …
colehank Oct 25, 2024
58a719a
more validation of raw obejct
colehank Oct 29, 2024
b89c864
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
8465017
bug fix
colehank Oct 29, 2024
a0e526d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
40b1074
bug fix
colehank Oct 29, 2024
49f39d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 29, 2024
4f2d43a
fix model path discovery and include assets in package
mscheltienne Nov 7, 2024
bbce3cc
improve docstrings
mscheltienne Nov 7, 2024
19e0260
simplify and test validation of line noise
mscheltienne Nov 7, 2024
c47d582
clean-up utils
mscheltienne Nov 7, 2024
686fda1
test scripts update
colehank Nov 14, 2024
c5897e1
Merge branch 'main' into megnet
colehank Nov 14, 2024
88d0619
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
7d0a7e2
Merge branch 'main' into megnet
colehank Nov 14, 2024
25824fb
bug fix
colehank Nov 14, 2024
036599b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
30cc94a
bug fix
colehank Nov 14, 2024
7922f2f
test file fixing
colehank Dec 27, 2024
ce6e82c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 27, 2024
4c672b2
add test fuc's doctring
colehank Dec 27, 2024
401740d
Merge branch 'main' into megnet
colehank Dec 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions mne_icalabel/megnet/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# %%
import numpy as np


# Conversion functions
def cart2sph(x, y, z):
xy = np.sqrt(x * x + y * y)
r = np.sqrt(x * x + y * y + z * z)
theta = np.arctan2(y, x)
phi = np.arctan2(z, xy)
return r, theta, phi


def pol2cart(rho, phi):
x = rho * np.cos(phi)
y = rho * np.sin(phi)
return x, y
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions are also defined for ICLabel, so I wonder if we can pull them out for a general utility functions related to geometry.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pol2cart is indeed a duplicate, cart2sph returns the element in a different order and it's a bit annoying to change the order. We can keep code de-duplication for a future PR.



def make_head_outlines(sphere, pos, outlines, clip_origin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write a docstring for the function?

assert isinstance(sphere, np.ndarray)
x, y, _, radius = sphere
del sphere

ll = np.linspace(0, 2 * np.pi, 101)
head_x = np.cos(ll) * radius * 1.01 + x
head_y = np.sin(ll) * radius * 1.01 + y
dx = np.exp(np.arccos(np.deg2rad(12)) * 1j)
dx, _ = dx.real, dx.imag

outlines_dict = dict(head=(head_x, head_y))

mask_scale = 1.0
max_norm = np.linalg.norm(pos, axis=1).max()
mask_scale = max(mask_scale, max_norm * 1.01 / radius)

outlines_dict["mask_pos"] = (mask_scale * head_x, mask_scale * head_y)
clip_radius = radius * mask_scale
outlines_dict["clip_radius"] = (clip_radius,) * 2
outlines_dict["clip_origin"] = clip_origin

outlines = outlines_dict

return outlines
Binary file added mne_icalabel/megnet/assets/network/megnet.onnx
Binary file not shown.
291 changes: 291 additions & 0 deletions mne_icalabel/megnet/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
import io

import matplotlib.pyplot as plt
import mne # type: ignore
import numpy as np
from mne.io import BaseRaw # type: ignore
from mne.preprocessing import ICA # type: ignore
from mne.utils import _validate_type, warn # type: ignore
from numpy.typing import NDArray
from PIL import Image
from scipy import interpolate # type: ignore
from scipy.spatial import ConvexHull # type: ignore

from ._utils import cart2sph, pol2cart


def get_megnet_features(raw: BaseRaw, ica: ICA):
"""Extract time series and topomaps for each ICA component.

the main work is focused on making BrainStorm-like topomaps
which trained the MEGnet.

Parameters
----------
raw : BaseRaw
The raw MEG data. The raw instance should have 250 Hz
sampling frequency and more than 60 seconds.
ica : ICA
The ICA object containing the independent components.

Returns
-------
time_series : np.ndarray
mscheltienne marked this conversation as resolved.
Show resolved Hide resolved
The time series for each ICA component.
topomaps : np.ndarray
mscheltienne marked this conversation as resolved.
Show resolved Hide resolved
The topomaps for each ICA component

"""
_validate_type(raw, BaseRaw, "raw")

if not any(
ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(unique=True)
):
raise RuntimeError(
"Could not find MEG channels in the provided "
"Raw instance. The MEGnet model was fitted on"
"MEG data and is not suited for other types of channels."
)

if raw.times[-1] < 60:
raise RuntimeError(
f"The provided raw instance has {raw.times[-1]} seconds. "
"MEGnet was designed to classify features extracted from "
"an MEG datasetat least 60 seconds long. "
)

if _check_notch(raw, "mag"):
raise RuntimeError(
"Line noise detected in 50/60 Hz. "
"MEGnet was trained on MEG data without line noise. "
"Please remove line noise before using MEGnet."
"see the 'notch_filter()' method for Raw instances."
)

if not np.isclose(raw.info["sfreq"], 250, atol=1e-1):
warn(
"The provided raw instance is not sampled at 250 Hz"
f"(sfreq={raw.info['sfreq']} Hz). "
"MEGnet was designed to classify features extracted from"
"an MEG dataset sampled at 250 Hz"
"(see the 'resample()' method for raw)."
"The classification performance might be negatively impacted."
)

mscheltienne marked this conversation as resolved.
Show resolved Hide resolved
if raw.info["highpass"] != 1 or raw.info["lowpass"] != 100:
warn(
"The provided raw instance is not filtered between 1 and 100 Hz. "
"MEGnet was designed to classify features extracted from an MEG dataset "
"bandpass filtered between 1 and 100 Hz (see the 'filter()' method for Raw."
)

if ica.method != "infomax":
warn(
f"The provided ICA instance was fitted with a '{ica.method}' algorithm. "
"MEGnet was designed with infomax ICA decompositions. To use the "
"infomax algorithm, use the 'mne.preprocessing.ICA' instance with "
"the arguments 'ICA(method='infomax')"
)

pos_new, outlines = _get_topomaps_data(ica)
topomaps = _get_topomaps(ica, pos_new, outlines)
time_series = ica.get_sources(raw)._data

return time_series, topomaps


def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple):
"""Generate head outlines and mask positions for the topomap plot."""
x, y, _, radius = sphere
ll = np.linspace(0, 2 * np.pi, 101)
head_x = np.cos(ll) * radius * 1.01 + x
head_y = np.sin(ll) * radius * 1.01 + y

mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius)
clip_radius = radius * mask_scale

outlines_dict = {
"head": (head_x, head_y),
"mask_pos": (mask_scale * head_x, mask_scale * head_y),
"clip_radius": (clip_radius,) * 2,
"clip_origin": clip_origin,
}
return outlines_dict


def _get_topomaps_data(ica: ICA):
"""Prepare 2D sensor positions and outlines for topomap plotting."""
mags = mne.pick_types(ica.info, meg="mag")
channel_info = ica.info["chs"]
loc_3d = [channel_info[i]["loc"][0:3] for i in mags]
channel_locations_3d = np.array(loc_3d)

# Convert to spherical and then to 2D
sph_coords = np.transpose(
cart2sph(
channel_locations_3d[:, 0],
channel_locations_3d[:, 1],
channel_locations_3d[:, 2],
)
)
TH, PHI = sph_coords[:, 1], sph_coords[:, 2]
newR = 1 - PHI / np.pi * 2
channel_locations_2d = np.transpose(pol2cart(newR, TH))

# Adjust coordinates with convex hull interpolation
hull = ConvexHull(channel_locations_2d)
border_indices = hull.vertices
Dborder = 1 / newR[border_indices]

funcTh = np.hstack(
[
TH[border_indices] - 2 * np.pi,
TH[border_indices],
TH[border_indices] + 2 * np.pi,
]
)
funcD = np.hstack((Dborder, Dborder, Dborder))
interp_func = interpolate.interp1d(funcTh, funcD)
D = interp_func(TH)

adjusted_R = np.array([min(newR[i] * D[i], 1) for i in range(len(mags))])
Xnew, Ynew = pol2cart(adjusted_R, TH)
pos_new = np.vstack((Xnew, Ynew)).T

outlines = _make_head_outlines(np.array([0, 0, 0, 1]), pos_new, (0, 0))
return pos_new, outlines


def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict):
"""Generate topomap images for each ICA component."""
topomaps = []
data_picks = mne.pick_types(ica.info, meg="mag")
components = ica.get_components()

for comp in range(ica.n_components_):
data = components[data_picks, comp]
fig = plt.figure(figsize=(1.3, 1.3), dpi=100, facecolor="black")
ax = fig.add_subplot(111)
mnefig, _ = mne.viz.plot_topomap(
data,
pos_new,
sensors=False,
outlines=outlines,
extrapolate="head",
sphere=[0, 0, 0, 1],
contours=0,
res=120,
axes=ax,
show=False,
cmap="bwr",
)
img_buf = io.BytesIO()
mnefig.figure.savefig(
img_buf, format="png", dpi=120, bbox_inches="tight", pad_inches=0
)
img_buf.seek(0)
rgba_image = Image.open(img_buf)
rgb_image = rgba_image.convert("RGB")
img_buf.close()
plt.close(fig)
mscheltienne marked this conversation as resolved.
Show resolved Hide resolved

topomaps.append(np.array(rgb_image))

return np.array(topomaps)


def _line_noise_channel(
raw: BaseRaw,
picks: str,
fline: float = 50.0,
neighbor_width: float = 2.0,
threshold_factor: float = 3,
show: bool = False,
):
"""Detect line noise in MEG/EEG data.

Parameters
----------
raw : mne.io.Raw
The raw MEG/EEG data.
picks : str or list, optional
Channels to include in the analysis.
fline : float, optional
The base frequency of the line noise to detect.
neighbor_width : float, optional
Width of the frequency neighborhood around each harmonic
for calculating background noise, in Hz. Default is 2.0 Hz.
threshold_factor : float, optional
Multiplicative factor for setting the detection threshold.
The threshold is set as the mean of the neighboring frequencies plus
`threshold_factor` times the standard deviation. Default is 3.
show : bool, optional
Whether to plot the PSD for channels affected by line noise.

Returns
-------
bool
mscheltienne marked this conversation as resolved.
Show resolved Hide resolved
Returns True if line noise is detected in any channel, otherwise False.

"""
psd = raw.compute_psd(picks=picks)
freqs = psd.freqs
psds = psd.get_data()
ch_names = psd.ch_names

# Compute Nyquist frequency and determine maximum harmonic
nyquist_freq = raw.info["sfreq"] / 2.0
max_harmonic = int(nyquist_freq // fline)

# Generate list of harmonic frequencies based on the fundamental frequency
line_freqs = np.arange(fline, fline * (max_harmonic + 1), fline)
freq_res = freqs[1] - freqs[0]
n_neighbors = int(np.ceil(neighbor_width / freq_res))

line_noise_detected = []
for ch_idx in range(psds.shape[0]):
psd_ch = psds[ch_idx, :] # PSD for the current channel
for lf in line_freqs:
# Find the frequency index closest to the current harmonic
idx = np.argmin(np.abs(freqs - lf))
# Get index range for neighboring frequencies,
# excluding the harmonic frequency itself
idx_range = np.arange(
max(0, idx - n_neighbors), min(len(freqs), idx + n_neighbors + 1)
)
idx_neighbors = idx_range[idx_range != idx]
# Calculate mean and standard deviation of neighboring frequencies
neighbor_mean = np.mean(psd_ch[idx_neighbors])
neighbor_std = np.std(psd_ch[idx_neighbors])

threshold = neighbor_mean + threshold_factor * neighbor_std
if psd_ch[idx] > threshold:
line_noise_detected.append(
{"channel": ch_names[ch_idx], "frequency": lf}
)

if show and line_noise_detected:
affected_channels = set([item["channel"] for item in line_noise_detected])
plt.figure(figsize=(12, 3))
for ch_name in affected_channels:
ch_idx = ch_names.index(ch_name)
plt.semilogy(freqs, psds[ch_idx, :], label=ch_name)
plt.axvline(fline, color="k", linestyle="--", lw=3, alpha=0.3)
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power Spectral Density (PSD)")
plt.legend()
plt.show()

return line_noise_detected


def _check_notch(
raw: BaseRaw, picks: str, neighbor_width: float = 2.0, threshold_factor: float = 3
) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering why we don't just notch filter for them at the line_noise and its harmonics?

Would this impact the results?

I guess I can imagine a user constructing the Raw object via RawArray, which wouldn't need them to specify the line noise frequency, so maybe this wouldn't work all the time.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Elekta/MEGIN system do store the line noise frequency, but I don't know if this is the case for other systems. I would keep it is as it is now, with the responsibility to apply the notch filter left to the user.

"""Return True if line noise find in raw."""
check_result = False
for fline in [50, 60]:
if _line_noise_channel(raw, picks, fline, neighbor_width, threshold_factor):
check_result = True
break
return check_result
Loading