EEG-Network-Viewer / src /streamlit_app.py
stardust-coder's picture
[add] directed graph
e648c90
import io
from dataclasses import dataclass
from typing import List, Tuple
import numpy as np
import streamlit as st
import plotly.graph_objects as go
import mne
from scipy.signal import hilbert
try:
import community as community_louvain
import networkx as nx
LOUVAIN_AVAILABLE = True
except ImportError:
LOUVAIN_AVAILABLE = False
st.warning("⚠️ Louvainクラスタリングを使用するには `pip install python-louvain networkx` を実行してください。")
from loader import (
pick_set_fdt,
load_eeglab_tc_from_bytes,
load_mat_candidates,
)
import metrics
st.set_page_config(page_title="EEG Viewer + Network Estimation", layout="wide")
# ============================================================
# Preprocess config
# ============================================================
@dataclass(frozen=True)
class PreprocessConfig:
fs: float
f_low: float
f_high: float
# ============================================================
# Helpers
# ============================================================
def ensure_tc(x: np.ndarray) -> np.ndarray:
"""Ensure array is (T,C). Accept (T,), (T,C), (C,T) with heuristic transpose."""
x = np.asarray(x)
if x.ndim == 1:
return x[:, None]
if x.ndim != 2:
raise ValueError(f"2次元配列のみ対応です: shape={x.shape}")
T, C = x.shape
if T <= 256 and C > T: # heuristic transpose
x = x.T
return x
def _quad_bezier_points(p0, p1, c, n=20):
"""2次Bezierを点列にして返す (n点)"""
ts = np.linspace(0, 1, n)
pts = (1-ts)[:,None]**2 * p0 + 2*(1-ts)[:,None]*ts[:,None]*c + ts[:,None]**2 * p1
return pts # shape (n,2)
def _quad_bezier_point_and_tangent(p0, p1, c, t):
"""2次Bezierの点と接線ベクトル(微分)を返す"""
# B(t) = (1-t)^2 p0 + 2(1-t)t c + t^2 p1
pt = (1-t)**2 * p0 + 2*(1-t)*t * c + t**2 * p1
# B'(t) = 2(1-t)(c-p0) + 2t(p1-c)
tan = 2*(1-t)*(c-p0) + 2*t*(p1-c)
return pt, tan
# ============================================================
# Signal processing
# ============================================================
def bandpass_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> np.ndarray:
"""Bandpass filter each channel using MNE RawArray. Input/Output: (T,C)."""
info = mne.create_info(
ch_names=[f"ch{i}" for i in range(x_tc.shape[1])],
sfreq=float(cfg.fs),
ch_types="eeg",
)
raw = mne.io.RawArray(x_tc.T, info, verbose=False) # (C,T)
raw_filt = raw.copy().filter(l_freq=cfg.f_low, h_freq=cfg.f_high, verbose=False)
return raw_filt.get_data().T.astype(np.float32)
def hilbert_envelope_tc(x_tc: np.ndarray) -> np.ndarray:
"""Hilbert envelope per channel using SciPy. Input/Output: (T,C)."""
analytic = hilbert(x_tc, axis=0)
return np.abs(analytic).astype(np.float32)
def hilbert_phase_tc(x_tc: np.ndarray) -> np.ndarray:
"""Hilbert phase per channel using SciPy. Input/Output: (T,C)."""
analytic = hilbert(x_tc, axis=0)
return np.angle(analytic).astype(np.float32)
def preprocess_tc(x_tc: np.ndarray, cfg: PreprocessConfig) -> dict:
"""raw(T,C) -> filtered/envelope/phase をまとめて返す"""
x_tc = ensure_tc(x_tc).astype(np.float32)
x_filt = bandpass_tc(x_tc, cfg)
env = hilbert_envelope_tc(x_filt)
phase = hilbert_phase_tc(x_filt)
return {
"fs": float(cfg.fs),
"raw": x_tc,
"filtered": x_filt,
"envelope": env,
"amplitude": env, # envelope のエイリアス
"phase": phase
}
@st.cache_data(show_spinner=False)
def preprocess_all_eeglab(
set_bytes: bytes,
fdt_bytes: bytes,
set_name: str,
fdt_name: str,
f_low: float,
f_high: float,
) -> dict:
"""
EEGLAB bytes -> load -> auto preprocess (bandpass + hilbert).
fsは読み込んだデータのものを使う。
"""
x_tc, fs, electrode_pos_2d, electrode_pos_3d = load_eeglab_tc_from_bytes(
set_bytes=set_bytes,
set_name=set_name,
fdt_bytes=fdt_bytes,
fdt_name=fdt_name,
)
cfg = PreprocessConfig(fs=float(fs), f_low=float(f_low), f_high=float(f_high))
result = preprocess_tc(x_tc, cfg)
# 電極位置を追加
if electrode_pos_2d is not None:
result["electrode_pos"] = electrode_pos_2d
if electrode_pos_3d is not None:
result["electrode_pos_3d"] = electrode_pos_3d
return result
@st.cache_data(show_spinner=False)
def load_mat_candidates_cached(mat_bytes: bytes) -> dict:
"""MAT candidatesをキャッシュ(UI操作で毎回読まない)"""
return load_mat_candidates(mat_bytes)
# ============================================================
# Viewer
# ============================================================
def window_slice(X_tc: np.ndarray, start_idx: int, end_idx: int, decim: int) -> np.ndarray:
start_idx = max(0, min(start_idx, X_tc.shape[0] - 1))
end_idx = max(start_idx + 1, min(end_idx, X_tc.shape[0]))
decim = max(1, int(decim))
return X_tc[start_idx:end_idx:decim, :]
def make_timeseries_figure(
X_tc: np.ndarray,
selected_channels: List[int],
fs: float,
start_sec: float,
win_sec: float,
decim: int,
offset_mode: bool,
show_rangeslider: bool,
signal_type: str = "filtered",
) -> go.Figure:
start_idx = int(round(start_sec * fs))
end_idx = int(round((start_sec + win_sec) * fs))
Xw = window_slice(X_tc, start_idx, end_idx, decim)
Tw = Xw.shape[0]
t = (np.arange(Tw) * decim + start_idx) / fs
fig = go.Figure()
if not selected_channels:
fig.update_layout(
title="Timeseries (no channel selected)",
height=450,
xaxis_title="time (s)",
yaxis_title="amplitude",
)
return fig
# 位相データの場合は特別な処理
is_phase = signal_type == "phase"
if offset_mode and len(selected_channels) > 1 and not is_phase:
per_ch_std = np.std(Xw[:, selected_channels], axis=0)
base = float(np.median(per_ch_std)) if np.isfinite(np.median(per_ch_std)) and np.median(per_ch_std) > 0 else 1.0
offset = 5.0 * base
for k, ch in enumerate(selected_channels):
y = Xw[:, ch] + k * offset
fig.add_trace(go.Scatter(x=t, y=y, mode="lines", name=f"ch{ch}", line=dict(width=1)))
ylab = "amplitude (offset)"
else:
for ch in selected_channels:
fig.add_trace(go.Scatter(x=t, y=Xw[:, ch], mode="lines", name=f"ch{ch}", line=dict(width=1)))
if is_phase:
ylab = "phase (rad)"
else:
ylab = "amplitude"
# rangeslider の高さを考慮して調整
plot_height = 550 if show_rangeslider else 450
bottom_margin = 150 if show_rangeslider else 80
title_text = f"Timeseries: {signal_type} (window={win_sec:.2f}s, start={start_sec:.2f}s, decim={decim})"
fig.update_layout(
title=title_text,
height=plot_height,
xaxis_title="time (s)",
yaxis_title=ylab,
legend=dict(orientation="h"),
margin=dict(l=60, r=20, t=80, b=bottom_margin),
)
# 位相の場合は y軸の範囲を -π ~ π に固定
if is_phase:
fig.update_yaxes(range=[-np.pi - 0.5, np.pi + 0.5])
if show_rangeslider:
fig.update_xaxes(
rangeslider=dict(
visible=True,
thickness=0.05,
)
)
else:
fig.update_xaxes(rangeslider=dict(visible=False))
return fig
# ============================================================
# Network (multiple methods) + export
# ============================================================
def estimate_network_envelope_corr(X_tc: np.ndarray) -> np.ndarray:
"""
Envelope (amplitude) の Pearson 相関係数を計算。
Input: X_tc (T, C) - envelope データ
Output: W (C, C) - 相関係数の絶対値
"""
X = X_tc - X_tc.mean(axis=0, keepdims=True)
corr = np.corrcoef(X, rowvar=False)
W = np.abs(corr)
np.fill_diagonal(W, 0.0)
return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
def estimate_network_phase_corr(X_tc: np.ndarray) -> np.ndarray:
"""
Phase の PLV を計算。
Input: X_tc (T, C) - phase データ (ラジアン)
Output: W (C, C) - circular correlation
circular correlationは以下で計算:
"""
T, C = X_tc.shape
W = np.zeros((C, C), dtype=np.float32)
# 各チャンネルペアについて PLV を計算
for i in range(C):
for j in range(i + 1, C):
#Jammalamadaka–Sengupta circular correlation
corr = metrics.circular_correlation(X_tc[:, i], X_tc[:, j])
W[i, j] = corr
W[j, i] = corr
return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
def estimate_network_phase_PLV(X_tc: np.ndarray, progress) -> np.ndarray:
"""
Phase の PLV を計算。
Input: X_tc (T, C) - phase データ (ラジアン)
Output: W (C, C) - PLV
PLV は以下で計算:
r_ij = |⟨exp(i*(θ_i - θ_j))⟩_t|
"""
T, C = X_tc.shape
W = np.zeros((C, C), dtype=np.float32)
# 各チャンネルペアについて PLV を計算
tmp_ = 0
for i in range(C):
for j in range(i + 1, C):
# 位相差
phase_diff = X_tc[:, i] - X_tc[:, j]
plv = np.abs(np.mean(np.exp(1j * phase_diff)))
W[i, j] = plv
W[j, i] = plv
tmp_ += 1
progress.progress(tmp_ / (int(C*(C-1)/2)))
return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
def estimate_network_pac_tort(X_tc1, X_tc2, progress):
"""
PACを目的としてModulation Indexを計算
Input: X_tc1 (T, C) - phase データ (ラジアン)
Input: X_tc2 (T, C) - envelope データ
Output: W (C, C) - Modulation Index
"""
assert X_tc1.shape == X_tc2.shape
T, C = X_tc1.shape
W = np.zeros((C, C), dtype=np.float32)
# 各チャンネルペアについて Chatterjee correlation を計算
tmp_ = 0
for i in range(C):
for j in range(C):
if i == j:
continue
# Modulation Index from Tort et al.(2010)
mi_ = metrics.modulation_index(X_tc1[:, i], X_tc2[:, j])
W[i, j] = mi_
tmp_ += 1
progress.progress(tmp_ / (C*C))
return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
def estimate_network_pac_chatterjee(X_tc1, X_tc2, progress):
"""
PACを目的としてChatterjee相関を計算
Input: X_tc1 (T, C) - phase データ (ラジアン)
Input: X_tc2 (T, C) - envelope データ
Output: W (C, C) - Chatterjee correlation from phase to envelope
"""
assert X_tc1.shape == X_tc2.shape
T, C = X_tc1.shape
W = np.zeros((C, C), dtype=np.float32)
# 各チャンネルペアについて Chatterjee correlation を計算
tmp_ = 0
for i in range(C):
for j in range(C):
if i == j:
continue
# Chatterjee相関係数
corr_ = metrics.chatterjee_phase_to_amp(X_tc1[:, i], X_tc2[:, j])
W[i, j] = corr_
tmp_ += 1
progress.progress(tmp_ / (C*C))
return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
def estimate_network_dummy(X_tc: np.ndarray) -> np.ndarray:
"""
ダミー実装: 単純な相関係数の絶対値
(後方互換性のため残す)
"""
X = X_tc - X_tc.mean(axis=0, keepdims=True)
corr = np.corrcoef(X, rowvar=False)
W = np.abs(corr)
np.fill_diagonal(W, 0.0)
return np.nan_to_num(W, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
def threshold_edges(
W: np.ndarray,
thr: float,
) -> List[Tuple[int, int, float]]:
"""
エッジ抽出関数
- W が対称 → 無向グラフとして i < j のみ抽出
- W が非対称 → 有向グラフとして i -> j をすべて抽出
Returns:
(i, j, w): 対称の場合は無向、非対称の場合は i→j
"""
C = W.shape[0]
edges: List[Tuple[int, int, float]] = []
is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0)
if is_symmetric:
# --- 無向グラフ ---
for i in range(C):
for j in range(i + 1, C):
w = float(W[i, j])
if w >= thr:
edges.append((i, j, w))
else:
# --- 有向グラフ ---
for i in range(C):
for j in range(C):
if i == j:
continue
w = float(W[i, j])
if w >= thr:
edges.append((i, j, w))
# 重みの大きい順にソート
edges.sort(key=lambda x: x[2], reverse=True)
return edges
def adjacency_at_threshold(W: np.ndarray, thr: float, weighted: bool) -> np.ndarray:
if weighted:
A = W.copy()
A[A < thr] = 0.0
np.fill_diagonal(A, 0.0)
return A
A = (W >= thr).astype(int)
np.fill_diagonal(A, 0)
return A
def compute_louvain_clusters(W: np.ndarray, thr: float) -> np.ndarray:
"""
Louvain法でクラスタリングを実行。
Args:
W: 重み行列 (C, C)
thr: 閾値(これ以下のエッジは削除)
Returns:
clusters: クラスタID配列 (C,)
"""
if not LOUVAIN_AVAILABLE:
# Louvainが使えない場合は全ノードを同じクラスタに
return np.zeros(W.shape[0], dtype=int)
# NetworkXグラフを作成
G = nx.Graph()
C = W.shape[0]
G.add_nodes_from(range(C))
# 閾値以上のエッジを追加
for i in range(C):
for j in range(C):
if W[i, j] >= thr:
G.add_edge(i, j, weight=max(W[i, j],W[j, i]))
# Louvain法でコミュニティ検出
partition = community_louvain.best_partition(G, weight='weight')
# クラスタIDの配列に変換
clusters = np.array([partition[i] for i in range(C)])
return clusters
def get_cluster_colors(clusters: np.ndarray) -> List[str]:
"""
クラスタIDから色のリストを生成。
Args:
clusters: クラスタID配列 (C,)
Returns:
colors: 色のリスト
"""
import colorsys
n_clusters = len(np.unique(clusters))
# クラスタ数に応じて色相を均等に分割
colors = []
for cluster_id in clusters:
hue = cluster_id / max(n_clusters, 1)
r, g, b = colorsys.hsv_to_rgb(hue, 0.8, 0.95)
colors.append(f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})')
return colors
def get_electrode_positions(prep: dict) -> np.ndarray:
"""
電極位置を取得する。
Returns:
pos: (C, 2) 電極の2D座標 (x, y)
取得できない場合は円形配置を返す
"""
# prepに電極位置が保存されているかチェック
if "electrode_pos" in prep:
return prep["electrode_pos"]
# デフォルト: 円形配置
C = prep["raw"].shape[1]
angles = np.linspace(0, 2 * np.pi, C, endpoint=False)
xs = np.cos(angles)
ys = np.sin(angles)
return np.column_stack([xs, ys])
def make_network_figure_3d(
W: np.ndarray,
thr: float,
electrode_pos_3d: np.ndarray,
use_louvain: bool = True,
) -> go.Figure:
"""
3Dネットワーク図を作成(ドラッグで回転可能)
"""
C = W.shape[0]
xs = electrode_pos_3d[:, 0]
ys = electrode_pos_3d[:, 1]
zs = electrode_pos_3d[:, 2]
edges = threshold_edges(W, thr)
fig = go.Figure()
# エッジの重みの範囲を取得
if edges:
weights = [w for _, _, w in edges]
min_w = min(weights)
max_w = max(weights)
weight_range = max_w - min_w if max_w > min_w else 1.0
else:
min_w = 0
max_w = 1
weight_range = 1.0
# レインボーカラーマップ関数
def get_rainbow_color(norm_val):
import colorsys
hue = (1.0 - norm_val) * 0.67
r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
return f'rgb({int(255*r)}, {int(255*g)}, {int(255*b)})'
# エッジを描画
for (i, j, w) in edges:
norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
color = get_rainbow_color(norm_w)
line_width = 1 + 4 * norm_w
fig.add_trace(go.Scatter3d(
x=[xs[i], xs[j], None],
y=[ys[i], ys[j], None],
z=[zs[i], zs[j], None],
mode='lines',
line=dict(color=color, width=line_width),
hoverinfo='skip',
showlegend=False,
))
# Louvainクラスタリング
if use_louvain and LOUVAIN_AVAILABLE:
clusters = compute_louvain_clusters(W, thr)
node_colors = get_cluster_colors(clusters)
n_clusters = len(np.unique(clusters))
title_suffix = f" | Louvain clusters: {n_clusters}"
else:
node_colors = ['#FFD700'] * C
clusters = np.zeros(C, dtype=int)
title_suffix = ""
# ノードを描画
fig.add_trace(go.Scatter3d(
x=xs,
y=ys,
z=zs,
mode='markers+text',
text=[f"{k}" for k in range(C)],
textposition='top center',
textfont=dict(size=8),
marker=dict(
size=8,
color=node_colors,
line=dict(color='white', width=1),
),
hoverinfo='text',
hovertext=[f"channel {k}<br>cluster: {clusters[k]}" for k in range(C)],
showlegend=False,
))
fig.update_layout(
title=f"3D Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}",
height=700,
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
bgcolor='rgba(0,0,0,0.9)',
),
paper_bgcolor='rgba(0,0,0,0.9)',
margin=dict(l=0, r=0, t=50, b=0),
)
return fig
def make_network_figure(
W: np.ndarray,
thr: float,
use_louvain: bool = True,
electrode_pos: np.ndarray = None,
) -> tuple[go.Figure, int]:
C = W.shape[0]
# 電極位置を取得
if electrode_pos is None or electrode_pos.shape[0] != C:
# デフォルト: 円形配置
angles = np.linspace(0, 2 * np.pi, C, endpoint=False)
xs = np.cos(angles)
ys = np.sin(angles)
else:
xs = electrode_pos[:, 0]
ys = electrode_pos[:, 1]
edges = threshold_edges(W, thr)
fig = go.Figure()
# エッジの重みの範囲を取得(色と太さのスケーリング用)
if edges:
weights = [w for _, _, w in edges]
min_w = min(weights)
max_w = max(weights)
weight_range = max_w - min_w if max_w > min_w else 1.0
else:
min_w = 0
max_w = 1
weight_range = 1.0
# レインボーカラーマップ関数 (0=青 → 0.5=緑/黄 → 1=赤)
def get_rainbow_color(norm_val):
"""正規化された値 (0-1) からレインボーカラーを生成"""
import colorsys
# HSVのHue: 240°(青) → 0°(赤) に変換
hue = (1.0 - norm_val) * 0.67 # 0.67 ≈ 240/360 (青)
r, g, b = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
return f'rgba({int(255*r)}, {int(255*g)}, {int(255*b)}, 0.7)'
# エッジを描画(重みに応じて色と太さを変える)
# --- 有向のときだけ:矢印(三角マーカー)を終端側に置く ---
is_symmetric = np.allclose(W, W.T, atol=1e-12, rtol=0)
if (not is_symmetric):
curve_strength = 0.1 # 湾曲の強さ(要調整)
node_radius = 0.08 # ノード中心からどれくらい手前に終点/矢印を置くか(要調整)
bezier_n = 18 # 曲線の分割数(増やすほど滑らか)
t_arrow = 0.90 # 矢印を置く位置(0〜1)
for (i, j, w) in edges:
norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
color = get_rainbow_color(norm_w)
line_width = 0.5 + 3.5 * norm_w
p0 = np.array([xs[i], ys[i]], dtype=float)
p1 = np.array([xs[j], ys[j]], dtype=float)
v = p1 - p0
dist = np.hypot(v[0], v[1])
if dist < 1e-9:
continue
u = v / dist
# ノードに重ならないよう端点を縮める
p0s = p0 + u * node_radius
p1s = p1 - u * node_radius
# 垂直方向(法線)
n = np.array([-u[1], u[0]])
# ★ 有向エッジは全部曲げる(規則的に)
sign = 1.0 #if i < j else -1.0
# 制御点
mid = 0.5 * (p0s + p1s)
c = mid + sign * curve_strength * dist * n
# 曲線点列
pts = _quad_bezier_points(p0s, p1s, c, n=bezier_n)
fig.add_trace(go.Scatter(
x=pts[:, 0],
y=pts[:, 1],
mode="lines",
hoverinfo="text",
hovertext=f"ch{i} → ch{j}<br>weight: {w:.4f}",
line=dict(width=line_width, color=color),
showlegend=False,
))
# 矢印(曲線接線方向)
pt, tan = _quad_bezier_point_and_tangent(p0s, p1s, c, t_arrow)
# 接線がゼロに近い場合の保険
tx, ty = float(tan[0]), float(tan[1])
if tx*tx + ty*ty < 1e-18:
tx, ty = float(p1s[0] - p0s[0]), float(p1s[1] - p0s[1])
theta = np.degrees(np.arctan2(ty, tx)) # 接線の角度(+x基準)
ANGLE_OFFSET = -90.0 # triangle-up(上向き) を接線方向に合わせる補正
ang = (theta + ANGLE_OFFSET) % 360
fig.add_trace(go.Scatter(
x=[pt[0]],
y=[pt[1]],
mode="markers",
hoverinfo="skip",
marker=dict(
symbol="triangle-up",
size=10,
angle=-ang,
angleref="up",
color=color,
line=dict(width=0),
),
showlegend=False,
))
else:
for (i, j, w) in edges:
# 正規化された重み (0-1)
norm_w = (w - min_w) / weight_range if weight_range > 0 else 0.5
# レインボーカラー: 弱い(青) → 中間(緑/黄) → 強い(赤)
color = get_rainbow_color(norm_w)
# 太さ: 重みに比例 (0.5-4の範囲)
line_width = 0.5 + 3.5 * norm_w
fig.add_trace(
go.Scatter(
x=[xs[i], xs[j]],
y=[ys[i], ys[j]],
mode="lines",
hoverinfo="text",
hovertext=f"ch{i} - ch{j}<br>weight: {w:.4f}",
line=dict(width=line_width, color=color),
showlegend=False,
)
)
# Louvainクラスタリング
if use_louvain and LOUVAIN_AVAILABLE:
clusters = compute_louvain_clusters(W, thr)
node_colors = get_cluster_colors(clusters)
n_clusters = len(np.unique(clusters))
title_suffix = f" | Louvain clusters: {n_clusters}"
else:
node_colors = ['#FFD700'] * C # デフォルトのゴールド
clusters = np.zeros(C, dtype=int)
title_suffix = ""
# ノードを描画
fig.add_trace(
go.Scatter(
x=xs,
y=ys,
mode="markers+text",
text=[f"{k}" for k in range(C)],
textposition="bottom center",
textfont=dict(size=8),
marker=dict(
size=14,
color=node_colors,
line=dict(width=2, color='white')
),
hoverinfo="text",
hovertext=[f"channel {k}<br>cluster: {clusters[k]}" for k in range(C)],
showlegend=False,
)
)
fig.update_layout(
title=f"Estimated Network (thr={thr:.3f}) edges={len(edges)}{title_suffix}",
height=600,
xaxis=dict(visible=False),
yaxis=dict(visible=False),
margin=dict(l=10, r=10, t=50, b=50),
paper_bgcolor='rgba(0,0,0,0.9)',
plot_bgcolor='rgba(0,0,0,0.9)',
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
# カラーバー的な説明を追加
if edges:
fig.add_annotation(
text=f"Edge color/width: weak (blue/thin) → medium (green/yellow) → strong (red/thick)<br>Weight range: {min_w:.3f} - {max_w:.3f}",
xref="paper", yref="paper",
x=0.5, y=-0.05,
showarrow=False,
font=dict(size=10, color='white'),
xanchor='center',
)
return fig, len(edges)
def make_edgecount_curve(W: np.ndarray) -> go.Figure:
vals = np.sort(W[np.triu_indices(W.shape[0], k=1)])
thr_grid = np.linspace(float(vals.max()), float(vals.min()), 120) if vals.size else np.array([0.0])
counts = [len(threshold_edges(W, float(thr))) for thr in thr_grid]
fig = go.Figure()
fig.add_trace(go.Scatter(x=thr_grid, y=counts, mode="lines"))
fig.update_layout(
title="Edge count vs threshold (lower thr => more edges)",
xaxis_title="threshold",
yaxis_title="edge count",
height=300,
)
return fig
def to_csv_bytes_matrix(mat: np.ndarray, fmt: str) -> bytes:
buf = io.StringIO()
np.savetxt(buf, mat, delimiter=",", fmt=fmt)
return buf.getvalue().encode("utf-8")
def to_csv_bytes_edges(edges: List[Tuple[int, int, float]]) -> bytes:
buf = io.StringIO()
buf.write("source,target,weight\n")
for i, j, w in edges:
buf.write(f"{i},{j},{w:.6f}\n")
return buf.getvalue().encode("utf-8")
# ============================================================
# Sidebar UI
# ============================================================
st.sidebar.header("Input format")
input_mode = st.sidebar.radio("データ形式", ["EEGLAB (.set + .fdt)", "MATLAB (.mat)"], index=0)
st.sidebar.header("Preprocess (auto)")
f_low_src = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=4.0, step=1.0, key="low_src")
f_high_src = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=8.0, step=1.0, key="high_src")
st.sidebar.header("if you use CFC+PAC:")
f_low_tgt = st.sidebar.number_input("Bandpass low (Hz)", min_value=0.0, value=25.0, step=1.0, key="low_tgt")
f_high_tgt = st.sidebar.number_input("Bandpass high (Hz)", min_value=0.1, value=40.0, step=1.0, key="high_tgt")
st.sidebar.header("Viewer controls")
win_sec = st.sidebar.number_input("Window length (sec)", min_value=0.1, value=5.0, step=0.1)
decim = st.sidebar.selectbox("Decimation (間引き)", options=[1, 2, 5, 10, 20, 50], index=1)
offset_mode = st.sidebar.checkbox("重ね描画のオフセット表示", value=True)
show_rangeslider = st.sidebar.checkbox("Plotly rangesliderを表示", value=False)
signal_view = st.sidebar.radio(
"表示する信号",
["raw", "filtered", "amplitude", "phase"],
index=1,
help="raw: 生信号, filtered: バンドパス後, amplitude: Hilbert振幅(envelope), phase: Hilbert位相"
)
st.title("EEG timeseries viewer + network estimation")
# ============================================================
# Load + preprocess (EEGLAB / MAT)
# ============================================================
if input_mode.startswith("EEGLAB"):
st.sidebar.header("Upload (.set + .fdt)")
uploaded_files = st.sidebar.file_uploader(
"Upload EEGLAB files",
type=["set", "fdt"],
accept_multiple_files=True,
)
if uploaded_files:
set_file, fdt_file = pick_set_fdt(uploaded_files)
if set_file is None or fdt_file is None:
st.warning("`.set` と `.fdt` の両方をアップロードしてください。")
else:
try:
with st.spinner("Loading EEGLAB + preprocessing (bandpass + hilbert)..."):
prep_src = preprocess_all_eeglab(
set_bytes=set_file.getvalue(),
fdt_bytes=fdt_file.getvalue(),
set_name=set_file.name,
fdt_name=fdt_file.name,
f_low=float(f_low_src),
f_high=float(f_high_src),
)
prep_tgt = preprocess_all_eeglab(
set_bytes=set_file.getvalue(),
fdt_bytes=fdt_file.getvalue(),
set_name=set_file.name,
fdt_name=fdt_file.name,
f_low=float(f_low_tgt),
f_high=float(f_high_tgt),
)
st.session_state["prep"] = prep_src
st.session_state["prep_tgt"] = prep_tgt
st.session_state["W"] = None
st.success(f"Loaded & preprocessed. (T,C)={prep_src['raw'].shape} fs={prep_src['fs']:.2f}Hz")
except Exception as e:
st.session_state.pop("prep", None)
st.session_state["W"] = None
st.error(f"読み込み/前処理エラー: {e}")
else:
st.sidebar.header("Upload (.mat)")
mat_file = st.sidebar.file_uploader("Upload .mat", type=["mat"])
if mat_file is not None:
mat_bytes = mat_file.getvalue()
try:
cands = load_mat_candidates_cached(mat_bytes)
if not cands:
st.error("数値の1D/2D配列が見つかりませんでした。")
st.info("MATファイルの構造を確認しています...")
# デバッグ: MATファイルの中身を表示
try:
from scipy.io import loadmat
mat_data = loadmat(io.BytesIO(mat_bytes))
st.write("**MATファイルに含まれる変数:**")
for k, v in mat_data.items():
if not k.startswith('__'):
if isinstance(v, np.ndarray):
st.write(f"- `{k}`: shape={v.shape}, dtype={v.dtype}, ndim={v.ndim}")
else:
st.write(f"- `{k}`: type={type(v).__name__}")
except Exception as e:
st.write(f"デバッグ情報の取得に失敗: {e}")
# HDF5形式の場合も試す
try:
import h5py
import tempfile
with tempfile.NamedTemporaryFile(suffix='.mat', delete=False) as tmp:
tmp.write(mat_bytes)
tmp_path = tmp.name
st.write("**HDF5形式として読み込み中...**")
with h5py.File(tmp_path, 'r') as f:
def show_structure(name, obj):
if isinstance(obj, h5py.Dataset):
st.write(f"- `{name}`: shape={obj.shape}, dtype={obj.dtype}")
f.visititems(show_structure)
import os
os.unlink(tmp_path)
except Exception as e2:
st.write(f"HDF5としても読み込めませんでした: {e2}")
else:
key = st.sidebar.selectbox("EEG配列(変数)を選択", options=list(cands.keys()))
fs_mat = st.sidebar.number_input("Sampling rate (Hz)", min_value=0.1, value=256.0, step=0.1)
# 変数が選択されたら自動的に前処理を実行
if key:
x = cands[key]
st.sidebar.write(f"選択した配列: shape={x.shape}, dtype={x.dtype}")
try:
with st.spinner("Preprocessing (bandpass + hilbert)..."):
cfg = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_src), f_high=float(f_high_src))
prep = preprocess_tc(x, cfg)
cfg_tgt = PreprocessConfig(fs=float(fs_mat), f_low=float(f_low_tgt), f_high=float(f_high_tgt))
prep_tgt = preprocess_tc(x, cfg_tgt)
st.session_state["prep"] = prep
st.session_state["prep_tgt"] = prep_tgt
st.session_state["W"] = None
st.success(f"Loaded MAT '{key}'. (T,C)={prep['raw'].shape} fs={prep['fs']:.2f}Hz")
except Exception as e:
st.session_state.pop("prep", None)
st.session_state["W"] = None
st.error(f"前処理エラー: {e}")
import traceback
st.code(traceback.format_exc())
except Exception as e:
st.session_state.pop("prep", None)
st.session_state["W"] = None
st.error(f".mat 読み込みエラー: {e}")
import traceback
st.code(traceback.format_exc())
if "prep" not in st.session_state:
st.info("左のサイドバーからデータをアップロードしてください。")
st.stop()
# ============================================================
# Viewer
# ============================================================
prep = st.session_state["prep"]
fs = float(prep["fs"])
X_tc = prep[signal_view]
T, C = X_tc.shape
duration_sec = (T - 1) / fs if T > 1 else 0.0
max_start = max(0.0, float(duration_sec - win_sec))
start_sec = st.sidebar.slider(
"Start time (sec)",
min_value=0.0,
max_value=float(max_start),
value=0.0,
step=float(max(0.01, win_sec / 200)),
)
st.sidebar.header("Channels")
# チャンネル選択の便利機能
col_ch1, col_ch2 = st.sidebar.columns(2)
with col_ch1:
select_all = st.button("全選択")
with col_ch2:
deselect_all = st.button("全解除")
# 範囲選択
with st.sidebar.expander("📊 範囲で選択"):
range_start = st.number_input("開始ch", min_value=0, max_value=C-1, value=0, step=1)
range_end = st.number_input("終了ch", min_value=0, max_value=C-1, value=min(C-1, 7), step=1)
if st.button("範囲を選択"):
st.session_state["selected_channels"] = list(range(int(range_start), int(range_end) + 1))
# プリセット選択
with st.sidebar.expander("⚡ プリセット"):
preset_col1, preset_col2 = st.columns(2)
with preset_col1:
if st.button("前頭部 (0-15)"):
st.session_state["selected_channels"] = list(range(min(16, C)))
with preset_col2:
if st.button("頭頂部 (16-31)"):
st.session_state["selected_channels"] = list(range(16, min(32, C)))
preset_col3, preset_col4 = st.columns(2)
with preset_col3:
if st.button("側頭部 (32-47)"):
st.session_state["selected_channels"] = list(range(32, min(48, C)))
with preset_col4:
if st.button("後頭部 (48-63)"):
st.session_state["selected_channels"] = list(range(48, min(64, C)))
# セッションステートの初期化
if "selected_channels" not in st.session_state:
st.session_state["selected_channels"] = list(range(min(C, 8)))
# ボタンによる選択の処理
if select_all:
st.session_state["selected_channels"] = list(range(C))
if deselect_all:
st.session_state["selected_channels"] = []
# メインの選択UI(最大表示数を制限)
max_display = 20 # multiselect で一度に表示する数を制限
if C <= max_display:
selected_channels = st.sidebar.multiselect(
f"表示するチャンネル(全{C}ch)",
options=list(range(C)),
default=st.session_state["selected_channels"],
key="ch_select",
)
else:
# 大量のチャンネルがある場合は、選択済みのものだけ表示
st.sidebar.caption(f"選択中: {len(st.session_state['selected_channels'])} / {C} channels")
# 個別追加
add_ch = st.sidebar.number_input(
"チャンネルを追加",
min_value=0,
max_value=C-1,
value=0,
step=1,
key="add_ch_input"
)
col_add, col_remove = st.sidebar.columns(2)
with col_add:
if st.button("➕ 追加"):
if add_ch not in st.session_state["selected_channels"]:
st.session_state["selected_channels"].append(int(add_ch))
st.session_state["selected_channels"].sort()
with col_remove:
if st.button("➖ 削除"):
if add_ch in st.session_state["selected_channels"]:
st.session_state["selected_channels"].remove(int(add_ch))
# 現在の選択を表示
if st.session_state["selected_channels"]:
selected_str = ", ".join(map(str, st.session_state["selected_channels"][:10]))
if len(st.session_state["selected_channels"]) > 10:
selected_str += f", ... (+{len(st.session_state['selected_channels']) - 10})"
st.sidebar.text(f"選択済み: {selected_str}")
selected_channels = st.session_state["selected_channels"]
# セッションステートを更新(multiselectを使った場合)
if C <= max_display:
st.session_state["selected_channels"] = selected_channels
col1, col2 = st.columns([2, 1])
with col1:
fig_ts = make_timeseries_figure(
X_tc=X_tc,
selected_channels=selected_channels,
fs=fs,
start_sec=float(start_sec),
win_sec=float(win_sec),
decim=int(decim),
offset_mode=bool(offset_mode),
show_rangeslider=bool(show_rangeslider),
signal_type=signal_view,
)
st.plotly_chart(fig_ts)
with col2:
st.subheader("Data info")
signal_desc = {
"raw": "生信号(前処理なし)",
"filtered": f"バンドパスフィルタ後 ({f_low_src}-{f_high_src} Hz)",
"amplitude": "Hilbert振幅 (envelope)",
"phase": "Hilbert位相 (-π ~ π)"
}
st.write(f"- view: **{signal_view}** ({signal_desc.get(signal_view, '')})")
st.write(f"- fs: **{fs:.2f} Hz**")
st.write(f"- T: {T} samples")
st.write(f"- C: {C} channels")
st.write(f"- duration: {duration_sec:.2f} sec")
if signal_view == "phase":
st.caption("※ 位相は -π (rad) から π (rad) の範囲で表示されます")
st.caption("※ 大規模データは window + decimation 推奨。rangesliderは重い場合OFF。")
st.divider()
# ============================================================
# Estimation
# ============================================================
st.subheader("Network estimation")
# 推定手法の選択
estimation_method = st.radio(
"推定手法を選択",
options=[
"envelope_corr",
"phase_PLV",
"phase_corr",
"pac_tort",
"pac_chatterjee"
],
format_func=lambda x: {
"envelope_corr": "Envelope Pearson correlation (振幅の相関)",
"phase_PLV": "PLV(位相同期, PLV)",
"phase_corr": "Circular correlation",
"pac_tort": "Modulation Index(位相と振幅のPAC指標)",
"pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)",
}[x],
horizontal=True,
help="envelope_corr: 振幅包絡線のPearson相関係数 | phase_PLV: 位相のPhase Locking Value | phase_corr: 位相の相関係数 | pac_tort: Modulation index | pac_chatterjee: 位相から振幅へのChatterjee相関",
)
# 推定手法の説明
method_info = {
"envelope_corr": "**Envelope correlation**: 振幅包絡線(Hilbert amplitude)間のPearson相関係数を計算します。振幅が同期して変動するチャンネル間の結合を検出します。",
"phase_PLV": "**PLV**: 位相間のPhase locking valueを計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。",
"phase_corr": "**Circular correlation**: 位相間の相関係数を計算します。位相同期を検出します。0(非同期)〜1(完全同期)の値を取ります。",
"pac_tort": "Modulation Index(位相と振幅のPAC指標)",
"pac_chatterjee": "Chatterjee correlation(位相→振幅の相関)",
}
st.info(method_info[estimation_method])
# セッションステートから前回の手法と W を取得
last_method = st.session_state.get("last_estimation_method")
W = st.session_state.get("W")
# 推定が必要かチェック(初回 or 手法変更)
need_estimation = (W is None) or (last_method != estimation_method)
if need_estimation:
progress = st.progress(0.0)
with st.spinner(f"推定中... ({estimation_method})"):
if estimation_method == "envelope_corr":
X_in = prep["amplitude"]
W = estimate_network_envelope_corr(X_in)
elif estimation_method == "phase_PLV":
X_in = prep["phase"]
W = estimate_network_phase_PLV(X_in, progress)
elif estimation_method == "phase_corr":
X_in = prep["phase"]
W = estimate_network_phase_corr(X_in)
elif estimation_method == "pac_tort":
X_in_low_phase = prep["phase"]
prep_tgt = st.session_state["prep_tgt"]
X_in_high_amplitude = prep_tgt["amplitude"]
W = estimate_network_pac_tort(X_in_low_phase,X_in_high_amplitude,progress)
elif estimation_method == "pac_chatterjee":
X_in_low_phase = prep["phase"]
prep_tgt = st.session_state["prep_tgt"]
X_in_high_amplitude = prep_tgt["amplitude"]
W = estimate_network_pac_chatterjee(X_in_low_phase,X_in_high_amplitude,progress)
else:
st.error("未知の推定手法です")
st.stop()
# セッションステートに保存
st.session_state["W"] = W
st.session_state["last_estimation_method"] = estimation_method
st.success(f"✅ 推定完了: {estimation_method} (ネットワークサイズ: {W.shape[0]} nodes)")
else:
st.success(f"✓ 推定済み: **{estimation_method}** (ネットワークサイズ: {W.shape[0]} nodes)")
# この時点で W は必ず存在する
# 閾値スライダーとネットワーク図の表示
wmax = float(np.max(W)) if np.isfinite(np.max(W)) else 1.0
col_thr1, col_thr2 = st.columns([3, 1])
with col_thr1:
thr = st.slider(
"閾値 (threshold) ※下げるほどエッジが増えます",
min_value=0.0,
max_value=max(0.0001, wmax),
value=wmax/2,
step=max(wmax / 200, 0.001),
)
with col_thr2:
use_louvain = st.checkbox(
"Louvainクラスタ",
value=True,
disabled=not LOUVAIN_AVAILABLE,
help="ノードの色をコミュニティ検出結果で塗り分けます"
)
# 電極位置を取得
electrode_pos = prep.get("electrode_pos", None)
# 2D座標を90度左回転(上が正面になる向きに)
if electrode_pos is not None:
electrode_pos = np.asarray(electrode_pos, dtype=np.float32)
if electrode_pos.ndim == 2 and electrode_pos.shape[1] >= 2:
pos2 = electrode_pos[:, :2]
electrode_pos = np.column_stack([-pos2[:, 1], pos2[:, 0]])
electrode_pos_3d = prep.get("electrode_pos_3d", None)
if electrode_pos is not None:
st.info(f"✓ 電極位置を使用してネットワークを配置 ({electrode_pos.shape[0]} channels)")
else:
st.info("ℹ️ 電極位置が取得できなかったため、円形配置を使用します")
# 3D座標の有無を表示
if electrode_pos_3d is not None:
st.success(f"✓ 3D電極座標を取得しました ({electrode_pos_3d.shape[0]} channels) - 下部に3Dビューアを表示します")
else:
st.info("ℹ️ 3D電極座標が取得できませんでした - 2D表示のみです")
net_col1, net_col2 = st.columns([2, 1])
with net_col1:
fig_net, edge_n = make_network_figure(
W,
float(thr),
use_louvain=use_louvain,
electrode_pos=electrode_pos,
)
st.plotly_chart(fig_net)
# 3Dネットワーク表示(3D座標がある場合のみ)
if electrode_pos_3d is not None:
electrode_pos_3d = np.asarray(electrode_pos_3d, dtype=np.float32)
if electrode_pos_3d.ndim == 2 and electrode_pos_3d.shape[0] == W.shape[0] and electrode_pos_3d.shape[1] == 3:
st.subheader("3D Viewer")
fig_3d = make_network_figure_3d(
W=W,
thr=float(thr),
electrode_pos_3d=electrode_pos_3d,
use_louvain=use_louvain,
)
st.plotly_chart(
fig_3d,
width="stretch",
config={"displayModeBar": True, "scrollZoom": True},
)
else:
st.warning(f"3D座標のshapeが不正です: {electrode_pos_3d.shape}(期待: (C,3), C={W.shape[0]})")
with net_col2:
st.metric("Edges", edge_n)
st.plotly_chart(make_edgecount_curve(W))
st.write("# Hypothesis testing")
st.write("Coming soon ...")