| | """ |
| | Визуализация предсказаний SYNTAX: |
| | - точки (SYNTAX GT vs предсказания модели) для нескольких датасетов; |
| | - зоны риска (низкий / высокий риск); |
| | - области ±σ и ±2σ вокруг диагонали; |
| | - логистические тренды для каждого датасета. |
| | |
| | Скрипт не зависит от PyTorch/Lightning и используется на этапе инференса. |
| | Сохранение осуществляется в папку `visualizations/` внутри проекта. |
| | """ |
| |
|
| | import os |
| | import numpy as np |
| | import plotly.graph_objects as go |
| | from scipy.optimize import curve_fit |
| |
|
| |
|
| | def visualize_final_syntax_plotly_multi( |
| | datasets, |
| | r2_values, |
| | gt_row, |
| | postfix=None, |
| | threshold=22.0, |
| | recall_values=None, |
| | backbone=False, |
| | ): |
| | """ |
| | Единая визуализация SYNTAX: точки, зоны риска и логистические тренды. |
| | |
| | Параметры |
| | --------- |
| | datasets : dict[str, tuple[list[float], list[float]]] |
| | Словарь {имя_датасета: (syntax_true_list, syntax_pred_list)}. |
| | r2_values : dict[str, float] |
| | Словарь R^2 по датасетам. |
| | gt_row : str |
| | Строка, попадающая в заголовок (например, "ENSEMBLE" или "BOTH"). |
| | postfix : str | None |
| | Суффикс для имени сохраняемого файла. |
| | threshold : float |
| | Порог SYNTAX (обычно 22.0) для разделения зон риска. |
| | recall_values : dict[str, float] | None |
| | Словарь Recall по датасетам (может быть None). |
| | backbone : bool |
| | Если True, сохраняет в `visualizations/backbone`, иначе в `visualizations/`. |
| | """ |
| | |
| | DATA_MIN = 0.0 |
| | DATA_MAX = 60.0 |
| |
|
| | PADDING = 0.5 |
| |
|
| | SIGMA_SLOPE = 0.15 |
| | SIGMA_BASE = 1.4 |
| |
|
| | PLOT_WIDTH = 980 |
| | PLOT_HEIGHT = 980 |
| |
|
| | BASE_FONT_SIZE = 16 |
| | TITLE_FONT_SIZE = 22 |
| | AXIS_LABEL_FONT_SIZE = BASE_FONT_SIZE |
| | AXIS_TICK_FONT_SIZE = 15 |
| | LEGEND_FONT_SIZE = 14 |
| |
|
| | MARKER_SIZE = 11 |
| | MARKER_LINE_WIDTH = 1.1 |
| | LINE_WIDTH = 2 |
| | TREND_LINE_WIDTH = 3 |
| |
|
| | PLOT_BG_COLOR = "rgba(235,238,245,1)" |
| | PAPER_BG_COLOR = "white" |
| | LEGEND_BG_COLOR = "rgba(255,255,255,0.94)" |
| | GRID_COLOR = "rgba(100,116,139,0.18)" |
| |
|
| | MARGIN_LEFT = 70 |
| | MARGIN_RIGHT = 24 |
| | MARGIN_TOP = 78 |
| | MARGIN_BOTTOM = 70 |
| |
|
| | LEGEND_X = 0.04 |
| | LEGEND_Y = 0.99 |
| |
|
| | COLORS = ["#1E88E5", "#8E24AA", "#A0D137", "#EA1D1D", "#06EE0D", "#FB8C00"] |
| | SYMBOLS = ["circle", "x", "square", "diamond", "triangle-up", "star"] |
| |
|
| | SIGMA_POINTS = 400 |
| | TREND_POINTS = 500 |
| |
|
| | |
| |
|
| | def _logistic_time(t, R0, Rmax, t50, k): |
| | """Логистическая функция по времени/оценке SYNTAX.""" |
| | t = np.asarray(t, dtype=float) |
| | t_safe = np.where(t <= 0, 1e-3, t) |
| | return R0 + (Rmax - R0) / (1.0 + (t50 / t_safe) ** k) |
| |
|
| | def _fit_logistic(x, y, domain=(DATA_MIN, DATA_MAX), n=TREND_POINTS): |
| | """ |
| | Аппроксимация логистической кривой. |
| | Возвращает X, Y или (None, None), если фит не удался. |
| | """ |
| | x = np.asarray(x, dtype=float) |
| | y = np.asarray(y, dtype=float) |
| | m = np.isfinite(x) & np.isfinite(y) |
| | if m.sum() < 4: |
| | return None, None |
| |
|
| | x_m, y_m = x[m], y[m] |
| | x_min = max(float(np.min(x_m)), float(domain[0])) |
| | x_max = min(float(np.max(x_m)), float(domain[1])) |
| | if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min: |
| | return None, None |
| |
|
| | x_pos = x_m[x_m > 0] |
| | if x_pos.size == 0: |
| | return None, None |
| |
|
| | R0_init = float(np.percentile(y_m, 10)) |
| | Rmax_init = float(np.percentile(y_m, 90)) |
| | t50_init = float(np.median(x_pos)) |
| | k_init = 1.0 |
| |
|
| | lower = [-10.0, 0.0, 1e-3, 0.01] |
| | upper = [60.0, 80.0, 60.0, 10.0] |
| |
|
| | try: |
| | popt, _ = curve_fit( |
| | _logistic_time, |
| | x_m, |
| | y_m, |
| | p0=[R0_init, Rmax_init, t50_init, k_init], |
| | bounds=(lower, upper), |
| | maxfev=20000, |
| | ) |
| | except Exception: |
| | return None, None |
| |
|
| | X = np.linspace(x_min, x_max, n) |
| | Y = _logistic_time(X, *popt) |
| | return X, Y |
| |
|
| | |
| | fig = go.Figure() |
| |
|
| | line_min = DATA_MIN - PADDING |
| | line_max = DATA_MAX + PADDING |
| | domain = (line_min, line_max) |
| |
|
| | base_font = dict( |
| | family="Inter, Roboto, Helvetica Neue, Arial, sans-serif", |
| | size=BASE_FONT_SIZE, |
| | ) |
| |
|
| | |
| | fig.add_trace( |
| | go.Scatter( |
| | x=[line_min, threshold, threshold, line_min], |
| | y=[line_min, line_min, threshold, threshold], |
| | fill="toself", |
| | fillcolor="rgba(255, 82, 82, 0.12)", |
| | line=dict(color="rgba(0,0,0,0)"), |
| | name="Low-risk zone", |
| | legendgroup="zones", |
| | legendgrouptitle_text="Пороги и линии", |
| | showlegend=True, |
| | hoverinfo="skip", |
| | legendrank=0, |
| | ) |
| | ) |
| | fig.add_trace( |
| | go.Scatter( |
| | x=[threshold, line_max, line_max, threshold], |
| | y=[threshold, threshold, line_max, line_max], |
| | fill="toself", |
| | fillcolor="rgba(76, 175, 80, 0.14)", |
| | line=dict(color="rgba(0,0,0,0)"), |
| | name="High-risk zone", |
| | legendgroup="zones", |
| | showlegend=True, |
| | hoverinfo="skip", |
| | legendrank=0, |
| | ) |
| | ) |
| |
|
| | fig.add_trace( |
| | go.Scatter( |
| | x=[threshold, threshold, None, line_min, line_max], |
| | y=[line_min, line_max, None, threshold, threshold], |
| | mode="lines", |
| | name=rf"$\mathrm{{SYNTAX}}={threshold}$", |
| | legendgroup="zones", |
| | showlegend=True, |
| | line=dict(color="rgba(46,125,50,0.85)", width=LINE_WIDTH, dash="dash"), |
| | legendrank=0, |
| | hoverinfo="skip", |
| | ) |
| | ) |
| |
|
| | x_vals = np.linspace(line_min, line_max, SIGMA_POINTS) |
| | sigma_upper = x_vals + SIGMA_BASE + SIGMA_SLOPE * x_vals |
| | sigma_lower = x_vals - SIGMA_BASE - SIGMA_SLOPE * x_vals |
| | two_sigma_upper = x_vals + 2 * SIGMA_BASE + 2 * SIGMA_SLOPE * x_vals |
| | two_sigma_lower = x_vals - 2 * SIGMA_BASE - 2 * SIGMA_SLOPE * x_vals |
| |
|
| | fig.add_trace( |
| | go.Scatter( |
| | x=np.concatenate([x_vals, x_vals[::-1]]), |
| | y=np.concatenate([two_sigma_lower, two_sigma_upper[::-1]]), |
| | fill="toself", |
| | fillcolor="rgba(255,193,7,0.18)", |
| | line=dict(color="rgba(0,0,0,0)"), |
| | name=r"$\pm 2\sigma$", |
| | legendgroup="zones", |
| | showlegend=True, |
| | hoverinfo="skip", |
| | legendrank=0, |
| | ) |
| | ) |
| | fig.add_trace( |
| | go.Scatter( |
| | x=np.concatenate([x_vals, x_vals[::-1]]), |
| | y=np.concatenate([sigma_lower, sigma_upper[::-1]]), |
| | fill="toself", |
| | fillcolor="rgba(255,152,0,0.30)", |
| | line=dict(color="rgba(0,0,0,0)"), |
| | name=r"$\pm \sigma$", |
| | legendgroup="zones", |
| | showlegend=True, |
| | hoverinfo="skip", |
| | legendrank=0, |
| | ) |
| | ) |
| |
|
| | fig.add_trace( |
| | go.Scatter( |
| | x=[line_min, line_max], |
| | y=[line_min, line_max], |
| | mode="lines", |
| | name=r"$y=x$", |
| | legendgroup="zones", |
| | showlegend=True, |
| | line=dict(color="rgba(30,30,30,0.85)", width=LINE_WIDTH), |
| | legendrank=0, |
| | ) |
| | ) |
| |
|
| | |
| | first_dataset = True |
| | for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()): |
| | x = np.array(syntax_true, dtype=float) |
| | y = np.array(syntax_pred, dtype=float) |
| | if x.size == 0 or y.size == 0: |
| | continue |
| |
|
| | r2 = r2_values.get(label, None) |
| | recall = recall_values.get(label, None) if recall_values else None |
| | hover_lines = [f"<b>{label}</b>"] |
| | if r2 is not None: |
| | hover_lines.append(f"R² = {r2:.3f}") |
| | if recall is not None: |
| | hover_lines.append(f"Recall = {recall:.3f}") |
| | hovertemplate = ( |
| | "<br>".join(hover_lines) |
| | + "<br>GT: %{x:.3f}<br>Pred: %{y:.3f}<extra></extra>" |
| | ) |
| |
|
| | fig.add_trace( |
| | go.Scatter( |
| | x=x, |
| | y=y, |
| | mode="markers", |
| | name=label, |
| | legendgroup="datasets", |
| | legendgrouptitle_text=("Датасеты" if first_dataset else None), |
| | showlegend=True, |
| | marker=dict( |
| | color=COLORS[i % len(COLORS)], |
| | size=MARKER_SIZE, |
| | opacity=0.96, |
| | symbol=SYMBOLS[i % len(SYMBOLS)], |
| | line=dict( |
| | width=MARKER_LINE_WIDTH, color="rgba(255,255,255,0.95)" |
| | ), |
| | ), |
| | hovertemplate=hovertemplate, |
| | legendrank=20, |
| | ) |
| | ) |
| | first_dataset = False |
| |
|
| | |
| | first_trend = True |
| | for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()): |
| | x = np.array(syntax_true, dtype=float) |
| | y = np.array(syntax_pred, dtype=float) |
| | if x.size == 0 or y.size == 0: |
| | continue |
| |
|
| | Xc, Yc = _fit_logistic(x, y, domain=domain) |
| | if Xc is not None: |
| | fig.add_trace( |
| | go.Scatter( |
| | x=Xc, |
| | y=Yc, |
| | mode="lines", |
| | name=label, |
| | legendgroup="trends", |
| | legendgrouptitle_text=( |
| | "Тренды (логистические)" if first_trend else None |
| | ), |
| | showlegend=True, |
| | line=dict( |
| | color=COLORS[i % len(COLORS)], width=TREND_LINE_WIDTH |
| | ), |
| | hoverinfo="skip", |
| | legendrank=30, |
| | ) |
| | ) |
| | first_trend = False |
| |
|
| | |
| | title_text = f"SYNTAX predictions ({gt_row})" |
| | if postfix: |
| | title_text += f" {postfix}" |
| |
|
| | fig.update_layout( |
| | title=dict( |
| | text=title_text, |
| | x=0.5, |
| | xanchor="center", |
| | font=dict( |
| | size=TITLE_FONT_SIZE, |
| | family=base_font["family"], |
| | color="rgba(15,23,42,1)", |
| | ), |
| | ), |
| | font=base_font, |
| | xaxis_title=r"$\mathrm{SYNTAX\ GT}$", |
| | yaxis_title=r"$\mathrm{SYNTAX\ predictions}$", |
| | width=PLOT_WIDTH, |
| | height=PLOT_HEIGHT, |
| | plot_bgcolor=PLOT_BG_COLOR, |
| | paper_bgcolor=PAPER_BG_COLOR, |
| | legend=dict( |
| | x=LEGEND_X, |
| | y=LEGEND_Y, |
| | bgcolor=LEGEND_BG_COLOR, |
| | bordercolor="#CBD5E1", |
| | borderwidth=1, |
| | font=dict(size=LEGEND_FONT_SIZE, family=base_font["family"]), |
| | tracegroupgap=8, |
| | itemclick="toggle", |
| | itemdoubleclick="toggleothers", |
| | groupclick="toggleitem", |
| | ), |
| | xaxis=dict( |
| | showgrid=True, |
| | gridcolor=GRID_COLOR, |
| | gridwidth=1, |
| | zeroline=False, |
| | tickfont=dict(size=AXIS_TICK_FONT_SIZE), |
| | range=[line_min, line_max], |
| | constrain="domain", |
| | ), |
| | yaxis=dict( |
| | showgrid=True, |
| | gridcolor=GRID_COLOR, |
| | gridwidth=1, |
| | zeroline=False, |
| | tickfont=dict(size=AXIS_TICK_FONT_SIZE), |
| | range=[line_min, line_max], |
| | scaleanchor="x", |
| | scaleratio=1, |
| | constrain="domain", |
| | ), |
| | margin=dict( |
| | l=MARGIN_LEFT, |
| | r=MARGIN_RIGHT, |
| | t=MARGIN_TOP, |
| | b=MARGIN_BOTTOM, |
| | ), |
| | ) |
| |
|
| | |
| | save_dir = "visualizations" |
| | if backbone: |
| | save_dir = os.path.join(save_dir, "backbone") |
| | os.makedirs(save_dir, exist_ok=True) |
| |
|
| | postfix_html = f"{postfix}" if postfix else "syntax" |
| | save_path_html = os.path.join(save_dir, f"{postfix_html}.html") |
| | fig.write_html(save_path_html, include_mathjax="cdn") |
| | print(f"Saved visualization with logistic trends: {save_path_html}") |
| |
|