"""
Визуализация предсказаний SYNTAX:
- точки (SYNTAX GT vs предсказания модели) для нескольких датасетов;
- зоны риска (низкий / высокий риск);
- области ±σ и ±2σ вокруг диагонали;
- логистические тренды для каждого датасета.
Скрипт не зависит от PyTorch/Lightning и используется на этапе инференса.
Сохранение осуществляется в папку `visualizations/` внутри проекта.
"""
import os
import numpy as np
import plotly.graph_objects as go
from scipy.optimize import curve_fit # type: ignore
def visualize_final_syntax_plotly_multi(
datasets,
r2_values,
gt_row,
postfix=None,
threshold=22.0,
recall_values=None,
backbone=False,
):
"""
Единая визуализация SYNTAX: точки, зоны риска и логистические тренды.
Параметры
---------
datasets : dict[str, tuple[list[float], list[float]]]
Словарь {имя_датасета: (syntax_true_list, syntax_pred_list)}.
r2_values : dict[str, float]
Словарь R^2 по датасетам.
gt_row : str
Строка, попадающая в заголовок (например, "ENSEMBLE" или "BOTH").
postfix : str | None
Суффикс для имени сохраняемого файла.
threshold : float
Порог SYNTAX (обычно 22.0) для разделения зон риска.
recall_values : dict[str, float] | None
Словарь Recall по датасетам (может быть None).
backbone : bool
Если True, сохраняет в `visualizations/backbone`, иначе в `visualizations/`.
"""
# ========== КОНСТАНТЫ ДЛЯ НАСТРОЙКИ ==========
DATA_MIN = 0.0
DATA_MAX = 60.0
PADDING = 0.5
SIGMA_SLOPE = 0.15
SIGMA_BASE = 1.4
PLOT_WIDTH = 980
PLOT_HEIGHT = 980
BASE_FONT_SIZE = 16
TITLE_FONT_SIZE = 22
AXIS_LABEL_FONT_SIZE = BASE_FONT_SIZE
AXIS_TICK_FONT_SIZE = 15
LEGEND_FONT_SIZE = 14
MARKER_SIZE = 11
MARKER_LINE_WIDTH = 1.1
LINE_WIDTH = 2
TREND_LINE_WIDTH = 3
PLOT_BG_COLOR = "rgba(235,238,245,1)"
PAPER_BG_COLOR = "white"
LEGEND_BG_COLOR = "rgba(255,255,255,0.94)"
GRID_COLOR = "rgba(100,116,139,0.18)"
MARGIN_LEFT = 70
MARGIN_RIGHT = 24
MARGIN_TOP = 78
MARGIN_BOTTOM = 70
LEGEND_X = 0.04
LEGEND_Y = 0.99
COLORS = ["#1E88E5", "#8E24AA", "#A0D137", "#EA1D1D", "#06EE0D", "#FB8C00"]
SYMBOLS = ["circle", "x", "square", "diamond", "triangle-up", "star"]
SIGMA_POINTS = 400
TREND_POINTS = 500
# ========== ВСПОМОГАТЕЛЬНЫЕ ФУНКЦИИ ==========
def _logistic_time(t, R0, Rmax, t50, k):
"""Логистическая функция по времени/оценке SYNTAX."""
t = np.asarray(t, dtype=float)
t_safe = np.where(t <= 0, 1e-3, t)
return R0 + (Rmax - R0) / (1.0 + (t50 / t_safe) ** k)
def _fit_logistic(x, y, domain=(DATA_MIN, DATA_MAX), n=TREND_POINTS):
"""
Аппроксимация логистической кривой.
Возвращает X, Y или (None, None), если фит не удался.
"""
x = np.asarray(x, dtype=float)
y = np.asarray(y, dtype=float)
m = np.isfinite(x) & np.isfinite(y)
if m.sum() < 4:
return None, None
x_m, y_m = x[m], y[m]
x_min = max(float(np.min(x_m)), float(domain[0]))
x_max = min(float(np.max(x_m)), float(domain[1]))
if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min:
return None, None
x_pos = x_m[x_m > 0]
if x_pos.size == 0:
return None, None
R0_init = float(np.percentile(y_m, 10))
Rmax_init = float(np.percentile(y_m, 90))
t50_init = float(np.median(x_pos))
k_init = 1.0
lower = [-10.0, 0.0, 1e-3, 0.01]
upper = [60.0, 80.0, 60.0, 10.0]
try:
popt, _ = curve_fit(
_logistic_time,
x_m,
y_m,
p0=[R0_init, Rmax_init, t50_init, k_init],
bounds=(lower, upper),
maxfev=20000,
)
except Exception:
return None, None
X = np.linspace(x_min, x_max, n)
Y = _logistic_time(X, *popt)
return X, Y
# ========== ОСНОВНОЙ КОД ==========
fig = go.Figure()
line_min = DATA_MIN - PADDING
line_max = DATA_MAX + PADDING
domain = (line_min, line_max)
base_font = dict(
family="Inter, Roboto, Helvetica Neue, Arial, sans-serif",
size=BASE_FONT_SIZE,
)
# ---------- Пороги и линии (legendrank=0) ----------
fig.add_trace(
go.Scatter(
x=[line_min, threshold, threshold, line_min],
y=[line_min, line_min, threshold, threshold],
fill="toself",
fillcolor="rgba(255, 82, 82, 0.12)",
line=dict(color="rgba(0,0,0,0)"),
name="Low-risk zone",
legendgroup="zones",
legendgrouptitle_text="Пороги и линии",
showlegend=True,
hoverinfo="skip",
legendrank=0,
)
)
fig.add_trace(
go.Scatter(
x=[threshold, line_max, line_max, threshold],
y=[threshold, threshold, line_max, line_max],
fill="toself",
fillcolor="rgba(76, 175, 80, 0.14)",
line=dict(color="rgba(0,0,0,0)"),
name="High-risk zone",
legendgroup="zones",
showlegend=True,
hoverinfo="skip",
legendrank=0,
)
)
fig.add_trace(
go.Scatter(
x=[threshold, threshold, None, line_min, line_max],
y=[line_min, line_max, None, threshold, threshold],
mode="lines",
name=rf"$\mathrm{{SYNTAX}}={threshold}$",
legendgroup="zones",
showlegend=True,
line=dict(color="rgba(46,125,50,0.85)", width=LINE_WIDTH, dash="dash"),
legendrank=0,
hoverinfo="skip",
)
)
x_vals = np.linspace(line_min, line_max, SIGMA_POINTS)
sigma_upper = x_vals + SIGMA_BASE + SIGMA_SLOPE * x_vals
sigma_lower = x_vals - SIGMA_BASE - SIGMA_SLOPE * x_vals
two_sigma_upper = x_vals + 2 * SIGMA_BASE + 2 * SIGMA_SLOPE * x_vals
two_sigma_lower = x_vals - 2 * SIGMA_BASE - 2 * SIGMA_SLOPE * x_vals
fig.add_trace(
go.Scatter(
x=np.concatenate([x_vals, x_vals[::-1]]),
y=np.concatenate([two_sigma_lower, two_sigma_upper[::-1]]),
fill="toself",
fillcolor="rgba(255,193,7,0.18)",
line=dict(color="rgba(0,0,0,0)"),
name=r"$\pm 2\sigma$",
legendgroup="zones",
showlegend=True,
hoverinfo="skip",
legendrank=0,
)
)
fig.add_trace(
go.Scatter(
x=np.concatenate([x_vals, x_vals[::-1]]),
y=np.concatenate([sigma_lower, sigma_upper[::-1]]),
fill="toself",
fillcolor="rgba(255,152,0,0.30)",
line=dict(color="rgba(0,0,0,0)"),
name=r"$\pm \sigma$",
legendgroup="zones",
showlegend=True,
hoverinfo="skip",
legendrank=0,
)
)
fig.add_trace(
go.Scatter(
x=[line_min, line_max],
y=[line_min, line_max],
mode="lines",
name=r"$y=x$",
legendgroup="zones",
showlegend=True,
line=dict(color="rgba(30,30,30,0.85)", width=LINE_WIDTH),
legendrank=0,
)
)
# ---------- Датасеты (legendrank=20) ----------
first_dataset = True
for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
x = np.array(syntax_true, dtype=float)
y = np.array(syntax_pred, dtype=float)
if x.size == 0 or y.size == 0:
continue
r2 = r2_values.get(label, None)
recall = recall_values.get(label, None) if recall_values else None
hover_lines = [f"{label}"]
if r2 is not None:
hover_lines.append(f"R² = {r2:.3f}")
if recall is not None:
hover_lines.append(f"Recall = {recall:.3f}")
hovertemplate = (
"
".join(hover_lines)
+ "
GT: %{x:.3f}
Pred: %{y:.3f}"
)
fig.add_trace(
go.Scatter(
x=x,
y=y,
mode="markers",
name=label,
legendgroup="datasets",
legendgrouptitle_text=("Датасеты" if first_dataset else None),
showlegend=True,
marker=dict(
color=COLORS[i % len(COLORS)],
size=MARKER_SIZE,
opacity=0.96,
symbol=SYMBOLS[i % len(SYMBOLS)],
line=dict(
width=MARKER_LINE_WIDTH, color="rgba(255,255,255,0.95)"
),
),
hovertemplate=hovertemplate,
legendrank=20,
)
)
first_dataset = False
# ---------- Тренды: логистические (legendrank=30) ----------
first_trend = True
for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
x = np.array(syntax_true, dtype=float)
y = np.array(syntax_pred, dtype=float)
if x.size == 0 or y.size == 0:
continue
Xc, Yc = _fit_logistic(x, y, domain=domain)
if Xc is not None:
fig.add_trace(
go.Scatter(
x=Xc,
y=Yc,
mode="lines",
name=label, # без коротких alias, полное имя датасета
legendgroup="trends",
legendgrouptitle_text=(
"Тренды (логистические)" if first_trend else None
),
showlegend=True,
line=dict(
color=COLORS[i % len(COLORS)], width=TREND_LINE_WIDTH
),
hoverinfo="skip",
legendrank=30,
)
)
first_trend = False
# ---------- оформление ----------
title_text = f"SYNTAX predictions ({gt_row})"
if postfix:
title_text += f" {postfix}"
fig.update_layout(
title=dict(
text=title_text,
x=0.5,
xanchor="center",
font=dict(
size=TITLE_FONT_SIZE,
family=base_font["family"],
color="rgba(15,23,42,1)",
),
),
font=base_font,
xaxis_title=r"$\mathrm{SYNTAX\ GT}$",
yaxis_title=r"$\mathrm{SYNTAX\ predictions}$",
width=PLOT_WIDTH,
height=PLOT_HEIGHT,
plot_bgcolor=PLOT_BG_COLOR,
paper_bgcolor=PAPER_BG_COLOR,
legend=dict(
x=LEGEND_X,
y=LEGEND_Y,
bgcolor=LEGEND_BG_COLOR,
bordercolor="#CBD5E1",
borderwidth=1,
font=dict(size=LEGEND_FONT_SIZE, family=base_font["family"]),
tracegroupgap=8,
itemclick="toggle",
itemdoubleclick="toggleothers",
groupclick="toggleitem",
),
xaxis=dict(
showgrid=True,
gridcolor=GRID_COLOR,
gridwidth=1,
zeroline=False,
tickfont=dict(size=AXIS_TICK_FONT_SIZE),
range=[line_min, line_max],
constrain="domain",
),
yaxis=dict(
showgrid=True,
gridcolor=GRID_COLOR,
gridwidth=1,
zeroline=False,
tickfont=dict(size=AXIS_TICK_FONT_SIZE),
range=[line_min, line_max],
scaleanchor="x",
scaleratio=1,
constrain="domain",
),
margin=dict(
l=MARGIN_LEFT,
r=MARGIN_RIGHT,
t=MARGIN_TOP,
b=MARGIN_BOTTOM,
),
)
# ---------- сохранение ----------
save_dir = "visualizations"
if backbone:
save_dir = os.path.join(save_dir, "backbone")
os.makedirs(save_dir, exist_ok=True)
postfix_html = f"{postfix}" if postfix else "syntax"
save_path_html = os.path.join(save_dir, f"{postfix_html}.html")
fig.write_html(save_path_html, include_mathjax="cdn")
print(f"Saved visualization with logistic trends: {save_path_html}")