""" Визуализация предсказаний SYNTAX: - точки (SYNTAX GT vs предсказания модели) для нескольких датасетов; - зоны риска (низкий / высокий риск); - области ±σ и ±2σ вокруг диагонали; - логистические тренды для каждого датасета. Скрипт не зависит от PyTorch/Lightning и используется на этапе инференса. Сохранение осуществляется в папку `visualizations/` внутри проекта. """ import os import numpy as np import plotly.graph_objects as go from scipy.optimize import curve_fit # type: ignore def visualize_final_syntax_plotly_multi( datasets, r2_values, gt_row, postfix=None, threshold=22.0, recall_values=None, backbone=False, ): """ Единая визуализация SYNTAX: точки, зоны риска и логистические тренды. Параметры --------- datasets : dict[str, tuple[list[float], list[float]]] Словарь {имя_датасета: (syntax_true_list, syntax_pred_list)}. r2_values : dict[str, float] Словарь R^2 по датасетам. gt_row : str Строка, попадающая в заголовок (например, "ENSEMBLE" или "BOTH"). postfix : str | None Суффикс для имени сохраняемого файла. threshold : float Порог SYNTAX (обычно 22.0) для разделения зон риска. recall_values : dict[str, float] | None Словарь Recall по датасетам (может быть None). backbone : bool Если True, сохраняет в `visualizations/backbone`, иначе в `visualizations/`. """ # ========== КОНСТАНТЫ ДЛЯ НАСТРОЙКИ ========== DATA_MIN = 0.0 DATA_MAX = 60.0 PADDING = 0.5 SIGMA_SLOPE = 0.15 SIGMA_BASE = 1.4 PLOT_WIDTH = 980 PLOT_HEIGHT = 980 BASE_FONT_SIZE = 16 TITLE_FONT_SIZE = 22 AXIS_LABEL_FONT_SIZE = BASE_FONT_SIZE AXIS_TICK_FONT_SIZE = 15 LEGEND_FONT_SIZE = 14 MARKER_SIZE = 11 MARKER_LINE_WIDTH = 1.1 LINE_WIDTH = 2 TREND_LINE_WIDTH = 3 PLOT_BG_COLOR = "rgba(235,238,245,1)" PAPER_BG_COLOR = "white" LEGEND_BG_COLOR = "rgba(255,255,255,0.94)" GRID_COLOR = "rgba(100,116,139,0.18)" MARGIN_LEFT = 70 MARGIN_RIGHT = 24 MARGIN_TOP = 78 MARGIN_BOTTOM = 70 LEGEND_X = 0.04 LEGEND_Y = 0.99 COLORS = ["#1E88E5", "#8E24AA", "#A0D137", "#EA1D1D", "#06EE0D", "#FB8C00"] SYMBOLS = ["circle", "x", "square", "diamond", "triangle-up", "star"] SIGMA_POINTS = 400 TREND_POINTS = 500 # ========== ВСПОМОГАТЕЛЬНЫЕ ФУНКЦИИ ========== def _logistic_time(t, R0, Rmax, t50, k): """Логистическая функция по времени/оценке SYNTAX.""" t = np.asarray(t, dtype=float) t_safe = np.where(t <= 0, 1e-3, t) return R0 + (Rmax - R0) / (1.0 + (t50 / t_safe) ** k) def _fit_logistic(x, y, domain=(DATA_MIN, DATA_MAX), n=TREND_POINTS): """ Аппроксимация логистической кривой. Возвращает X, Y или (None, None), если фит не удался. """ x = np.asarray(x, dtype=float) y = np.asarray(y, dtype=float) m = np.isfinite(x) & np.isfinite(y) if m.sum() < 4: return None, None x_m, y_m = x[m], y[m] x_min = max(float(np.min(x_m)), float(domain[0])) x_max = min(float(np.max(x_m)), float(domain[1])) if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min: return None, None x_pos = x_m[x_m > 0] if x_pos.size == 0: return None, None R0_init = float(np.percentile(y_m, 10)) Rmax_init = float(np.percentile(y_m, 90)) t50_init = float(np.median(x_pos)) k_init = 1.0 lower = [-10.0, 0.0, 1e-3, 0.01] upper = [60.0, 80.0, 60.0, 10.0] try: popt, _ = curve_fit( _logistic_time, x_m, y_m, p0=[R0_init, Rmax_init, t50_init, k_init], bounds=(lower, upper), maxfev=20000, ) except Exception: return None, None X = np.linspace(x_min, x_max, n) Y = _logistic_time(X, *popt) return X, Y # ========== ОСНОВНОЙ КОД ========== fig = go.Figure() line_min = DATA_MIN - PADDING line_max = DATA_MAX + PADDING domain = (line_min, line_max) base_font = dict( family="Inter, Roboto, Helvetica Neue, Arial, sans-serif", size=BASE_FONT_SIZE, ) # ---------- Пороги и линии (legendrank=0) ---------- fig.add_trace( go.Scatter( x=[line_min, threshold, threshold, line_min], y=[line_min, line_min, threshold, threshold], fill="toself", fillcolor="rgba(255, 82, 82, 0.12)", line=dict(color="rgba(0,0,0,0)"), name="Low-risk zone", legendgroup="zones", legendgrouptitle_text="Пороги и линии", showlegend=True, hoverinfo="skip", legendrank=0, ) ) fig.add_trace( go.Scatter( x=[threshold, line_max, line_max, threshold], y=[threshold, threshold, line_max, line_max], fill="toself", fillcolor="rgba(76, 175, 80, 0.14)", line=dict(color="rgba(0,0,0,0)"), name="High-risk zone", legendgroup="zones", showlegend=True, hoverinfo="skip", legendrank=0, ) ) fig.add_trace( go.Scatter( x=[threshold, threshold, None, line_min, line_max], y=[line_min, line_max, None, threshold, threshold], mode="lines", name=rf"$\mathrm{{SYNTAX}}={threshold}$", legendgroup="zones", showlegend=True, line=dict(color="rgba(46,125,50,0.85)", width=LINE_WIDTH, dash="dash"), legendrank=0, hoverinfo="skip", ) ) x_vals = np.linspace(line_min, line_max, SIGMA_POINTS) sigma_upper = x_vals + SIGMA_BASE + SIGMA_SLOPE * x_vals sigma_lower = x_vals - SIGMA_BASE - SIGMA_SLOPE * x_vals two_sigma_upper = x_vals + 2 * SIGMA_BASE + 2 * SIGMA_SLOPE * x_vals two_sigma_lower = x_vals - 2 * SIGMA_BASE - 2 * SIGMA_SLOPE * x_vals fig.add_trace( go.Scatter( x=np.concatenate([x_vals, x_vals[::-1]]), y=np.concatenate([two_sigma_lower, two_sigma_upper[::-1]]), fill="toself", fillcolor="rgba(255,193,7,0.18)", line=dict(color="rgba(0,0,0,0)"), name=r"$\pm 2\sigma$", legendgroup="zones", showlegend=True, hoverinfo="skip", legendrank=0, ) ) fig.add_trace( go.Scatter( x=np.concatenate([x_vals, x_vals[::-1]]), y=np.concatenate([sigma_lower, sigma_upper[::-1]]), fill="toself", fillcolor="rgba(255,152,0,0.30)", line=dict(color="rgba(0,0,0,0)"), name=r"$\pm \sigma$", legendgroup="zones", showlegend=True, hoverinfo="skip", legendrank=0, ) ) fig.add_trace( go.Scatter( x=[line_min, line_max], y=[line_min, line_max], mode="lines", name=r"$y=x$", legendgroup="zones", showlegend=True, line=dict(color="rgba(30,30,30,0.85)", width=LINE_WIDTH), legendrank=0, ) ) # ---------- Датасеты (legendrank=20) ---------- first_dataset = True for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()): x = np.array(syntax_true, dtype=float) y = np.array(syntax_pred, dtype=float) if x.size == 0 or y.size == 0: continue r2 = r2_values.get(label, None) recall = recall_values.get(label, None) if recall_values else None hover_lines = [f"{label}"] if r2 is not None: hover_lines.append(f"R² = {r2:.3f}") if recall is not None: hover_lines.append(f"Recall = {recall:.3f}") hovertemplate = ( "
".join(hover_lines) + "
GT: %{x:.3f}
Pred: %{y:.3f}" ) fig.add_trace( go.Scatter( x=x, y=y, mode="markers", name=label, legendgroup="datasets", legendgrouptitle_text=("Датасеты" if first_dataset else None), showlegend=True, marker=dict( color=COLORS[i % len(COLORS)], size=MARKER_SIZE, opacity=0.96, symbol=SYMBOLS[i % len(SYMBOLS)], line=dict( width=MARKER_LINE_WIDTH, color="rgba(255,255,255,0.95)" ), ), hovertemplate=hovertemplate, legendrank=20, ) ) first_dataset = False # ---------- Тренды: логистические (legendrank=30) ---------- first_trend = True for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()): x = np.array(syntax_true, dtype=float) y = np.array(syntax_pred, dtype=float) if x.size == 0 or y.size == 0: continue Xc, Yc = _fit_logistic(x, y, domain=domain) if Xc is not None: fig.add_trace( go.Scatter( x=Xc, y=Yc, mode="lines", name=label, # без коротких alias, полное имя датасета legendgroup="trends", legendgrouptitle_text=( "Тренды (логистические)" if first_trend else None ), showlegend=True, line=dict( color=COLORS[i % len(COLORS)], width=TREND_LINE_WIDTH ), hoverinfo="skip", legendrank=30, ) ) first_trend = False # ---------- оформление ---------- title_text = f"SYNTAX predictions ({gt_row})" if postfix: title_text += f" {postfix}" fig.update_layout( title=dict( text=title_text, x=0.5, xanchor="center", font=dict( size=TITLE_FONT_SIZE, family=base_font["family"], color="rgba(15,23,42,1)", ), ), font=base_font, xaxis_title=r"$\mathrm{SYNTAX\ GT}$", yaxis_title=r"$\mathrm{SYNTAX\ predictions}$", width=PLOT_WIDTH, height=PLOT_HEIGHT, plot_bgcolor=PLOT_BG_COLOR, paper_bgcolor=PAPER_BG_COLOR, legend=dict( x=LEGEND_X, y=LEGEND_Y, bgcolor=LEGEND_BG_COLOR, bordercolor="#CBD5E1", borderwidth=1, font=dict(size=LEGEND_FONT_SIZE, family=base_font["family"]), tracegroupgap=8, itemclick="toggle", itemdoubleclick="toggleothers", groupclick="toggleitem", ), xaxis=dict( showgrid=True, gridcolor=GRID_COLOR, gridwidth=1, zeroline=False, tickfont=dict(size=AXIS_TICK_FONT_SIZE), range=[line_min, line_max], constrain="domain", ), yaxis=dict( showgrid=True, gridcolor=GRID_COLOR, gridwidth=1, zeroline=False, tickfont=dict(size=AXIS_TICK_FONT_SIZE), range=[line_min, line_max], scaleanchor="x", scaleratio=1, constrain="domain", ), margin=dict( l=MARGIN_LEFT, r=MARGIN_RIGHT, t=MARGIN_TOP, b=MARGIN_BOTTOM, ), ) # ---------- сохранение ---------- save_dir = "visualizations" if backbone: save_dir = os.path.join(save_dir, "backbone") os.makedirs(save_dir, exist_ok=True) postfix_html = f"{postfix}" if postfix else "syntax" save_path_html = os.path.join(save_dir, f"{postfix_html}.html") fig.write_html(save_path_html, include_mathjax="cdn") print(f"Saved visualization with logistic trends: {save_path_html}")