import io from dataclasses import dataclass from typing import List, Tuple import numpy as np import streamlit as st import plotly.graph_objects as go import mne from scipy.signal import hilbert try: import community as community_louvain import networkx as nx LOUVAIN_AVAILABLE = True except ImportError: LOUVAIN_AVAILABLE = False st.warning("⚠️ Louvainクラスタリングを使用するには `pip install python-louvain networkx` を実行してください。") from loader import ( pick_set_fdt, load_eeglab_tc_from_bytes, load_mat_candidates, ) import metrics st.set_page_config(page_title="EEG Viewer + Network Estimation", layout="wide") # ============================================================ # Preprocess config # ============================================================ @dataclass(frozen=True) class PreprocessConfig: fs: float f_low: float f_high: float # ============================================================ # Helpers # ============================================================ def ensure_tc(x: np.ndarray) -> np.ndarray: """Ensure array is (T,C). Accept (T,), (T,C), (C,T) with heuristic transpose.""" x = np.asarray(x) if x.ndim == 1: return x[:, None] if x.ndim != 2: raise ValueError(f"2次元配列のみ対応です: shape={x.shape}") T, C = x.shape if T <= 256 and C > T: # heuristic transpose x = x.T return x def _quad_bezier_points(p0, p1, c, n=20): """2次Bezierを点列にして返す (n点)""" ts = np.linspace(0, 1, n) pts = (1-ts)[:,None]**2 * p0 + 2*(1-ts)[:,None]*ts[:,None]*c + ts[:,None]**2 * p1 return pts # shape (n,2) def _quad_bezier_point_and_tangent(p0, p1, c, t): """2次Bezierの点と接線ベクトル(微分)を返す""" # B(t) = (1-t)^2 p0 + 2(1-t)t c + t^2 p1 pt = (1-t)**2 * p0 + 2*(1-t)*t * c + t**2 * p1 # B'(t) = 2(1-t)(c-p0) + 2t(p1-c) tan = 2*(1-t)*(c-p0) + 2*t*(p1-c) return pt, tan # ============================================================ # Signal processing # ============================================================ def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray: """Bandpass filter each channel using MNE RawArray. Input/Output: (T,C).""" info = mne.create_info( ch_names=[f"ch{i}" for i in range(x_tc.shape[1])], sfreq=float(cfg.fs), ch_types="eeg", ) raw = mne.io.RawArray(x_tc.T, info, verbose=False) # (C,T) raw_filt = raw.copy().filter(l_freq=cfg.f_low, h_freq=cfg.f_high, verbose=False) return raw_filt.get_data().T.astype(np.float32) def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray: """Hilbert envelope per channel using SciPy. Input/Output: (T,C).""" analytic = hilbert(x_tc, axis=0) return np.abs(analytic).astype(np.float32) def hilbert_phase_tc(x_tc: np.ndarray) -> np.ndarray: """Hilbert phase per channel using SciPy. Input/Output: (T,C).""" analytic = hilbert(x_tc, axis=0) return np.angle(analytic).astype(np.float32) def preprocess_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> dict: """raw(T,C) -> filtered/envelope/phase をまとめて返す""" x_tc = ensure_tc(x_tc).astype(np.float32) x_filt = bandpass_tc(x_tc, cfg) env = hilbert_envelope_tc(x_filt) phase = hilbert_phase_tc(x_filt) return { "fs": float(cfg.fs), "raw": x_tc, "filtered": x_filt, "envelope": env, "amplitude": env, # envelope のエイリアス "phase": phase } @st.cache_data(show_spinner=False) def preprocess_all_eeglab( set_bytes: bytes, fdt_bytes: bytes, set_name: str, fdt_name: str, f_low: float, f_high: float, ) -> dict: """ EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert). fsは読み込んだデータのものを使う。 """ x_tc, fs, electrode_pos_2d, electrode_pos_3d = load_eeglab_tc_from_bytes( set_bytes=set_bytes, set_name=set_name, fdt_bytes=fdt_bytes, fdt_name=fdt_name, ) cfg = PreprocessConfig(fs=float(fs), f_low=float(f_low), f_high=float(f_high)) result = preprocess_tc(x_tc, cfg) # 電極位置を追加 if electrode_pos_2d is not None: result["electrode_pos"] = electrode_pos_2d if electrode_pos_3d is not None: result["electrode_pos_3d"] = electrode_pos_3d return result @st.cache_data(show_spinner=False) def load_mat_candidates_cached(mat_bytes: bytes) -> dict: """MAT candidatesをキャッシュ(UI操作で毎回読まない)""" return load_mat_candidates(mat_bytes) # ============================================================ # Viewer # ============================================================ def window_slice(X_tc: np.ndarray, start_idx: int, end_idx: int, decim: int) -> np.ndarray: start_idx = max(0, min(start_idx, X_tc.shape[0] - 1)) end_idx = max(start_idx + 1, min(end_idx, X_tc.shape[0])) decim = max(1, int(decim)) return X_tc[start_idx:end_idx:decim, :] def make_timeseries_figure( X_tc: np.ndarray, selected_channels: List[int], fs: float, start_sec: float, win_sec: float, decim: int, offset_mode: bool, show_rangeslider: bool, signal_type: str = "filtered", ) -> go.Figure: start_idx = int(round(start_sec * fs)) end_idx = int(round((start_sec + win_sec) * fs)) Xw = window_slice(X_tc, start_idx, end_idx, decim) Tw = Xw.shape[0] t = (np.arange(Tw) * decim + start_idx) / fs fig = go.Figure() if not selected_channels: fig.update_layout( title="Timeseries (no channel selected)", height=450, xaxis_title="time (s)", yaxis_title="amplitude", ) return fig # 位相データの場合は特別な処理 is_phase = signal_type == "phase" if offset_mode and len(selected_channels) > 1 and not is_phase: per_ch_std = np.std(Xw[:, selected_channels], axis=0) base = float(np.median(per_ch_std)) if np.isfinite(np.median(per_ch_std)) and np.median(per_ch_std) > 0 else 1.0 offset = 5.0 * base for k, ch in enumerate(selected_channels): y = Xw[:, ch] + k * offset fig.add_trace(go.Scatter(x=t, y=y, mode="lines", name=f"ch{ch}", line=dict(width=1))) ylab = "amplitude (offset)" else: for ch in selected_channels: fig.add_trace(go.Scatter(x=t, y=Xw[:, ch], mode="lines", name=f"ch{ch}", line=dict(width=1))) if is_phase: ylab = "phase (rad)" else: ylab = "amplitude" # rangeslider の高さを考慮して調整 plot_height = 550 if show_rangeslider else 450 bottom_margin = 150 if show_rangeslider else 80 title_text = f"Timeseries: {signal_type} (window={win_sec:.2f}s, start={start_sec:.2f}s, decim={decim})" fig.update_layout( title=title_text, height=plot_height, xaxis_title="time (s)", yaxis_title=ylab, legend=dict(orientation="h"), margin=dict(l=60, r=20, t=80, b=bottom_margin), ) # 位相の場合は y軸の範囲を -π ~ π に固定 if is_phase: fig.update_yaxes(range=[-np.pi - 0.5, np.pi + 0.5]) if show_rangeslider: fig.update_xaxes( rangeslider=dict( visible=True, thickness=0.05, ) ) else: fig.update_xaxes(rangeslider=dict(visible=False)) return fig # ============================================================ # Network (multiple methods) + export # ============================================================ def estimate_network_envelope_corr(X_tc: np.ndarray) -> np.ndarray: """ Envelope (amplitude) の Pearson 相関係数を計算。 Input: X_tc (T, C) - envelope データ Output: W (C, C) - 相関係数の絶対値 """ X = X_tc - X_tc.mean(axis=0, keepdims=True) corr = np.corrcoef(X, rowvar=False) W = np.abs(corr) np.fill_diagonal(W, 0.0) return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) def estimate_network_phase_corr(X_tc: np.ndarray) -> np.ndarray: """ Phase の PLV を計算。 Input: X_tc (T, C) - phase データ (ラジアン) Output: W (C, C) - circular correlation circular correlationは以下で計算: """ T, C = X_tc.shape W = np.zeros((C, C), dtype=np.float32) # 各チャンネルペアについて PLV を計算 for i in range(C): for j in range(i + 1, C): #Jammalamadaka–Sengupta circular correlation corr = metrics.circular_correlation(X_tc[:, i], X_tc[:, j]) W[i, j] = corr W[j, i] = corr return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) def estimate_network_phase_PLV(X_tc: np.ndarray, progress) -> np.ndarray: """ Phase の PLV を計算。 Input: X_tc (T, C) - phase データ (ラジアン) Output: W (C, C) - PLV PLV は以下で計算: r_ij = |⟨exp(i*(θ_i - θ_j))⟩_t| """ T, C = X_tc.shape W = np.zeros((C, C), dtype=np.float32) # 各チャンネルペアについて PLV を計算 tmp_ = 0 for i in range(C): for j in range(i + 1, C): # 位相差 phase_diff = X_tc[:, i] - X_tc[:, j] plv = np.abs(np.mean(np.exp(1j * phase_diff))) W[i, j] = plv W[j, i] = plv tmp_ += 1 progress.progress(tmp_ / (int(C*(C-1)/2))) return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) def estimate_network_pac_tort(X_tc1, X_tc2, progress): """ PACを目的としてModulation Indexを計算 Input: X_tc1 (T, C) - phase データ (ラジアン) Input: X_tc2 (T, C) - envelope データ Output: W (C, C) - Modulation Index """ assert X_tc1.shape == X_tc2.shape T, C = X_tc1.shape W = np.zeros((C, C), dtype=np.float32) # 各チャンネルペアについて Chatterjee correlation を計算 tmp_ = 0 for i in range(C): for j in range(C): if i == j: continue # Modulation Index from Tort et al.(2010) mi_ = metrics.modulation_index(X_tc1[:, i], X_tc2[:, j]) W[i, j] = mi_ tmp_ += 1 progress.progress(tmp_ / (C*C)) return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) def estimate_network_pac_chatterjee(X_tc1, X_tc2, progress): """ PACを目的としてChatterjee相関を計算 Input: X_tc1 (T, C) - phase データ (ラジアン) Input: X_tc2 (T, C) - envelope データ Output: W (C, C) - Chatterjee correlation from phase to envelope """ assert X_tc1.shape == X_tc2.shape T, C = X_tc1.shape W = np.zeros((C, C), dtype=np.float32) # 各チャンネルペアについて Chatterjee correlation を計算 tmp_ = 0 for i in range(C): for j in range(C): if i == j: continue # Chatterjee相関係数 corr_ = metrics.chatterjee_phase_to_amp(X_tc1[:, i], X_tc2[:, j]) W[i, j] = corr_ tmp_ += 1 progress.progress(tmp_ / (C*C)) return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) def estimate_network_dummy(X_tc: np.ndarray) -> np.ndarray: """ ダミー実装: 単純な相関係数の絶対値 (後方互換性のため残す) """ X = X_tc - X_tc.mean(axis=0, keepdims=True) corr = np.corrcoef(X, rowvar=False) W = np.abs(corr) np.fill_diagonal(W, 0.0) return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) def threshold_edges( W: np.ndarray, thr: float, ) -> List[Tuple[int, int, float]]: """ エッジ抽出関数 - W が対称 → 無向グラフとして i < j のみ抽出 - W が非対称 → 有向グラフとして i -> j をすべて抽出 Returns: (i, j, w): 対称の場合は無向、非対称の場合は i→j """ C = W.shape[0] edges: List[Tuple[int, int, float]] = [] is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0) if is_symmetric: # --- 無向グラフ --- for i in range(C): for j in range(i + 1, C): w = float(W[i, j]) if w >= thr: edges.append((i, j, w)) else: # --- 有向グラフ --- for i in range(C): for j in range(C): if i == j: continue w = float(W[i, j]) if w >= thr: edges.append((i, j, w)) # 重みの大きい順にソート edges.sort(key=lambda x: x[2], reverse=True) return edges def adjacency_at_threshold(W: np.ndarray, thr: float, weighted: bool) -> np.ndarray: if weighted: A = W.copy() A[A < thr] = 0.0 np.fill_diagonal(A, 0.0) return A A = (W >= thr).astype(int) np.fill_diagonal(A, 0) return A def compute_louvain_clusters(W: np.ndarray, thr: float) -> np.ndarray: """ Louvain法でクラスタリングを実行。 Args: W: 重み行列 (C, C) thr: 閾値(これ以下のエッジは削除) Returns: clusters: クラスタID配列 (C,) """ if not LOUVAIN_AVAILABLE: # Louvainが使えない場合は全ノードを同じクラスタに return np.zeros(W.shape[0], dtype=int) # NetworkXグラフを作成 G = nx.Graph() C = W.shape[0] G.add_nodes_from(range(C)) # 閾値以上のエッジを追加 for i in range(C): for j in range(C): if W[i, j] >= thr: G.add_edge(i, j, weight=max(W[i, j],W[j, i])) # Louvain法でコミュニティ検出 partition = community_louvain.best_partition(G, weight='weight') # クラスタIDの配列に変換 clusters = np.array([partition[i] for i in range(C)]) return clusters def get_cluster_colors(clusters: np.ndarray) -> List[str]: """ クラスタIDから色のリストを生成。 Args: clusters: クラスタID配列 (C,) Returns: colors: 色のリスト """ import colorsys n_clusters = len(np.unique(clusters)) # クラスタ数に応じて色相を均等に分割 colors = [] for cluster_id in clusters: hue = cluster_id / max(n_clusters, 1) r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.95) colors.append(f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})') return colors def get_electrode_positions(prep: dict) -> np.ndarray: """ 電極位置を取得する。 Returns: pos: (C, 2) 電極の2D座標 (x, y) 取得できない場合は円形配置を返す """ # prepに電極位置が保存されているかチェック if "electrode_pos" in prep: return prep["electrode_pos"] # デフォルト: 円形配置 C = prep["raw"].shape[1] angles = np.linspace(0, 2 * np.pi, C, endpoint=False) xs = np.cos(angles) ys = np.sin(angles) return np.column_stack([xs, ys]) def make_network_figure_3d( W: np.ndarray, thr: float, electrode_pos_3d: np.ndarray, use_louvain: bool = True, ) -> go.Figure: """ 3Dネットワーク図を作成(ドラッグで回転可能) """ C = W.shape[0] xs = electrode_pos_3d[:, 0] ys = electrode_pos_3d[:, 1] zs = electrode_pos_3d[:, 2] edges = threshold_edges(W, thr) fig = go.Figure() # エッジの重みの範囲を取得 if edges: weights = [w for _, _, w in edges] min_w = min(weights) max_w = max(weights) weight_range = max_w - min_w if max_w > min_w else 1.0 else: min_w = 0 max_w = 1 weight_range = 1.0 # レインボーカラーマップ関数 def get_rainbow_color(norm_val): import colorsys hue = (1.0 - norm_val) * 0.67 r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95) return f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})' # エッジを描画 for (i, j, w) in edges: norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5 color = get_rainbow_color(norm_w) line_width = 1 + 4 * norm_w fig.add_trace(go.Scatter3d( x=[xs[i], xs[j], None], y=[ys[i], ys[j], None], z=[zs[i], zs[j], None], mode='lines', line=dict(color=color, width=line_width), hoverinfo='skip', showlegend=False, )) # Louvainクラスタリング if use_louvain and LOUVAIN_AVAILABLE: clusters = compute_louvain_clusters(W, thr) node_colors = get_cluster_colors(clusters) n_clusters = len(np.unique(clusters)) title_suffix = f" | Louvain clusters: {n_clusters}" else: node_colors = ['#FFD700'] * C clusters = np.zeros(C, dtype=int) title_suffix = "" # ノードを描画 fig.add_trace(go.Scatter3d( x=xs, y=ys, z=zs, mode='markers+text', text=[f"{k}" for k in range(C)], textposition='top center', textfont=dict(size=8), marker=dict( size=8, color=node_colors, line=dict(color='white', width=1), ), hoverinfo='text', hovertext=[f"channel {k}
cluster: {clusters[k]}" for k in range(C)], showlegend=False, )) fig.update_layout( title=f"3D Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}", height=700, scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), bgcolor='rgba(0,0,0,0.9)', ), paper_bgcolor='rgba(0,0,0,0.9)', margin=dict(l=0, r=0, t=50, b=0), ) return fig def make_network_figure( W: np.ndarray, thr: float, use_louvain: bool = True, electrode_pos: np.ndarray = None, ) -> tuple[go.Figure, int]: C = W.shape[0] # 電極位置を取得 if electrode_pos is None or electrode_pos.shape[0] != C: # デフォルト: 円形配置 angles = np.linspace(0, 2 * np.pi, C, endpoint=False) xs = np.cos(angles) ys = np.sin(angles) else: xs = electrode_pos[:, 0] ys = electrode_pos[:, 1] edges = threshold_edges(W, thr) fig = go.Figure() # エッジの重みの範囲を取得(色と太さのスケーリング用) if edges: weights = [w for _, _, w in edges] min_w = min(weights) max_w = max(weights) weight_range = max_w - min_w if max_w > min_w else 1.0 else: min_w = 0 max_w = 1 weight_range = 1.0 # レインボーカラーマップ関数 (0=青 → 0.5=緑/黄 → 1=赤) def get_rainbow_color(norm_val): """正規化された値 (0-1) からレインボーカラーを生成""" import colorsys # HSVのHue: 240°(青) → 0°(赤) に変換 hue = (1.0 - norm_val) * 0.67 # 0.67 ≈ 240/360 (青) r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95) return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)' # エッジを描画(重みに応じて色と太さを変える) # --- 有向のときだけ:矢印(三角マーカー)を終端側に置く --- is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0) if (not is_symmetric): curve_strength = 0.1 # 湾曲の強さ(要調整) node_radius = 0.08 # ノード中心からどれくらい手前に終点/矢印を置くか(要調整) bezier_n = 18 # 曲線の分割数(増やすほど滑らか) t_arrow = 0.90 # 矢印を置く位置(0〜1) for (i, j, w) in edges: norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5 color = get_rainbow_color(norm_w) line_width = 0.5 + 3.5 * norm_w p0 = np.array([xs[i], ys[i]], dtype=float) p1 = np.array([xs[j], ys[j]], dtype=float) v = p1 - p0 dist = np.hypot(v[0], v[1]) if dist < 1e-9: continue u = v / dist # ノードに重ならないよう端点を縮める p0s = p0 + u * node_radius p1s = p1 - u * node_radius # 垂直方向(法線) n = np.array([-u[1], u[0]]) # ★ 有向エッジは全部曲げる(規則的に) sign = 1.0 #if i < j else -1.0 # 制御点 mid = 0.5 * (p0s + p1s) c = mid + sign * curve_strength * dist * n # 曲線点列 pts = _quad_bezier_points(p0s, p1s, c, n=bezier_n) fig.add_trace(go.Scatter( x=pts[:, 0], y=pts[:, 1], mode="lines", hoverinfo="text", hovertext=f"ch{i} → ch{j}
weight: {w:.4f}", line=dict(width=line_width, color=color), showlegend=False, )) # 矢印(曲線接線方向) pt, tan = _quad_bezier_point_and_tangent(p0s, p1s, c, t_arrow) # 接線がゼロに近い場合の保険 tx, ty = float(tan[0]), float(tan[1]) if tx*tx + ty*ty < 1e-18: tx, ty = float(p1s[0] - p0s[0]), float(p1s[1] - p0s[1]) theta = np.degrees(np.arctan2(ty, tx)) # 接線の角度(+x基準) ANGLE_OFFSET = -90.0 # triangle-up(上向き) を接線方向に合わせる補正 ang = (theta + ANGLE_OFFSET) % 360 fig.add_trace(go.Scatter( x=[pt[0]], y=[pt[1]], mode="markers", hoverinfo="skip", marker=dict( symbol="triangle-up", size=10, angle=-ang, angleref="up", color=color, line=dict(width=0), ), showlegend=False, )) else: for (i, j, w) in edges: # 正規化された重み (0-1) norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5 # レインボーカラー: 弱い(青) → 中間(緑/黄) → 強い(赤) color = get_rainbow_color(norm_w) # 太さ: 重みに比例 (0.5-4の範囲) line_width = 0.5 + 3.5 * norm_w fig.add_trace( go.Scatter( x=[xs[i], xs[j]], y=[ys[i], ys[j]], mode="lines", hoverinfo="text", hovertext=f"ch{i} - ch{j}
weight: {w:.4f}", line=dict(width=line_width, color=color), showlegend=False, ) ) # Louvainクラスタリング if use_louvain and LOUVAIN_AVAILABLE: clusters = compute_louvain_clusters(W, thr) node_colors = get_cluster_colors(clusters) n_clusters = len(np.unique(clusters)) title_suffix = f" | Louvain clusters: {n_clusters}" else: node_colors = ['#FFD700'] * C # デフォルトのゴールド clusters = np.zeros(C, dtype=int) title_suffix = "" # ノードを描画 fig.add_trace( go.Scatter( x=xs, y=ys, mode="markers+text", text=[f"{k}" for k in range(C)], textposition="bottom center", textfont=dict(size=8), marker=dict( size=14, color=node_colors, line=dict(width=2, color='white') ), hoverinfo="text", hovertext=[f"channel {k}
cluster: {clusters[k]}" for k in range(C)], showlegend=False, ) ) fig.update_layout( title=f"Estimated Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}", height=600, xaxis=dict(visible=False), yaxis=dict(visible=False), margin=dict(l=10, r=10, t=50, b=50), paper_bgcolor='rgba(0,0,0,0.9)', plot_bgcolor='rgba(0,0,0,0.9)', ) fig.update_yaxes(scaleanchor="x", scaleratio=1) # カラーバー的な説明を追加 if edges: fig.add_annotation( text=f"Edge color/width: weak (blue/thin) → medium (green/yellow) → strong (red/thick)
Weight range: {min_w:.3f} - {max_w:.3f}", xref="paper", yref="paper", x=0.5, y=-0.05, showarrow=False, font=dict(size=10, color='white'), xanchor='center', ) return fig, len(edges) def make_edgecount_curve(W: np.ndarray) -> go.Figure: vals = np.sort(W[np.triu_indices(W.shape[0], k=1)]) thr_grid = np.linspace(float(vals.max()), float(vals.min()), 120) if vals.size else np.array([0.0]) counts = [len(threshold_edges(W, float(thr))) for thr in thr_grid] fig = go.Figure() fig.add_trace(go.Scatter(x=thr_grid, y=counts, mode="lines")) fig.update_layout( title="Edge count vs threshold (lower thr => more edges)", xaxis_title="threshold", yaxis_title="edge count", height=300, ) return fig def to_csv_bytes_matrix(mat: np.ndarray, fmt: str) -> bytes: buf = io.StringIO() np.savetxt(buf, mat, delimiter=",", fmt=fmt) return buf.getvalue().encode("utf-8") def to_csv_bytes_edges(edges: List[Tuple[int, int, float]]) -> bytes: buf = io.StringIO() buf.write("source,target,weight\n") for i, j, w in edges: buf.write(f"{i},{j},{w:.6f}\n") return buf.getvalue().encode("utf-8") # ============================================================ # Sidebar UI # ============================================================ st.sidebar.header("Input format") input_mode = st.sidebar.radio("データ形式", ["EEGLAB (.set + .fdt)", "MATLAB (.mat)"], index=0) st.sidebar.header("Preprocess (auto)") f_low_src = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=4.0, step=1.0, key="low_src") f_high_src = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=8.0, step=1.0, key="high_src") st.sidebar.header("if you use CFC+PAC:") f_low_tgt = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=25.0, step=1.0, key="low_tgt") f_high_tgt = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=40.0, step=1.0, key="high_tgt") st.sidebar.header("Viewer controls") win_sec = st.sidebar.number_input("Window length (sec)", min_value=0.1, value=5.0, step=0.1) decim = st.sidebar.selectbox("Decimation (間引き)", options=[1, 2, 5, 10, 20, 50], index=1) offset_mode = st.sidebar.checkbox("重ね描画のオフセット表示", value=True) show_rangeslider = st.sidebar.checkbox("Plotly rangesliderを表示", value=False) signal_view = st.sidebar.radio( "表示する信号", ["raw", "filtered", "amplitude", "phase"], index=1, help="raw: 生信号, filtered: バンドパス後, amplitude: Hilbert振幅(envelope), phase: Hilbert位相" ) st.title("EEG timeseries viewer + network estimation") # ============================================================ # Load + preprocess (EEGLAB / MAT) # ============================================================ if input_mode.startswith("EEGLAB"): st.sidebar.header("Upload (.set + .fdt)") uploaded_files = st.sidebar.file_uploader( "Upload EEGLAB files", type=["set", "fdt"], accept_multiple_files=True, ) if uploaded_files: set_file, fdt_file = pick_set_fdt(uploaded_files) if set_file is None or fdt_file is None: st.warning("`.set` と `.fdt` の両方をアップロードしてください。") else: try: with st.spinner("Loading EEGLAB + preprocessing (bandpass + hilbert)..."): prep_src = preprocess_all_eeglab( set_bytes=set_file.getvalue(), fdt_bytes=fdt_file.getvalue(), set_name=set_file.name, fdt_name=fdt_file.name, f_low=float(f_low_src), f_high=float(f_high_src), ) prep_tgt = preprocess_all_eeglab( set_bytes=set_file.getvalue(), fdt_bytes=fdt_file.getvalue(), set_name=set_file.name, fdt_name=fdt_file.name, f_low=float(f_low_tgt), f_high=float(f_high_tgt), ) st.session_state["prep"] = prep_src st.session_state["prep_tgt"] = prep_tgt st.session_state["W"] = None st.success(f"Loaded & preprocessed. (T,C)={prep_src['raw'].shape} fs={prep_src['fs']:.2f}Hz") except Exception as e: st.session_state.pop("prep", None) st.session_state["W"] = None st.error(f"読み込み/前処理エラー: {e}") else: st.sidebar.header("Upload (.mat)") mat_file = st.sidebar.file_uploader("Upload .mat", type=["mat"]) if mat_file is not None: mat_bytes = mat_file.getvalue() try: cands = load_mat_candidates_cached(mat_bytes) if not cands: st.error("数値の1D/2D配列が見つかりませんでした。") st.info("MATファイルの構造を確認しています...") # デバッグ: MATファイルの中身を表示 try: from scipy.io import loadmat mat_data = loadmat(io.BytesIO(mat_bytes)) st.write("**MATファイルに含まれる変数:**") for k, v in mat_data.items(): if not k.startswith('__'): if isinstance(v, np.ndarray): st.write(f"- `{k}`: shape={v.shape}, dtype={v.dtype}, ndim={v.ndim}") else: st.write(f"- `{k}`: type={type(v).__name__}") except Exception as e: st.write(f"デバッグ情報の取得に失敗: {e}") # HDF5形式の場合も試す try: import h5py import tempfile with tempfile.NamedTemporaryFile(suffix='.mat', delete=False) as tmp: tmp.write(mat_bytes) tmp_path = tmp.name st.write("**HDF5形式として読み込み中...**") with h5py.File(tmp_path, 'r') as f: def show_structure(name, obj): if isinstance(obj, h5py.Dataset): st.write(f"- `{name}`: shape={obj.shape}, dtype={obj.dtype}") f.visititems(show_structure) import os os.unlink(tmp_path) except Exception as e2: st.write(f"HDF5としても読み込めませんでした: {e2}") else: key = st.sidebar.selectbox("EEG配列(変数)を選択", options=list(cands.keys())) fs_mat = st.sidebar.number_input("Sampling rate (Hz)", min_value=0.1, value=256.0, step=0.1) # 変数が選択されたら自動的に前処理を実行 if key: x = cands[key] st.sidebar.write(f"選択した配列: shape={x.shape}, dtype={x.dtype}") try: with st.spinner("Preprocessing (bandpass + hilbert)..."): cfg = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_src), f_high=float(f_high_src)) prep = preprocess_tc(x, cfg) cfg_tgt = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_tgt), f_high=float(f_high_tgt)) prep_tgt = preprocess_tc(x, cfg_tgt) st.session_state["prep"] = prep st.session_state["prep_tgt"] = prep_tgt st.session_state["W"] = None st.success(f"Loaded MAT '{key}'. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz") except Exception as e: st.session_state.pop("prep", None) st.session_state["W"] = None st.error(f"前処理エラー: {e}") import traceback st.code(traceback.format_exc()) except Exception as e: st.session_state.pop("prep", None) st.session_state["W"] = None st.error(f".mat 読み込みエラー: {e}") import traceback st.code(traceback.format_exc()) if "prep" not in st.session_state: st.info("左のサイドバーからデータをアップロードしてください。") st.stop() # ============================================================ # Viewer # ============================================================ prep = st.session_state["prep"] fs = float(prep["fs"]) X_tc = prep[signal_view] T, C = X_tc.shape duration_sec = (T - 1) / fs if T > 1 else 0.0 max_start = max(0.0, float(duration_sec - win_sec)) start_sec = st.sidebar.slider( "Start time (sec)", min_value=0.0, max_value=float(max_start), value=0.0, step=float(max(0.01, win_sec / 200)), ) st.sidebar.header("Channels") # チャンネル選択の便利機能 col_ch1, col_ch2 = st.sidebar.columns(2) with col_ch1: select_all = st.button("全選択") with col_ch2: deselect_all = st.button("全解除") # 範囲選択 with st.sidebar.expander("📊 範囲で選択"): range_start = st.number_input("開始ch", min_value=0, max_value=C-1, value=0, step=1) range_end = st.number_input("終了ch", min_value=0, max_value=C-1, value=min(C-1, 7), step=1) if st.button("範囲を選択"): st.session_state["selected_channels"] = list(range(int(range_start), int(range_end) + 1)) # プリセット選択 with st.sidebar.expander("⚡ プリセット"): preset_col1, preset_col2 = st.columns(2) with preset_col1: if st.button("前頭部 (0-15)"): st.session_state["selected_channels"] = list(range(min(16, C))) with preset_col2: if st.button("頭頂部 (16-31)"): st.session_state["selected_channels"] = list(range(16, min(32, C))) preset_col3, preset_col4 = st.columns(2) with preset_col3: if st.button("側頭部 (32-47)"): st.session_state["selected_channels"] = list(range(32, min(48, C))) with preset_col4: if st.button("後頭部 (48-63)"): st.session_state["selected_channels"] = list(range(48, min(64, C))) # セッションステートの初期化 if "selected_channels" not in st.session_state: st.session_state["selected_channels"] = list(range(min(C, 8))) # ボタンによる選択の処理 if select_all: st.session_state["selected_channels"] = list(range(C)) if deselect_all: st.session_state["selected_channels"] = [] # メインの選択UI(最大表示数を制限) max_display = 20 # multiselect で一度に表示する数を制限 if C <= max_display: selected_channels = st.sidebar.multiselect( f"表示するチャンネル(全{C}ch)", options=list(range(C)), default=st.session_state["selected_channels"], key="ch_select", ) else: # 大量のチャンネルがある場合は、選択済みのものだけ表示 st.sidebar.caption(f"選択中: {len(st.session_state['selected_channels'])} / {C} channels") # 個別追加 add_ch = st.sidebar.number_input( "チャンネルを追加", min_value=0, max_value=C-1, value=0, step=1, key="add_ch_input" ) col_add, col_remove = st.sidebar.columns(2) with col_add: if st.button("➕ 追加"): if add_ch not in st.session_state["selected_channels"]: st.session_state["selected_channels"].append(int(add_ch)) st.session_state["selected_channels"].sort() with col_remove: if st.button("➖ 削除"): if add_ch in st.session_state["selected_channels"]: st.session_state["selected_channels"].remove(int(add_ch)) # 現在の選択を表示 if st.session_state["selected_channels"]: selected_str = ", ".join(map(str, st.session_state["selected_channels"][:10])) if len(st.session_state["selected_channels"]) > 10: selected_str += f", ... (+{len(st.session_state['selected_channels']) - 10})" st.sidebar.text(f"選択済み: {selected_str}") selected_channels = st.session_state["selected_channels"] # セッションステートを更新(multiselectを使った場合) if C <= max_display: st.session_state["selected_channels"] = selected_channels col1, col2 = st.columns([2, 1]) with col1: fig_ts = make_timeseries_figure( X_tc=X_tc, selected_channels=selected_channels, fs=fs, start_sec=float(start_sec), win_sec=float(win_sec), decim=int(decim), offset_mode=bool(offset_mode), show_rangeslider=bool(show_rangeslider), signal_type=signal_view, ) st.plotly_chart(fig_ts) with col2: st.subheader("Data info") signal_desc = { "raw": "生信号(前処理なし)", "filtered": f"バンドパスフィルタ後 ({f_low_src}-{f_high_src} Hz)", "amplitude": "Hilbert振幅 (envelope)", "phase": "Hilbert位相 (-π ~ π)" } st.write(f"- view: **{signal_view}** ({signal_desc.get(signal_view, '')})") st.write(f"- fs: **{fs:.2f} Hz**") st.write(f"- T: {T} samples") st.write(f"- C: {C} channels") st.write(f"- duration: {duration_sec:.2f} sec") if signal_view == "phase": st.caption("※ 位相は -π (rad) から π (rad) の範囲で表示されます") st.caption("※ 大規模データは window + decimation 推奨。rangesliderは重い場合OFF。") st.divider() # ============================================================ # Estimation # ============================================================ st.subheader("Network estimation") # 推定手法の選択 estimation_method = st.radio( "推定手法を選択", options=[ "envelope_corr", "phase_PLV", "phase_corr", "pac_tort", "pac_chatterjee" ], format_func=lambda x: { "envelope_corr": "Envelope Pearson correlation (振幅の相関)", "phase_PLV": "PLV(位相同期, PLV)", "phase_corr": "Circular correlation", "pac_tort": "Modulation Index(位相と振幅のPAC指標)", "pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)", }[x], horizontal=True, help="envelope_corr: 振幅包絡線のPearson相関係数 | phase_PLV: 位相のPhase Locking Value | phase_corr: 位相の相関係数 | pac_tort: Modulation index | pac_chatterjee: 位相から振幅へのChatterjee相関", ) # 推定手法の説明 method_info = { "envelope_corr": "**Envelope correlation**: 振幅包絡線(Hilbert amplitude)間のPearson相関係数を計算します。振幅が同期して変動するチャンネル間の結合を検出します。", "phase_PLV": "**PLV**: 位相間のPhase locking valueを計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。", "phase_corr": "**Circular correlation**: 位相間の相関係数を計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。", "pac_tort": "Modulation Index(位相と振幅のPAC指標)", "pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)", } st.info(method_info[estimation_method]) # セッションステートから前回の手法と W を取得 last_method = st.session_state.get("last_estimation_method") W = st.session_state.get("W") # 推定が必要かチェック(初回 or 手法変更) need_estimation = (W is None) or (last_method != estimation_method) if need_estimation: progress = st.progress(0.0) with st.spinner(f"推定中... ({estimation_method})"): if estimation_method == "envelope_corr": X_in = prep["amplitude"] W = estimate_network_envelope_corr(X_in) elif estimation_method == "phase_PLV": X_in = prep["phase"] W = estimate_network_phase_PLV(X_in, progress) elif estimation_method == "phase_corr": X_in = prep["phase"] W = estimate_network_phase_corr(X_in) elif estimation_method == "pac_tort": X_in_low_phase = prep["phase"] prep_tgt = st.session_state["prep_tgt"] X_in_high_amplitude = prep_tgt["amplitude"] W = estimate_network_pac_tort(X_in_low_phase,X_in_high_amplitude,progress) elif estimation_method == "pac_chatterjee": X_in_low_phase = prep["phase"] prep_tgt = st.session_state["prep_tgt"] X_in_high_amplitude = prep_tgt["amplitude"] W = estimate_network_pac_chatterjee(X_in_low_phase,X_in_high_amplitude,progress) else: st.error("未知の推定手法です") st.stop() # セッションステートに保存 st.session_state["W"] = W st.session_state["last_estimation_method"] = estimation_method st.success(f"✅ 推定完了: {estimation_method} (ネットワークサイズ: {W.shape[0]} nodes)") else: st.success(f"✓ 推定済み: **{estimation_method}** (ネットワークサイズ: {W.shape[0]} nodes)") # この時点で W は必ず存在する # 閾値スライダーとネットワーク図の表示 wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0 col_thr1, col_thr2 = st.columns([3, 1]) with col_thr1: thr = st.slider( "閾値 (threshold) ※下げるほどエッジが増えます", min_value=0.0, max_value=max(0.0001, wmax), value=wmax/2, step=max(wmax / 200, 0.001), ) with col_thr2: use_louvain = st.checkbox( "Louvainクラスタ", value=True, disabled=not LOUVAIN_AVAILABLE, help="ノードの色をコミュニティ検出結果で塗り分けます" ) # 電極位置を取得 electrode_pos = prep.get("electrode_pos", None) # 2D座標を90度左回転(上が正面になる向きに) if electrode_pos is not None: electrode_pos = np.asarray(electrode_pos, dtype=np.float32) if electrode_pos.ndim == 2 and electrode_pos.shape[1] >= 2: pos2 = electrode_pos[:, :2] electrode_pos = np.column_stack([-pos2[:, 1], pos2[:, 0]]) electrode_pos_3d = prep.get("electrode_pos_3d", None) if electrode_pos is not None: st.info(f"✓ 電極位置を使用してネットワークを配置 ({electrode_pos.shape[0]} channels)") else: st.info("ℹ️ 電極位置が取得できなかったため、円形配置を使用します") # 3D座標の有無を表示 if electrode_pos_3d is not None: st.success(f"✓ 3D電極座標を取得しました ({electrode_pos_3d.shape[0]} channels) - 下部に3Dビューアを表示します") else: st.info("ℹ️ 3D電極座標が取得できませんでした - 2D表示のみです") net_col1, net_col2 = st.columns([2, 1]) with net_col1: fig_net, edge_n = make_network_figure( W, float(thr), use_louvain=use_louvain, electrode_pos=electrode_pos, ) st.plotly_chart(fig_net) # 3Dネットワーク表示(3D座標がある場合のみ) if electrode_pos_3d is not None: electrode_pos_3d = np.asarray(electrode_pos_3d, dtype=np.float32) if electrode_pos_3d.ndim == 2 and electrode_pos_3d.shape[0] == W.shape[0] and electrode_pos_3d.shape[1] == 3: st.subheader("3D Viewer") fig_3d = make_network_figure_3d( W=W, thr=float(thr), electrode_pos_3d=electrode_pos_3d, use_louvain=use_louvain, ) st.plotly_chart( fig_3d, width="stretch", config={"displayModeBar": True, "scrollZoom": True}, ) else: st.warning(f"3D座標のshapeが不正です: {electrode_pos_3d.shape}(期待: (C,3), C={W.shape[0]})") with net_col2: st.metric("Edges", edge_n) st.plotly_chart(make_edgecount_curve(W)) st.write("# Hypothesis testing") st.write("Coming soon ...")