Spaces:
Sleeping
Sleeping
File size: 1,463 Bytes
b11ec91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | # preprocessing.py
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import mne
@dataclass(frozen=True)
class PreprocessConfig:
fs: float
f_low: float
f_high: float
def to_time_channel(x: np.ndarray) -> np.ndarray:
if x.ndim == 1:
return x[:, None]
if x.ndim != 2:
raise ValueError(f"Expected 1D or 2D array, got {x.shape}")
T, C = x.shape
if T <= 256 and C > T:
x = x.T
return x
def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray:
info = mne.create_info(
ch_names=[f"ch{i}" for i in range(x_tc.shape[1])],
sfreq=cfg.fs,
ch_types="eeg",
)
raw = mne.io.RawArray(x_tc.T, info, verbose=False)
raw_filt = raw.copy().filter(cfg.f_low, cfg.f_high, verbose=False)
return raw_filt.get_data().T
def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray:
Xf = np.fft.fft(x_tc, axis=0)
N = Xf.shape[0]
h = np.zeros(N)
if N % 2 == 0:
h[0] = h[N // 2] = 1
h[1:N // 2] = 2
else:
h[0] = 1
h[1:(N + 1) // 2] = 2
env = np.abs(np.fft.ifft(Xf * h[:, None], axis=0))
return env.astype(np.float32)
def preprocess_pipeline(x: np.ndarray, cfg: PreprocessConfig):
x_tc = to_time_channel(x)
x_filt = bandpass_tc(x_tc, cfg)
env = hilbert_envelope_tc(x_filt)
return {
"raw": x_tc,
"filtered": x_filt,
"envelope": env,
}
|