File size: 1,463 Bytes
b11ec91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# preprocessing.py
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import mne


@dataclass(frozen=True)
class PreprocessConfig:
    fs: float
    f_low: float
    f_high: float


def to_time_channel(x: np.ndarray) -> np.ndarray:
    if x.ndim == 1:
        return x[:, None]
    if x.ndim != 2:
        raise ValueError(f"Expected 1D or 2D array, got {x.shape}")
    T, C = x.shape
    if T <= 256 and C > T:
        x = x.T
    return x


def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray:
    info = mne.create_info(
        ch_names=[f"ch{i}" for i in range(x_tc.shape[1])],
        sfreq=cfg.fs,
        ch_types="eeg",
    )
    raw = mne.io.RawArray(x_tc.T, info, verbose=False)
    raw_filt = raw.copy().filter(cfg.f_low, cfg.f_high, verbose=False)
    return raw_filt.get_data().T


def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray:
    Xf = np.fft.fft(x_tc, axis=0)
    N = Xf.shape[0]
    h = np.zeros(N)
    if N % 2 == 0:
        h[0] = h[N // 2] = 1
        h[1:N // 2] = 2
    else:
        h[0] = 1
        h[1:(N + 1) // 2] = 2
    env = np.abs(np.fft.ifft(Xf * h[:, None], axis=0))
    return env.astype(np.float32)


def preprocess_pipeline(x: np.ndarray, cfg: PreprocessConfig):
    x_tc = to_time_channel(x)
    x_filt = bandpass_tc(x_tc, cfg)
    env = hilbert_envelope_tc(x_filt)
    return {
        "raw": x_tc,
        "filtered": x_filt,
        "envelope": env,
    }