-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add MEGnet to make MNE-ICALabel work on MEG data #207
base: main
Are you sure you want to change the base?
Changes from 22 commits
ec28e4f
6f272b1
b3433c8
34c2f31
5b7dc9c
96ed02d
989cb40
bc64aa2
8af24b1
8f5e0e6
5aabe3f
bd7f8cc
59aedfb
067849c
a0da5ee
143df13
58a719a
b89c864
8465017
a0e526d
40b1074
49f39d1
4f2d43a
bbce3cc
19e0260
c47d582
686fda1
c5897e1
88d0619
7d0a7e2
25824fb
036599b
30cc94a
7922f2f
ce6e82c
4c672b2
401740d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# %% | ||
import numpy as np | ||
|
||
|
||
# Conversion functions | ||
def cart2sph(x, y, z): | ||
xy = np.sqrt(x * x + y * y) | ||
r = np.sqrt(x * x + y * y + z * z) | ||
theta = np.arctan2(y, x) | ||
phi = np.arctan2(z, xy) | ||
return r, theta, phi | ||
|
||
|
||
def pol2cart(rho, phi): | ||
x = rho * np.cos(phi) | ||
y = rho * np.sin(phi) | ||
return x, y | ||
|
||
|
||
def make_head_outlines(sphere, pos, outlines, clip_origin): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you write a docstring for the function? |
||
assert isinstance(sphere, np.ndarray) | ||
x, y, _, radius = sphere | ||
del sphere | ||
|
||
ll = np.linspace(0, 2 * np.pi, 101) | ||
head_x = np.cos(ll) * radius * 1.01 + x | ||
head_y = np.sin(ll) * radius * 1.01 + y | ||
dx = np.exp(np.arccos(np.deg2rad(12)) * 1j) | ||
dx, _ = dx.real, dx.imag | ||
|
||
outlines_dict = dict(head=(head_x, head_y)) | ||
|
||
mask_scale = 1.0 | ||
max_norm = np.linalg.norm(pos, axis=1).max() | ||
mask_scale = max(mask_scale, max_norm * 1.01 / radius) | ||
|
||
outlines_dict["mask_pos"] = (mask_scale * head_x, mask_scale * head_y) | ||
clip_radius = radius * mask_scale | ||
outlines_dict["clip_radius"] = (clip_radius,) * 2 | ||
outlines_dict["clip_origin"] = clip_origin | ||
|
||
outlines = outlines_dict | ||
|
||
return outlines |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
import io | ||
|
||
import matplotlib.pyplot as plt | ||
import mne # type: ignore | ||
import numpy as np | ||
from mne.io import BaseRaw # type: ignore | ||
from mne.preprocessing import ICA # type: ignore | ||
from mne.utils import _validate_type, warn # type: ignore | ||
from numpy.typing import NDArray | ||
from PIL import Image | ||
from scipy import interpolate # type: ignore | ||
from scipy.spatial import ConvexHull # type: ignore | ||
|
||
from ._utils import cart2sph, pol2cart | ||
|
||
|
||
def get_megnet_features(raw: BaseRaw, ica: ICA): | ||
"""Extract time series and topomaps for each ICA component. | ||
|
||
the main work is focused on making BrainStorm-like topomaps | ||
which trained the MEGnet. | ||
|
||
Parameters | ||
---------- | ||
raw : BaseRaw | ||
The raw MEG data. The raw instance should have 250 Hz | ||
sampling frequency and more than 60 seconds. | ||
ica : ICA | ||
The ICA object containing the independent components. | ||
|
||
Returns | ||
------- | ||
time_series : np.ndarray | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The time series for each ICA component. | ||
topomaps : np.ndarray | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
The topomaps for each ICA component | ||
|
||
""" | ||
_validate_type(raw, BaseRaw, "raw") | ||
|
||
if not any( | ||
ch_type in ["mag", "grad"] for ch_type in raw.get_channel_types(unique=True) | ||
): | ||
raise RuntimeError( | ||
"Could not find MEG channels in the provided " | ||
"Raw instance. The MEGnet model was fitted on" | ||
"MEG data and is not suited for other types of channels." | ||
) | ||
|
||
if raw.times[-1] < 60: | ||
raise RuntimeError( | ||
f"The provided raw instance has {raw.times[-1]} seconds. " | ||
"MEGnet was designed to classify features extracted from " | ||
"an MEG datasetat least 60 seconds long. " | ||
) | ||
|
||
if _check_notch(raw, "mag"): | ||
raise RuntimeError( | ||
"Line noise detected in 50/60 Hz. " | ||
"MEGnet was trained on MEG data without line noise. " | ||
"Please remove line noise before using MEGnet." | ||
"see the 'notch_filter()' method for Raw instances." | ||
) | ||
|
||
if not np.isclose(raw.info["sfreq"], 250, atol=1e-1): | ||
warn( | ||
"The provided raw instance is not sampled at 250 Hz" | ||
f"(sfreq={raw.info['sfreq']} Hz). " | ||
"MEGnet was designed to classify features extracted from" | ||
"an MEG dataset sampled at 250 Hz" | ||
"(see the 'resample()' method for raw)." | ||
"The classification performance might be negatively impacted." | ||
) | ||
|
||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if raw.info["highpass"] != 1 or raw.info["lowpass"] != 100: | ||
warn( | ||
"The provided raw instance is not filtered between 1 and 100 Hz. " | ||
"MEGnet was designed to classify features extracted from an MEG dataset " | ||
"bandpass filtered between 1 and 100 Hz (see the 'filter()' method for Raw." | ||
) | ||
|
||
if ica.method != "infomax": | ||
warn( | ||
f"The provided ICA instance was fitted with a '{ica.method}' algorithm. " | ||
"MEGnet was designed with infomax ICA decompositions. To use the " | ||
"infomax algorithm, use the 'mne.preprocessing.ICA' instance with " | ||
"the arguments 'ICA(method='infomax')" | ||
) | ||
|
||
pos_new, outlines = _get_topomaps_data(ica) | ||
topomaps = _get_topomaps(ica, pos_new, outlines) | ||
time_series = ica.get_sources(raw)._data | ||
|
||
return time_series, topomaps | ||
|
||
|
||
def _make_head_outlines(sphere: NDArray, pos: NDArray, clip_origin: tuple): | ||
"""Generate head outlines and mask positions for the topomap plot.""" | ||
x, y, _, radius = sphere | ||
ll = np.linspace(0, 2 * np.pi, 101) | ||
head_x = np.cos(ll) * radius * 1.01 + x | ||
head_y = np.sin(ll) * radius * 1.01 + y | ||
|
||
mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius) | ||
clip_radius = radius * mask_scale | ||
|
||
outlines_dict = { | ||
"head": (head_x, head_y), | ||
"mask_pos": (mask_scale * head_x, mask_scale * head_y), | ||
"clip_radius": (clip_radius,) * 2, | ||
"clip_origin": clip_origin, | ||
} | ||
return outlines_dict | ||
|
||
|
||
def _get_topomaps_data(ica: ICA): | ||
"""Prepare 2D sensor positions and outlines for topomap plotting.""" | ||
mags = mne.pick_types(ica.info, meg="mag") | ||
channel_info = ica.info["chs"] | ||
loc_3d = [channel_info[i]["loc"][0:3] for i in mags] | ||
channel_locations_3d = np.array(loc_3d) | ||
|
||
# Convert to spherical and then to 2D | ||
sph_coords = np.transpose( | ||
cart2sph( | ||
channel_locations_3d[:, 0], | ||
channel_locations_3d[:, 1], | ||
channel_locations_3d[:, 2], | ||
) | ||
) | ||
TH, PHI = sph_coords[:, 1], sph_coords[:, 2] | ||
newR = 1 - PHI / np.pi * 2 | ||
channel_locations_2d = np.transpose(pol2cart(newR, TH)) | ||
|
||
# Adjust coordinates with convex hull interpolation | ||
hull = ConvexHull(channel_locations_2d) | ||
border_indices = hull.vertices | ||
Dborder = 1 / newR[border_indices] | ||
|
||
funcTh = np.hstack( | ||
[ | ||
TH[border_indices] - 2 * np.pi, | ||
TH[border_indices], | ||
TH[border_indices] + 2 * np.pi, | ||
] | ||
) | ||
funcD = np.hstack((Dborder, Dborder, Dborder)) | ||
interp_func = interpolate.interp1d(funcTh, funcD) | ||
D = interp_func(TH) | ||
|
||
adjusted_R = np.array([min(newR[i] * D[i], 1) for i in range(len(mags))]) | ||
Xnew, Ynew = pol2cart(adjusted_R, TH) | ||
pos_new = np.vstack((Xnew, Ynew)).T | ||
|
||
outlines = _make_head_outlines(np.array([0, 0, 0, 1]), pos_new, (0, 0)) | ||
return pos_new, outlines | ||
|
||
|
||
def _get_topomaps(ica: ICA, pos_new: NDArray, outlines: dict): | ||
"""Generate topomap images for each ICA component.""" | ||
topomaps = [] | ||
data_picks = mne.pick_types(ica.info, meg="mag") | ||
components = ica.get_components() | ||
|
||
for comp in range(ica.n_components_): | ||
data = components[data_picks, comp] | ||
fig = plt.figure(figsize=(1.3, 1.3), dpi=100, facecolor="black") | ||
ax = fig.add_subplot(111) | ||
mnefig, _ = mne.viz.plot_topomap( | ||
data, | ||
pos_new, | ||
sensors=False, | ||
outlines=outlines, | ||
extrapolate="head", | ||
sphere=[0, 0, 0, 1], | ||
contours=0, | ||
res=120, | ||
axes=ax, | ||
show=False, | ||
cmap="bwr", | ||
) | ||
img_buf = io.BytesIO() | ||
mnefig.figure.savefig( | ||
img_buf, format="png", dpi=120, bbox_inches="tight", pad_inches=0 | ||
) | ||
img_buf.seek(0) | ||
rgba_image = Image.open(img_buf) | ||
rgb_image = rgba_image.convert("RGB") | ||
img_buf.close() | ||
plt.close(fig) | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
topomaps.append(np.array(rgb_image)) | ||
|
||
return np.array(topomaps) | ||
|
||
|
||
def _line_noise_channel( | ||
raw: BaseRaw, | ||
picks: str, | ||
fline: float = 50.0, | ||
neighbor_width: float = 2.0, | ||
threshold_factor: float = 3, | ||
show: bool = False, | ||
): | ||
"""Detect line noise in MEG/EEG data. | ||
|
||
Parameters | ||
---------- | ||
raw : mne.io.Raw | ||
The raw MEG/EEG data. | ||
picks : str or list, optional | ||
Channels to include in the analysis. | ||
fline : float, optional | ||
The base frequency of the line noise to detect. | ||
neighbor_width : float, optional | ||
Width of the frequency neighborhood around each harmonic | ||
for calculating background noise, in Hz. Default is 2.0 Hz. | ||
threshold_factor : float, optional | ||
Multiplicative factor for setting the detection threshold. | ||
The threshold is set as the mean of the neighboring frequencies plus | ||
`threshold_factor` times the standard deviation. Default is 3. | ||
show : bool, optional | ||
Whether to plot the PSD for channels affected by line noise. | ||
|
||
Returns | ||
------- | ||
bool | ||
mscheltienne marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Returns True if line noise is detected in any channel, otherwise False. | ||
|
||
""" | ||
psd = raw.compute_psd(picks=picks) | ||
freqs = psd.freqs | ||
psds = psd.get_data() | ||
ch_names = psd.ch_names | ||
|
||
# Compute Nyquist frequency and determine maximum harmonic | ||
nyquist_freq = raw.info["sfreq"] / 2.0 | ||
max_harmonic = int(nyquist_freq // fline) | ||
|
||
# Generate list of harmonic frequencies based on the fundamental frequency | ||
line_freqs = np.arange(fline, fline * (max_harmonic + 1), fline) | ||
freq_res = freqs[1] - freqs[0] | ||
n_neighbors = int(np.ceil(neighbor_width / freq_res)) | ||
|
||
line_noise_detected = [] | ||
for ch_idx in range(psds.shape[0]): | ||
psd_ch = psds[ch_idx, :] # PSD for the current channel | ||
for lf in line_freqs: | ||
# Find the frequency index closest to the current harmonic | ||
idx = np.argmin(np.abs(freqs - lf)) | ||
# Get index range for neighboring frequencies, | ||
# excluding the harmonic frequency itself | ||
idx_range = np.arange( | ||
max(0, idx - n_neighbors), min(len(freqs), idx + n_neighbors + 1) | ||
) | ||
idx_neighbors = idx_range[idx_range != idx] | ||
# Calculate mean and standard deviation of neighboring frequencies | ||
neighbor_mean = np.mean(psd_ch[idx_neighbors]) | ||
neighbor_std = np.std(psd_ch[idx_neighbors]) | ||
|
||
threshold = neighbor_mean + threshold_factor * neighbor_std | ||
if psd_ch[idx] > threshold: | ||
line_noise_detected.append( | ||
{"channel": ch_names[ch_idx], "frequency": lf} | ||
) | ||
|
||
if show and line_noise_detected: | ||
affected_channels = set([item["channel"] for item in line_noise_detected]) | ||
plt.figure(figsize=(12, 3)) | ||
for ch_name in affected_channels: | ||
ch_idx = ch_names.index(ch_name) | ||
plt.semilogy(freqs, psds[ch_idx, :], label=ch_name) | ||
plt.axvline(fline, color="k", linestyle="--", lw=3, alpha=0.3) | ||
plt.xlabel("Frequency (Hz)") | ||
plt.ylabel("Power Spectral Density (PSD)") | ||
plt.legend() | ||
plt.show() | ||
|
||
return line_noise_detected | ||
|
||
|
||
def _check_notch( | ||
raw: BaseRaw, picks: str, neighbor_width: float = 2.0, threshold_factor: float = 3 | ||
) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm wondering why we don't just notch filter for them at the Would this impact the results? I guess I can imagine a user constructing the Raw object via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Elekta/MEGIN system do store the line noise frequency, but I don't know if this is the case for other systems. I would keep it is as it is now, with the responsibility to apply the notch filter left to the user. |
||
"""Return True if line noise find in raw.""" | ||
check_result = False | ||
for fline in [50, 60]: | ||
if _line_noise_channel(raw, picks, fline, neighbor_width, threshold_factor): | ||
check_result = True | ||
break | ||
return check_result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These functions are also defined for ICLabel, so I wonder if we can pull them out for a general utility functions related to geometry.
mne-icalabel/mne_icalabel/iclabel/_utils.py
Line 97 in b2fc448
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pol2cart
is indeed a duplicate,cart2sph
returns the element in a different order and it's a bit annoying to change the order. We can keep code de-duplication for a future PR.