Spaces:
Sleeping
Sleeping
| import io | |
| from dataclasses import dataclass | |
| from typing import List, Tuple | |
| import numpy as np | |
| import streamlit as st | |
| import plotly.graph_objects as go | |
| import mne | |
| from scipy.signal import hilbert | |
| try: | |
| import community as community_louvain | |
| import networkx as nx | |
| LOUVAIN_AVAILABLE = True | |
| except ImportError: | |
| LOUVAIN_AVAILABLE = False | |
| st.warning("⚠️ Louvainクラスタリングを使用するには `pip install python-louvain networkx` を実行してください。") | |
| from loader import ( | |
| pick_set_fdt, | |
| load_eeglab_tc_from_bytes, | |
| load_mat_candidates, | |
| ) | |
| import metrics | |
| st.set_page_config(page_title="EEG Viewer + Network Estimation", layout="wide") | |
| # ============================================================ | |
| # Preprocess config | |
| # ============================================================ | |
| class PreprocessConfig: | |
| fs: float | |
| f_low: float | |
| f_high: float | |
| # ============================================================ | |
| # Helpers | |
| # ============================================================ | |
| def ensure_tc(x: np.ndarray) -> np.ndarray: | |
| """Ensure array is (T,C). Accept (T,), (T,C), (C,T) with heuristic transpose.""" | |
| x = np.asarray(x) | |
| if x.ndim == 1: | |
| return x[:, None] | |
| if x.ndim != 2: | |
| raise ValueError(f"2次元配列のみ対応です: shape={x.shape}") | |
| T, C = x.shape | |
| if T <= 256 and C > T: # heuristic transpose | |
| x = x.T | |
| return x | |
| def _quad_bezier_points(p0, p1, c, n=20): | |
| """2次Bezierを点列にして返す (n点)""" | |
| ts = np.linspace(0, 1, n) | |
| pts = (1-ts)[:,None]**2 * p0 + 2*(1-ts)[:,None]*ts[:,None]*c + ts[:,None]**2 * p1 | |
| return pts # shape (n,2) | |
| def _quad_bezier_point_and_tangent(p0, p1, c, t): | |
| """2次Bezierの点と接線ベクトル(微分)を返す""" | |
| # B(t) = (1-t)^2 p0 + 2(1-t)t c + t^2 p1 | |
| pt = (1-t)**2 * p0 + 2*(1-t)*t * c + t**2 * p1 | |
| # B'(t) = 2(1-t)(c-p0) + 2t(p1-c) | |
| tan = 2*(1-t)*(c-p0) + 2*t*(p1-c) | |
| return pt, tan | |
| # ============================================================ | |
| # Signal processing | |
| # ============================================================ | |
| def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray: | |
| """Bandpass filter each channel using MNE RawArray. Input/Output: (T,C).""" | |
| info = mne.create_info( | |
| ch_names=[f"ch{i}" for i in range(x_tc.shape[1])], | |
| sfreq=float(cfg.fs), | |
| ch_types="eeg", | |
| ) | |
| raw = mne.io.RawArray(x_tc.T, info, verbose=False) # (C,T) | |
| raw_filt = raw.copy().filter(l_freq=cfg.f_low, h_freq=cfg.f_high, verbose=False) | |
| return raw_filt.get_data().T.astype(np.float32) | |
| def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray: | |
| """Hilbert envelope per channel using SciPy. Input/Output: (T,C).""" | |
| analytic = hilbert(x_tc, axis=0) | |
| return np.abs(analytic).astype(np.float32) | |
| def hilbert_phase_tc(x_tc: np.ndarray) -> np.ndarray: | |
| """Hilbert phase per channel using SciPy. Input/Output: (T,C).""" | |
| analytic = hilbert(x_tc, axis=0) | |
| return np.angle(analytic).astype(np.float32) | |
| def preprocess_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> dict: | |
| """raw(T,C) -> filtered/envelope/phase をまとめて返す""" | |
| x_tc = ensure_tc(x_tc).astype(np.float32) | |
| x_filt = bandpass_tc(x_tc, cfg) | |
| env = hilbert_envelope_tc(x_filt) | |
| phase = hilbert_phase_tc(x_filt) | |
| return { | |
| "fs": float(cfg.fs), | |
| "raw": x_tc, | |
| "filtered": x_filt, | |
| "envelope": env, | |
| "amplitude": env, # envelope のエイリアス | |
| "phase": phase | |
| } | |
| def preprocess_all_eeglab( | |
| set_bytes: bytes, | |
| fdt_bytes: bytes, | |
| set_name: str, | |
| fdt_name: str, | |
| f_low: float, | |
| f_high: float, | |
| ) -> dict: | |
| """ | |
| EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert). | |
| fsは読み込んだデータのものを使う。 | |
| """ | |
| x_tc, fs, electrode_pos_2d, electrode_pos_3d = load_eeglab_tc_from_bytes( | |
| set_bytes=set_bytes, | |
| set_name=set_name, | |
| fdt_bytes=fdt_bytes, | |
| fdt_name=fdt_name, | |
| ) | |
| cfg = PreprocessConfig(fs=float(fs), f_low=float(f_low), f_high=float(f_high)) | |
| result = preprocess_tc(x_tc, cfg) | |
| # 電極位置を追加 | |
| if electrode_pos_2d is not None: | |
| result["electrode_pos"] = electrode_pos_2d | |
| if electrode_pos_3d is not None: | |
| result["electrode_pos_3d"] = electrode_pos_3d | |
| return result | |
| def load_mat_candidates_cached(mat_bytes: bytes) -> dict: | |
| """MAT candidatesをキャッシュ(UI操作で毎回読まない)""" | |
| return load_mat_candidates(mat_bytes) | |
| # ============================================================ | |
| # Viewer | |
| # ============================================================ | |
| def window_slice(X_tc: np.ndarray, start_idx: int, end_idx: int, decim: int) -> np.ndarray: | |
| start_idx = max(0, min(start_idx, X_tc.shape[0] - 1)) | |
| end_idx = max(start_idx + 1, min(end_idx, X_tc.shape[0])) | |
| decim = max(1, int(decim)) | |
| return X_tc[start_idx:end_idx:decim, :] | |
| def make_timeseries_figure( | |
| X_tc: np.ndarray, | |
| selected_channels: List[int], | |
| fs: float, | |
| start_sec: float, | |
| win_sec: float, | |
| decim: int, | |
| offset_mode: bool, | |
| show_rangeslider: bool, | |
| signal_type: str = "filtered", | |
| ) -> go.Figure: | |
| start_idx = int(round(start_sec * fs)) | |
| end_idx = int(round((start_sec + win_sec) * fs)) | |
| Xw = window_slice(X_tc, start_idx, end_idx, decim) | |
| Tw = Xw.shape[0] | |
| t = (np.arange(Tw) * decim + start_idx) / fs | |
| fig = go.Figure() | |
| if not selected_channels: | |
| fig.update_layout( | |
| title="Timeseries (no channel selected)", | |
| height=450, | |
| xaxis_title="time (s)", | |
| yaxis_title="amplitude", | |
| ) | |
| return fig | |
| # 位相データの場合は特別な処理 | |
| is_phase = signal_type == "phase" | |
| if offset_mode and len(selected_channels) > 1 and not is_phase: | |
| per_ch_std = np.std(Xw[:, selected_channels], axis=0) | |
| base = float(np.median(per_ch_std)) if np.isfinite(np.median(per_ch_std)) and np.median(per_ch_std) > 0 else 1.0 | |
| offset = 5.0 * base | |
| for k, ch in enumerate(selected_channels): | |
| y = Xw[:, ch] + k * offset | |
| fig.add_trace(go.Scatter(x=t, y=y, mode="lines", name=f"ch{ch}", line=dict(width=1))) | |
| ylab = "amplitude (offset)" | |
| else: | |
| for ch in selected_channels: | |
| fig.add_trace(go.Scatter(x=t, y=Xw[:, ch], mode="lines", name=f"ch{ch}", line=dict(width=1))) | |
| if is_phase: | |
| ylab = "phase (rad)" | |
| else: | |
| ylab = "amplitude" | |
| # rangeslider の高さを考慮して調整 | |
| plot_height = 550 if show_rangeslider else 450 | |
| bottom_margin = 150 if show_rangeslider else 80 | |
| title_text = f"Timeseries: {signal_type} (window={win_sec:.2f}s, start={start_sec:.2f}s, decim={decim})" | |
| fig.update_layout( | |
| title=title_text, | |
| height=plot_height, | |
| xaxis_title="time (s)", | |
| yaxis_title=ylab, | |
| legend=dict(orientation="h"), | |
| margin=dict(l=60, r=20, t=80, b=bottom_margin), | |
| ) | |
| # 位相の場合は y軸の範囲を -π ~ π に固定 | |
| if is_phase: | |
| fig.update_yaxes(range=[-np.pi - 0.5, np.pi + 0.5]) | |
| if show_rangeslider: | |
| fig.update_xaxes( | |
| rangeslider=dict( | |
| visible=True, | |
| thickness=0.05, | |
| ) | |
| ) | |
| else: | |
| fig.update_xaxes(rangeslider=dict(visible=False)) | |
| return fig | |
| # ============================================================ | |
| # Network (multiple methods) + export | |
| # ============================================================ | |
| def estimate_network_envelope_corr(X_tc: np.ndarray) -> np.ndarray: | |
| """ | |
| Envelope (amplitude) の Pearson 相関係数を計算。 | |
| Input: X_tc (T, C) - envelope データ | |
| Output: W (C, C) - 相関係数の絶対値 | |
| """ | |
| X = X_tc - X_tc.mean(axis=0, keepdims=True) | |
| corr = np.corrcoef(X, rowvar=False) | |
| W = np.abs(corr) | |
| np.fill_diagonal(W, 0.0) | |
| return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) | |
| def estimate_network_phase_corr(X_tc: np.ndarray) -> np.ndarray: | |
| """ | |
| Phase の PLV を計算。 | |
| Input: X_tc (T, C) - phase データ (ラジアン) | |
| Output: W (C, C) - circular correlation | |
| circular correlationは以下で計算: | |
| """ | |
| T, C = X_tc.shape | |
| W = np.zeros((C, C), dtype=np.float32) | |
| # 各チャンネルペアについて PLV を計算 | |
| for i in range(C): | |
| for j in range(i + 1, C): | |
| #Jammalamadaka–Sengupta circular correlation | |
| corr = metrics.circular_correlation(X_tc[:, i], X_tc[:, j]) | |
| W[i, j] = corr | |
| W[j, i] = corr | |
| return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) | |
| def estimate_network_phase_PLV(X_tc: np.ndarray, progress) -> np.ndarray: | |
| """ | |
| Phase の PLV を計算。 | |
| Input: X_tc (T, C) - phase データ (ラジアン) | |
| Output: W (C, C) - PLV | |
| PLV は以下で計算: | |
| r_ij = |⟨exp(i*(θ_i - θ_j))⟩_t| | |
| """ | |
| T, C = X_tc.shape | |
| W = np.zeros((C, C), dtype=np.float32) | |
| # 各チャンネルペアについて PLV を計算 | |
| tmp_ = 0 | |
| for i in range(C): | |
| for j in range(i + 1, C): | |
| # 位相差 | |
| phase_diff = X_tc[:, i] - X_tc[:, j] | |
| plv = np.abs(np.mean(np.exp(1j * phase_diff))) | |
| W[i, j] = plv | |
| W[j, i] = plv | |
| tmp_ += 1 | |
| progress.progress(tmp_ / (int(C*(C-1)/2))) | |
| return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) | |
| def estimate_network_pac_tort(X_tc1, X_tc2, progress): | |
| """ | |
| PACを目的としてModulation Indexを計算 | |
| Input: X_tc1 (T, C) - phase データ (ラジアン) | |
| Input: X_tc2 (T, C) - envelope データ | |
| Output: W (C, C) - Modulation Index | |
| """ | |
| assert X_tc1.shape == X_tc2.shape | |
| T, C = X_tc1.shape | |
| W = np.zeros((C, C), dtype=np.float32) | |
| # 各チャンネルペアについて Chatterjee correlation を計算 | |
| tmp_ = 0 | |
| for i in range(C): | |
| for j in range(C): | |
| if i == j: | |
| continue | |
| # Modulation Index from Tort et al.(2010) | |
| mi_ = metrics.modulation_index(X_tc1[:, i], X_tc2[:, j]) | |
| W[i, j] = mi_ | |
| tmp_ += 1 | |
| progress.progress(tmp_ / (C*C)) | |
| return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) | |
| def estimate_network_pac_chatterjee(X_tc1, X_tc2, progress): | |
| """ | |
| PACを目的としてChatterjee相関を計算 | |
| Input: X_tc1 (T, C) - phase データ (ラジアン) | |
| Input: X_tc2 (T, C) - envelope データ | |
| Output: W (C, C) - Chatterjee correlation from phase to envelope | |
| """ | |
| assert X_tc1.shape == X_tc2.shape | |
| T, C = X_tc1.shape | |
| W = np.zeros((C, C), dtype=np.float32) | |
| # 各チャンネルペアについて Chatterjee correlation を計算 | |
| tmp_ = 0 | |
| for i in range(C): | |
| for j in range(C): | |
| if i == j: | |
| continue | |
| # Chatterjee相関係数 | |
| corr_ = metrics.chatterjee_phase_to_amp(X_tc1[:, i], X_tc2[:, j]) | |
| W[i, j] = corr_ | |
| tmp_ += 1 | |
| progress.progress(tmp_ / (C*C)) | |
| return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) | |
| def estimate_network_dummy(X_tc: np.ndarray) -> np.ndarray: | |
| """ | |
| ダミー実装: 単純な相関係数の絶対値 | |
| (後方互換性のため残す) | |
| """ | |
| X = X_tc - X_tc.mean(axis=0, keepdims=True) | |
| corr = np.corrcoef(X, rowvar=False) | |
| W = np.abs(corr) | |
| np.fill_diagonal(W, 0.0) | |
| return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) | |
| def threshold_edges( | |
| W: np.ndarray, | |
| thr: float, | |
| ) -> List[Tuple[int, int, float]]: | |
| """ | |
| エッジ抽出関数 | |
| - W が対称 → 無向グラフとして i < j のみ抽出 | |
| - W が非対称 → 有向グラフとして i -> j をすべて抽出 | |
| Returns: | |
| (i, j, w): 対称の場合は無向、非対称の場合は i→j | |
| """ | |
| C = W.shape[0] | |
| edges: List[Tuple[int, int, float]] = [] | |
| is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0) | |
| if is_symmetric: | |
| # --- 無向グラフ --- | |
| for i in range(C): | |
| for j in range(i + 1, C): | |
| w = float(W[i, j]) | |
| if w >= thr: | |
| edges.append((i, j, w)) | |
| else: | |
| # --- 有向グラフ --- | |
| for i in range(C): | |
| for j in range(C): | |
| if i == j: | |
| continue | |
| w = float(W[i, j]) | |
| if w >= thr: | |
| edges.append((i, j, w)) | |
| # 重みの大きい順にソート | |
| edges.sort(key=lambda x: x[2], reverse=True) | |
| return edges | |
| def adjacency_at_threshold(W: np.ndarray, thr: float, weighted: bool) -> np.ndarray: | |
| if weighted: | |
| A = W.copy() | |
| A[A < thr] = 0.0 | |
| np.fill_diagonal(A, 0.0) | |
| return A | |
| A = (W >= thr).astype(int) | |
| np.fill_diagonal(A, 0) | |
| return A | |
| def compute_louvain_clusters(W: np.ndarray, thr: float) -> np.ndarray: | |
| """ | |
| Louvain法でクラスタリングを実行。 | |
| Args: | |
| W: 重み行列 (C, C) | |
| thr: 閾値(これ以下のエッジは削除) | |
| Returns: | |
| clusters: クラスタID配列 (C,) | |
| """ | |
| if not LOUVAIN_AVAILABLE: | |
| # Louvainが使えない場合は全ノードを同じクラスタに | |
| return np.zeros(W.shape[0], dtype=int) | |
| # NetworkXグラフを作成 | |
| G = nx.Graph() | |
| C = W.shape[0] | |
| G.add_nodes_from(range(C)) | |
| # 閾値以上のエッジを追加 | |
| for i in range(C): | |
| for j in range(C): | |
| if W[i, j] >= thr: | |
| G.add_edge(i, j, weight=max(W[i, j],W[j, i])) | |
| # Louvain法でコミュニティ検出 | |
| partition = community_louvain.best_partition(G, weight='weight') | |
| # クラスタIDの配列に変換 | |
| clusters = np.array([partition[i] for i in range(C)]) | |
| return clusters | |
| def get_cluster_colors(clusters: np.ndarray) -> List[str]: | |
| """ | |
| クラスタIDから色のリストを生成。 | |
| Args: | |
| clusters: クラスタID配列 (C,) | |
| Returns: | |
| colors: 色のリスト | |
| """ | |
| import colorsys | |
| n_clusters = len(np.unique(clusters)) | |
| # クラスタ数に応じて色相を均等に分割 | |
| colors = [] | |
| for cluster_id in clusters: | |
| hue = cluster_id / max(n_clusters, 1) | |
| r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.95) | |
| colors.append(f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})') | |
| return colors | |
| def get_electrode_positions(prep: dict) -> np.ndarray: | |
| """ | |
| 電極位置を取得する。 | |
| Returns: | |
| pos: (C, 2) 電極の2D座標 (x, y) | |
| 取得できない場合は円形配置を返す | |
| """ | |
| # prepに電極位置が保存されているかチェック | |
| if "electrode_pos" in prep: | |
| return prep["electrode_pos"] | |
| # デフォルト: 円形配置 | |
| C = prep["raw"].shape[1] | |
| angles = np.linspace(0, 2 * np.pi, C, endpoint=False) | |
| xs = np.cos(angles) | |
| ys = np.sin(angles) | |
| return np.column_stack([xs, ys]) | |
| def make_network_figure_3d( | |
| W: np.ndarray, | |
| thr: float, | |
| electrode_pos_3d: np.ndarray, | |
| use_louvain: bool = True, | |
| ) -> go.Figure: | |
| """ | |
| 3Dネットワーク図を作成(ドラッグで回転可能) | |
| """ | |
| C = W.shape[0] | |
| xs = electrode_pos_3d[:, 0] | |
| ys = electrode_pos_3d[:, 1] | |
| zs = electrode_pos_3d[:, 2] | |
| edges = threshold_edges(W, thr) | |
| fig = go.Figure() | |
| # エッジの重みの範囲を取得 | |
| if edges: | |
| weights = [w for _, _, w in edges] | |
| min_w = min(weights) | |
| max_w = max(weights) | |
| weight_range = max_w - min_w if max_w > min_w else 1.0 | |
| else: | |
| min_w = 0 | |
| max_w = 1 | |
| weight_range = 1.0 | |
| # レインボーカラーマップ関数 | |
| def get_rainbow_color(norm_val): | |
| import colorsys | |
| hue = (1.0 - norm_val) * 0.67 | |
| r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95) | |
| return f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})' | |
| # エッジを描画 | |
| for (i, j, w) in edges: | |
| norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5 | |
| color = get_rainbow_color(norm_w) | |
| line_width = 1 + 4 * norm_w | |
| fig.add_trace(go.Scatter3d( | |
| x=[xs[i], xs[j], None], | |
| y=[ys[i], ys[j], None], | |
| z=[zs[i], zs[j], None], | |
| mode='lines', | |
| line=dict(color=color, width=line_width), | |
| hoverinfo='skip', | |
| showlegend=False, | |
| )) | |
| # Louvainクラスタリング | |
| if use_louvain and LOUVAIN_AVAILABLE: | |
| clusters = compute_louvain_clusters(W, thr) | |
| node_colors = get_cluster_colors(clusters) | |
| n_clusters = len(np.unique(clusters)) | |
| title_suffix = f" | Louvain clusters: {n_clusters}" | |
| else: | |
| node_colors = ['#FFD700'] * C | |
| clusters = np.zeros(C, dtype=int) | |
| title_suffix = "" | |
| # ノードを描画 | |
| fig.add_trace(go.Scatter3d( | |
| x=xs, | |
| y=ys, | |
| z=zs, | |
| mode='markers+text', | |
| text=[f"{k}" for k in range(C)], | |
| textposition='top center', | |
| textfont=dict(size=8), | |
| marker=dict( | |
| size=8, | |
| color=node_colors, | |
| line=dict(color='white', width=1), | |
| ), | |
| hoverinfo='text', | |
| hovertext=[f"channel {k}<br>cluster: {clusters[k]}" for k in range(C)], | |
| showlegend=False, | |
| )) | |
| fig.update_layout( | |
| title=f"3D Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}", | |
| height=700, | |
| scene=dict( | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| zaxis=dict(visible=False), | |
| bgcolor='rgba(0,0,0,0.9)', | |
| ), | |
| paper_bgcolor='rgba(0,0,0,0.9)', | |
| margin=dict(l=0, r=0, t=50, b=0), | |
| ) | |
| return fig | |
| def make_network_figure( | |
| W: np.ndarray, | |
| thr: float, | |
| use_louvain: bool = True, | |
| electrode_pos: np.ndarray = None, | |
| ) -> tuple[go.Figure, int]: | |
| C = W.shape[0] | |
| # 電極位置を取得 | |
| if electrode_pos is None or electrode_pos.shape[0] != C: | |
| # デフォルト: 円形配置 | |
| angles = np.linspace(0, 2 * np.pi, C, endpoint=False) | |
| xs = np.cos(angles) | |
| ys = np.sin(angles) | |
| else: | |
| xs = electrode_pos[:, 0] | |
| ys = electrode_pos[:, 1] | |
| edges = threshold_edges(W, thr) | |
| fig = go.Figure() | |
| # エッジの重みの範囲を取得(色と太さのスケーリング用) | |
| if edges: | |
| weights = [w for _, _, w in edges] | |
| min_w = min(weights) | |
| max_w = max(weights) | |
| weight_range = max_w - min_w if max_w > min_w else 1.0 | |
| else: | |
| min_w = 0 | |
| max_w = 1 | |
| weight_range = 1.0 | |
| # レインボーカラーマップ関数 (0=青 → 0.5=緑/黄 → 1=赤) | |
| def get_rainbow_color(norm_val): | |
| """正規化された値 (0-1) からレインボーカラーを生成""" | |
| import colorsys | |
| # HSVのHue: 240°(青) → 0°(赤) に変換 | |
| hue = (1.0 - norm_val) * 0.67 # 0.67 ≈ 240/360 (青) | |
| r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95) | |
| return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)' | |
| # エッジを描画(重みに応じて色と太さを変える) | |
| # --- 有向のときだけ:矢印(三角マーカー)を終端側に置く --- | |
| is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0) | |
| if (not is_symmetric): | |
| curve_strength = 0.1 # 湾曲の強さ(要調整) | |
| node_radius = 0.08 # ノード中心からどれくらい手前に終点/矢印を置くか(要調整) | |
| bezier_n = 18 # 曲線の分割数(増やすほど滑らか) | |
| t_arrow = 0.90 # 矢印を置く位置(0〜1) | |
| for (i, j, w) in edges: | |
| norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5 | |
| color = get_rainbow_color(norm_w) | |
| line_width = 0.5 + 3.5 * norm_w | |
| p0 = np.array([xs[i], ys[i]], dtype=float) | |
| p1 = np.array([xs[j], ys[j]], dtype=float) | |
| v = p1 - p0 | |
| dist = np.hypot(v[0], v[1]) | |
| if dist < 1e-9: | |
| continue | |
| u = v / dist | |
| # ノードに重ならないよう端点を縮める | |
| p0s = p0 + u * node_radius | |
| p1s = p1 - u * node_radius | |
| # 垂直方向(法線) | |
| n = np.array([-u[1], u[0]]) | |
| # ★ 有向エッジは全部曲げる(規則的に) | |
| sign = 1.0 #if i < j else -1.0 | |
| # 制御点 | |
| mid = 0.5 * (p0s + p1s) | |
| c = mid + sign * curve_strength * dist * n | |
| # 曲線点列 | |
| pts = _quad_bezier_points(p0s, p1s, c, n=bezier_n) | |
| fig.add_trace(go.Scatter( | |
| x=pts[:, 0], | |
| y=pts[:, 1], | |
| mode="lines", | |
| hoverinfo="text", | |
| hovertext=f"ch{i} → ch{j}<br>weight: {w:.4f}", | |
| line=dict(width=line_width, color=color), | |
| showlegend=False, | |
| )) | |
| # 矢印(曲線接線方向) | |
| pt, tan = _quad_bezier_point_and_tangent(p0s, p1s, c, t_arrow) | |
| # 接線がゼロに近い場合の保険 | |
| tx, ty = float(tan[0]), float(tan[1]) | |
| if tx*tx + ty*ty < 1e-18: | |
| tx, ty = float(p1s[0] - p0s[0]), float(p1s[1] - p0s[1]) | |
| theta = np.degrees(np.arctan2(ty, tx)) # 接線の角度(+x基準) | |
| ANGLE_OFFSET = -90.0 # triangle-up(上向き) を接線方向に合わせる補正 | |
| ang = (theta + ANGLE_OFFSET) % 360 | |
| fig.add_trace(go.Scatter( | |
| x=[pt[0]], | |
| y=[pt[1]], | |
| mode="markers", | |
| hoverinfo="skip", | |
| marker=dict( | |
| symbol="triangle-up", | |
| size=10, | |
| angle=-ang, | |
| angleref="up", | |
| color=color, | |
| line=dict(width=0), | |
| ), | |
| showlegend=False, | |
| )) | |
| else: | |
| for (i, j, w) in edges: | |
| # 正規化された重み (0-1) | |
| norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5 | |
| # レインボーカラー: 弱い(青) → 中間(緑/黄) → 強い(赤) | |
| color = get_rainbow_color(norm_w) | |
| # 太さ: 重みに比例 (0.5-4の範囲) | |
| line_width = 0.5 + 3.5 * norm_w | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[xs[i], xs[j]], | |
| y=[ys[i], ys[j]], | |
| mode="lines", | |
| hoverinfo="text", | |
| hovertext=f"ch{i} - ch{j}<br>weight: {w:.4f}", | |
| line=dict(width=line_width, color=color), | |
| showlegend=False, | |
| ) | |
| ) | |
| # Louvainクラスタリング | |
| if use_louvain and LOUVAIN_AVAILABLE: | |
| clusters = compute_louvain_clusters(W, thr) | |
| node_colors = get_cluster_colors(clusters) | |
| n_clusters = len(np.unique(clusters)) | |
| title_suffix = f" | Louvain clusters: {n_clusters}" | |
| else: | |
| node_colors = ['#FFD700'] * C # デフォルトのゴールド | |
| clusters = np.zeros(C, dtype=int) | |
| title_suffix = "" | |
| # ノードを描画 | |
| fig.add_trace( | |
| go.Scatter( | |
| x=xs, | |
| y=ys, | |
| mode="markers+text", | |
| text=[f"{k}" for k in range(C)], | |
| textposition="bottom center", | |
| textfont=dict(size=8), | |
| marker=dict( | |
| size=14, | |
| color=node_colors, | |
| line=dict(width=2, color='white') | |
| ), | |
| hoverinfo="text", | |
| hovertext=[f"channel {k}<br>cluster: {clusters[k]}" for k in range(C)], | |
| showlegend=False, | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=f"Estimated Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}", | |
| height=600, | |
| xaxis=dict(visible=False), | |
| yaxis=dict(visible=False), | |
| margin=dict(l=10, r=10, t=50, b=50), | |
| paper_bgcolor='rgba(0,0,0,0.9)', | |
| plot_bgcolor='rgba(0,0,0,0.9)', | |
| ) | |
| fig.update_yaxes(scaleanchor="x", scaleratio=1) | |
| # カラーバー的な説明を追加 | |
| if edges: | |
| fig.add_annotation( | |
| text=f"Edge color/width: weak (blue/thin) → medium (green/yellow) → strong (red/thick)<br>Weight range: {min_w:.3f} - {max_w:.3f}", | |
| xref="paper", yref="paper", | |
| x=0.5, y=-0.05, | |
| showarrow=False, | |
| font=dict(size=10, color='white'), | |
| xanchor='center', | |
| ) | |
| return fig, len(edges) | |
| def make_edgecount_curve(W: np.ndarray) -> go.Figure: | |
| vals = np.sort(W[np.triu_indices(W.shape[0], k=1)]) | |
| thr_grid = np.linspace(float(vals.max()), float(vals.min()), 120) if vals.size else np.array([0.0]) | |
| counts = [len(threshold_edges(W, float(thr))) for thr in thr_grid] | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(x=thr_grid, y=counts, mode="lines")) | |
| fig.update_layout( | |
| title="Edge count vs threshold (lower thr => more edges)", | |
| xaxis_title="threshold", | |
| yaxis_title="edge count", | |
| height=300, | |
| ) | |
| return fig | |
| def to_csv_bytes_matrix(mat: np.ndarray, fmt: str) -> bytes: | |
| buf = io.StringIO() | |
| np.savetxt(buf, mat, delimiter=",", fmt=fmt) | |
| return buf.getvalue().encode("utf-8") | |
| def to_csv_bytes_edges(edges: List[Tuple[int, int, float]]) -> bytes: | |
| buf = io.StringIO() | |
| buf.write("source,target,weight\n") | |
| for i, j, w in edges: | |
| buf.write(f"{i},{j},{w:.6f}\n") | |
| return buf.getvalue().encode("utf-8") | |
| # ============================================================ | |
| # Sidebar UI | |
| # ============================================================ | |
| st.sidebar.header("Input format") | |
| input_mode = st.sidebar.radio("データ形式", ["EEGLAB (.set + .fdt)", "MATLAB (.mat)"], index=0) | |
| st.sidebar.header("Preprocess (auto)") | |
| f_low_src = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=4.0, step=1.0, key="low_src") | |
| f_high_src = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=8.0, step=1.0, key="high_src") | |
| st.sidebar.header("if you use CFC+PAC:") | |
| f_low_tgt = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=25.0, step=1.0, key="low_tgt") | |
| f_high_tgt = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=40.0, step=1.0, key="high_tgt") | |
| st.sidebar.header("Viewer controls") | |
| win_sec = st.sidebar.number_input("Window length (sec)", min_value=0.1, value=5.0, step=0.1) | |
| decim = st.sidebar.selectbox("Decimation (間引き)", options=[1, 2, 5, 10, 20, 50], index=1) | |
| offset_mode = st.sidebar.checkbox("重ね描画のオフセット表示", value=True) | |
| show_rangeslider = st.sidebar.checkbox("Plotly rangesliderを表示", value=False) | |
| signal_view = st.sidebar.radio( | |
| "表示する信号", | |
| ["raw", "filtered", "amplitude", "phase"], | |
| index=1, | |
| help="raw: 生信号, filtered: バンドパス後, amplitude: Hilbert振幅(envelope), phase: Hilbert位相" | |
| ) | |
| st.title("EEG timeseries viewer + network estimation") | |
| # ============================================================ | |
| # Load + preprocess (EEGLAB / MAT) | |
| # ============================================================ | |
| if input_mode.startswith("EEGLAB"): | |
| st.sidebar.header("Upload (.set + .fdt)") | |
| uploaded_files = st.sidebar.file_uploader( | |
| "Upload EEGLAB files", | |
| type=["set", "fdt"], | |
| accept_multiple_files=True, | |
| ) | |
| if uploaded_files: | |
| set_file, fdt_file = pick_set_fdt(uploaded_files) | |
| if set_file is None or fdt_file is None: | |
| st.warning("`.set` と `.fdt` の両方をアップロードしてください。") | |
| else: | |
| try: | |
| with st.spinner("Loading EEGLAB + preprocessing (bandpass + hilbert)..."): | |
| prep_src = preprocess_all_eeglab( | |
| set_bytes=set_file.getvalue(), | |
| fdt_bytes=fdt_file.getvalue(), | |
| set_name=set_file.name, | |
| fdt_name=fdt_file.name, | |
| f_low=float(f_low_src), | |
| f_high=float(f_high_src), | |
| ) | |
| prep_tgt = preprocess_all_eeglab( | |
| set_bytes=set_file.getvalue(), | |
| fdt_bytes=fdt_file.getvalue(), | |
| set_name=set_file.name, | |
| fdt_name=fdt_file.name, | |
| f_low=float(f_low_tgt), | |
| f_high=float(f_high_tgt), | |
| ) | |
| st.session_state["prep"] = prep_src | |
| st.session_state["prep_tgt"] = prep_tgt | |
| st.session_state["W"] = None | |
| st.success(f"Loaded & preprocessed. (T,C)={prep_src['raw'].shape} fs={prep_src['fs']:.2f}Hz") | |
| except Exception as e: | |
| st.session_state.pop("prep", None) | |
| st.session_state["W"] = None | |
| st.error(f"読み込み/前処理エラー: {e}") | |
| else: | |
| st.sidebar.header("Upload (.mat)") | |
| mat_file = st.sidebar.file_uploader("Upload .mat", type=["mat"]) | |
| if mat_file is not None: | |
| mat_bytes = mat_file.getvalue() | |
| try: | |
| cands = load_mat_candidates_cached(mat_bytes) | |
| if not cands: | |
| st.error("数値の1D/2D配列が見つかりませんでした。") | |
| st.info("MATファイルの構造を確認しています...") | |
| # デバッグ: MATファイルの中身を表示 | |
| try: | |
| from scipy.io import loadmat | |
| mat_data = loadmat(io.BytesIO(mat_bytes)) | |
| st.write("**MATファイルに含まれる変数:**") | |
| for k, v in mat_data.items(): | |
| if not k.startswith('__'): | |
| if isinstance(v, np.ndarray): | |
| st.write(f"- `{k}`: shape={v.shape}, dtype={v.dtype}, ndim={v.ndim}") | |
| else: | |
| st.write(f"- `{k}`: type={type(v).__name__}") | |
| except Exception as e: | |
| st.write(f"デバッグ情報の取得に失敗: {e}") | |
| # HDF5形式の場合も試す | |
| try: | |
| import h5py | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(suffix='.mat', delete=False) as tmp: | |
| tmp.write(mat_bytes) | |
| tmp_path = tmp.name | |
| st.write("**HDF5形式として読み込み中...**") | |
| with h5py.File(tmp_path, 'r') as f: | |
| def show_structure(name, obj): | |
| if isinstance(obj, h5py.Dataset): | |
| st.write(f"- `{name}`: shape={obj.shape}, dtype={obj.dtype}") | |
| f.visititems(show_structure) | |
| import os | |
| os.unlink(tmp_path) | |
| except Exception as e2: | |
| st.write(f"HDF5としても読み込めませんでした: {e2}") | |
| else: | |
| key = st.sidebar.selectbox("EEG配列(変数)を選択", options=list(cands.keys())) | |
| fs_mat = st.sidebar.number_input("Sampling rate (Hz)", min_value=0.1, value=256.0, step=0.1) | |
| # 変数が選択されたら自動的に前処理を実行 | |
| if key: | |
| x = cands[key] | |
| st.sidebar.write(f"選択した配列: shape={x.shape}, dtype={x.dtype}") | |
| try: | |
| with st.spinner("Preprocessing (bandpass + hilbert)..."): | |
| cfg = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_src), f_high=float(f_high_src)) | |
| prep = preprocess_tc(x, cfg) | |
| cfg_tgt = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_tgt), f_high=float(f_high_tgt)) | |
| prep_tgt = preprocess_tc(x, cfg_tgt) | |
| st.session_state["prep"] = prep | |
| st.session_state["prep_tgt"] = prep_tgt | |
| st.session_state["W"] = None | |
| st.success(f"Loaded MAT '{key}'. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz") | |
| except Exception as e: | |
| st.session_state.pop("prep", None) | |
| st.session_state["W"] = None | |
| st.error(f"前処理エラー: {e}") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| except Exception as e: | |
| st.session_state.pop("prep", None) | |
| st.session_state["W"] = None | |
| st.error(f".mat 読み込みエラー: {e}") | |
| import traceback | |
| st.code(traceback.format_exc()) | |
| if "prep" not in st.session_state: | |
| st.info("左のサイドバーからデータをアップロードしてください。") | |
| st.stop() | |
| # ============================================================ | |
| # Viewer | |
| # ============================================================ | |
| prep = st.session_state["prep"] | |
| fs = float(prep["fs"]) | |
| X_tc = prep[signal_view] | |
| T, C = X_tc.shape | |
| duration_sec = (T - 1) / fs if T > 1 else 0.0 | |
| max_start = max(0.0, float(duration_sec - win_sec)) | |
| start_sec = st.sidebar.slider( | |
| "Start time (sec)", | |
| min_value=0.0, | |
| max_value=float(max_start), | |
| value=0.0, | |
| step=float(max(0.01, win_sec / 200)), | |
| ) | |
| st.sidebar.header("Channels") | |
| # チャンネル選択の便利機能 | |
| col_ch1, col_ch2 = st.sidebar.columns(2) | |
| with col_ch1: | |
| select_all = st.button("全選択") | |
| with col_ch2: | |
| deselect_all = st.button("全解除") | |
| # 範囲選択 | |
| with st.sidebar.expander("📊 範囲で選択"): | |
| range_start = st.number_input("開始ch", min_value=0, max_value=C-1, value=0, step=1) | |
| range_end = st.number_input("終了ch", min_value=0, max_value=C-1, value=min(C-1, 7), step=1) | |
| if st.button("範囲を選択"): | |
| st.session_state["selected_channels"] = list(range(int(range_start), int(range_end) + 1)) | |
| # プリセット選択 | |
| with st.sidebar.expander("⚡ プリセット"): | |
| preset_col1, preset_col2 = st.columns(2) | |
| with preset_col1: | |
| if st.button("前頭部 (0-15)"): | |
| st.session_state["selected_channels"] = list(range(min(16, C))) | |
| with preset_col2: | |
| if st.button("頭頂部 (16-31)"): | |
| st.session_state["selected_channels"] = list(range(16, min(32, C))) | |
| preset_col3, preset_col4 = st.columns(2) | |
| with preset_col3: | |
| if st.button("側頭部 (32-47)"): | |
| st.session_state["selected_channels"] = list(range(32, min(48, C))) | |
| with preset_col4: | |
| if st.button("後頭部 (48-63)"): | |
| st.session_state["selected_channels"] = list(range(48, min(64, C))) | |
| # セッションステートの初期化 | |
| if "selected_channels" not in st.session_state: | |
| st.session_state["selected_channels"] = list(range(min(C, 8))) | |
| # ボタンによる選択の処理 | |
| if select_all: | |
| st.session_state["selected_channels"] = list(range(C)) | |
| if deselect_all: | |
| st.session_state["selected_channels"] = [] | |
| # メインの選択UI(最大表示数を制限) | |
| max_display = 20 # multiselect で一度に表示する数を制限 | |
| if C <= max_display: | |
| selected_channels = st.sidebar.multiselect( | |
| f"表示するチャンネル(全{C}ch)", | |
| options=list(range(C)), | |
| default=st.session_state["selected_channels"], | |
| key="ch_select", | |
| ) | |
| else: | |
| # 大量のチャンネルがある場合は、選択済みのものだけ表示 | |
| st.sidebar.caption(f"選択中: {len(st.session_state['selected_channels'])} / {C} channels") | |
| # 個別追加 | |
| add_ch = st.sidebar.number_input( | |
| "チャンネルを追加", | |
| min_value=0, | |
| max_value=C-1, | |
| value=0, | |
| step=1, | |
| key="add_ch_input" | |
| ) | |
| col_add, col_remove = st.sidebar.columns(2) | |
| with col_add: | |
| if st.button("➕ 追加"): | |
| if add_ch not in st.session_state["selected_channels"]: | |
| st.session_state["selected_channels"].append(int(add_ch)) | |
| st.session_state["selected_channels"].sort() | |
| with col_remove: | |
| if st.button("➖ 削除"): | |
| if add_ch in st.session_state["selected_channels"]: | |
| st.session_state["selected_channels"].remove(int(add_ch)) | |
| # 現在の選択を表示 | |
| if st.session_state["selected_channels"]: | |
| selected_str = ", ".join(map(str, st.session_state["selected_channels"][:10])) | |
| if len(st.session_state["selected_channels"]) > 10: | |
| selected_str += f", ... (+{len(st.session_state['selected_channels']) - 10})" | |
| st.sidebar.text(f"選択済み: {selected_str}") | |
| selected_channels = st.session_state["selected_channels"] | |
| # セッションステートを更新(multiselectを使った場合) | |
| if C <= max_display: | |
| st.session_state["selected_channels"] = selected_channels | |
| col1, col2 = st.columns([2, 1]) | |
| with col1: | |
| fig_ts = make_timeseries_figure( | |
| X_tc=X_tc, | |
| selected_channels=selected_channels, | |
| fs=fs, | |
| start_sec=float(start_sec), | |
| win_sec=float(win_sec), | |
| decim=int(decim), | |
| offset_mode=bool(offset_mode), | |
| show_rangeslider=bool(show_rangeslider), | |
| signal_type=signal_view, | |
| ) | |
| st.plotly_chart(fig_ts) | |
| with col2: | |
| st.subheader("Data info") | |
| signal_desc = { | |
| "raw": "生信号(前処理なし)", | |
| "filtered": f"バンドパスフィルタ後 ({f_low_src}-{f_high_src} Hz)", | |
| "amplitude": "Hilbert振幅 (envelope)", | |
| "phase": "Hilbert位相 (-π ~ π)" | |
| } | |
| st.write(f"- view: **{signal_view}** ({signal_desc.get(signal_view, '')})") | |
| st.write(f"- fs: **{fs:.2f} Hz**") | |
| st.write(f"- T: {T} samples") | |
| st.write(f"- C: {C} channels") | |
| st.write(f"- duration: {duration_sec:.2f} sec") | |
| if signal_view == "phase": | |
| st.caption("※ 位相は -π (rad) から π (rad) の範囲で表示されます") | |
| st.caption("※ 大規模データは window + decimation 推奨。rangesliderは重い場合OFF。") | |
| st.divider() | |
| # ============================================================ | |
| # Estimation | |
| # ============================================================ | |
| st.subheader("Network estimation") | |
| # 推定手法の選択 | |
| estimation_method = st.radio( | |
| "推定手法を選択", | |
| options=[ | |
| "envelope_corr", | |
| "phase_PLV", | |
| "phase_corr", | |
| "pac_tort", | |
| "pac_chatterjee" | |
| ], | |
| format_func=lambda x: { | |
| "envelope_corr": "Envelope Pearson correlation (振幅の相関)", | |
| "phase_PLV": "PLV(位相同期, PLV)", | |
| "phase_corr": "Circular correlation", | |
| "pac_tort": "Modulation Index(位相と振幅のPAC指標)", | |
| "pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)", | |
| }[x], | |
| horizontal=True, | |
| help="envelope_corr: 振幅包絡線のPearson相関係数 | phase_PLV: 位相のPhase Locking Value | phase_corr: 位相の相関係数 | pac_tort: Modulation index | pac_chatterjee: 位相から振幅へのChatterjee相関", | |
| ) | |
| # 推定手法の説明 | |
| method_info = { | |
| "envelope_corr": "**Envelope correlation**: 振幅包絡線(Hilbert amplitude)間のPearson相関係数を計算します。振幅が同期して変動するチャンネル間の結合を検出します。", | |
| "phase_PLV": "**PLV**: 位相間のPhase locking valueを計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。", | |
| "phase_corr": "**Circular correlation**: 位相間の相関係数を計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。", | |
| "pac_tort": "Modulation Index(位相と振幅のPAC指標)", | |
| "pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)", | |
| } | |
| st.info(method_info[estimation_method]) | |
| # セッションステートから前回の手法と W を取得 | |
| last_method = st.session_state.get("last_estimation_method") | |
| W = st.session_state.get("W") | |
| # 推定が必要かチェック(初回 or 手法変更) | |
| need_estimation = (W is None) or (last_method != estimation_method) | |
| if need_estimation: | |
| progress = st.progress(0.0) | |
| with st.spinner(f"推定中... ({estimation_method})"): | |
| if estimation_method == "envelope_corr": | |
| X_in = prep["amplitude"] | |
| W = estimate_network_envelope_corr(X_in) | |
| elif estimation_method == "phase_PLV": | |
| X_in = prep["phase"] | |
| W = estimate_network_phase_PLV(X_in, progress) | |
| elif estimation_method == "phase_corr": | |
| X_in = prep["phase"] | |
| W = estimate_network_phase_corr(X_in) | |
| elif estimation_method == "pac_tort": | |
| X_in_low_phase = prep["phase"] | |
| prep_tgt = st.session_state["prep_tgt"] | |
| X_in_high_amplitude = prep_tgt["amplitude"] | |
| W = estimate_network_pac_tort(X_in_low_phase,X_in_high_amplitude,progress) | |
| elif estimation_method == "pac_chatterjee": | |
| X_in_low_phase = prep["phase"] | |
| prep_tgt = st.session_state["prep_tgt"] | |
| X_in_high_amplitude = prep_tgt["amplitude"] | |
| W = estimate_network_pac_chatterjee(X_in_low_phase,X_in_high_amplitude,progress) | |
| else: | |
| st.error("未知の推定手法です") | |
| st.stop() | |
| # セッションステートに保存 | |
| st.session_state["W"] = W | |
| st.session_state["last_estimation_method"] = estimation_method | |
| st.success(f"✅ 推定完了: {estimation_method} (ネットワークサイズ: {W.shape[0]} nodes)") | |
| else: | |
| st.success(f"✓ 推定済み: **{estimation_method}** (ネットワークサイズ: {W.shape[0]} nodes)") | |
| # この時点で W は必ず存在する | |
| # 閾値スライダーとネットワーク図の表示 | |
| wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0 | |
| col_thr1, col_thr2 = st.columns([3, 1]) | |
| with col_thr1: | |
| thr = st.slider( | |
| "閾値 (threshold) ※下げるほどエッジが増えます", | |
| min_value=0.0, | |
| max_value=max(0.0001, wmax), | |
| value=wmax/2, | |
| step=max(wmax / 200, 0.001), | |
| ) | |
| with col_thr2: | |
| use_louvain = st.checkbox( | |
| "Louvainクラスタ", | |
| value=True, | |
| disabled=not LOUVAIN_AVAILABLE, | |
| help="ノードの色をコミュニティ検出結果で塗り分けます" | |
| ) | |
| # 電極位置を取得 | |
| electrode_pos = prep.get("electrode_pos", None) | |
| # 2D座標を90度左回転(上が正面になる向きに) | |
| if electrode_pos is not None: | |
| electrode_pos = np.asarray(electrode_pos, dtype=np.float32) | |
| if electrode_pos.ndim == 2 and electrode_pos.shape[1] >= 2: | |
| pos2 = electrode_pos[:, :2] | |
| electrode_pos = np.column_stack([-pos2[:, 1], pos2[:, 0]]) | |
| electrode_pos_3d = prep.get("electrode_pos_3d", None) | |
| if electrode_pos is not None: | |
| st.info(f"✓ 電極位置を使用してネットワークを配置 ({electrode_pos.shape[0]} channels)") | |
| else: | |
| st.info("ℹ️ 電極位置が取得できなかったため、円形配置を使用します") | |
| # 3D座標の有無を表示 | |
| if electrode_pos_3d is not None: | |
| st.success(f"✓ 3D電極座標を取得しました ({electrode_pos_3d.shape[0]} channels) - 下部に3Dビューアを表示します") | |
| else: | |
| st.info("ℹ️ 3D電極座標が取得できませんでした - 2D表示のみです") | |
| net_col1, net_col2 = st.columns([2, 1]) | |
| with net_col1: | |
| fig_net, edge_n = make_network_figure( | |
| W, | |
| float(thr), | |
| use_louvain=use_louvain, | |
| electrode_pos=electrode_pos, | |
| ) | |
| st.plotly_chart(fig_net) | |
| # 3Dネットワーク表示(3D座標がある場合のみ) | |
| if electrode_pos_3d is not None: | |
| electrode_pos_3d = np.asarray(electrode_pos_3d, dtype=np.float32) | |
| if electrode_pos_3d.ndim == 2 and electrode_pos_3d.shape[0] == W.shape[0] and electrode_pos_3d.shape[1] == 3: | |
| st.subheader("3D Viewer") | |
| fig_3d = make_network_figure_3d( | |
| W=W, | |
| thr=float(thr), | |
| electrode_pos_3d=electrode_pos_3d, | |
| use_louvain=use_louvain, | |
| ) | |
| st.plotly_chart( | |
| fig_3d, | |
| width="stretch", | |
| config={"displayModeBar": True, "scrollZoom": True}, | |
| ) | |
| else: | |
| st.warning(f"3D座標のshapeが不正です: {electrode_pos_3d.shape}(期待: (C,3), C={W.shape[0]})") | |
| with net_col2: | |
| st.metric("Edges", edge_n) | |
| st.plotly_chart(make_edgecount_curve(W)) | |
| st.write("# Hypothesis testing") | |
| st.write("Coming soon ...") |