# preprocessing.py from __future__ import annotations from dataclasses import dataclass import numpy as np import mne @dataclass(frozen=True) class PreprocessConfig: fs: float f_low: float f_high: float def to_time_channel(x: np.ndarray) -> np.ndarray: if x.ndim == 1: return x[:, None] if x.ndim != 2: raise ValueError(f"Expected 1D or 2D array, got {x.shape}") T, C = x.shape if T <= 256 and C > T: x = x.T return x def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray: info = mne.create_info( ch_names=[f"ch{i}" for i in range(x_tc.shape[1])], sfreq=cfg.fs, ch_types="eeg", ) raw = mne.io.RawArray(x_tc.T, info, verbose=False) raw_filt = raw.copy().filter(cfg.f_low, cfg.f_high, verbose=False) return raw_filt.get_data().T def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray: Xf = np.fft.fft(x_tc, axis=0) N = Xf.shape[0] h = np.zeros(N) if N % 2 == 0: h[0] = h[N // 2] = 1 h[1:N // 2] = 2 else: h[0] = 1 h[1:(N + 1) // 2] = 2 env = np.abs(np.fft.ifft(Xf * h[:, None], axis=0)) return env.astype(np.float32) def preprocess_pipeline(x: np.ndarray, cfg: PreprocessConfig): x_tc = to_time_channel(x) x_filt = bandpass_tc(x_tc, cfg) env = hilbert_envelope_tc(x_filt) return { "raw": x_tc, "filtered": x_filt, "envelope": env, }