ecflow / plotting.py
Bing Yan
Improve UI: mechanism names, dynamic parameter table, cleaner plots
1b135ef
"""
Visualization utilities for ECFlow web app.
Generates matplotlib figures for mechanism classification, parameter
posteriors, and signal reconstruction overlays.
"""
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
COLORS = {
"primary": "#2563EB",
"secondary": "#7C3AED",
"accent": "#059669",
"warm": "#DC2626",
"neutral": "#6B7280",
"bg": "#F9FAFB",
}
MECH_COLORS_EC = {
"Nernst": "#3B82F6",
"BV": "#8B5CF6",
"MHC": "#EC4899",
"Ads": "#F59E0B",
"EC": "#10B981",
"LH": "#EF4444",
}
MECH_COLORS_TPD = {
"FirstOrder": "#3B82F6",
"SecondOrder": "#8B5CF6",
"LH_Surface": "#EC4899",
"MvK": "#F59E0B",
"FirstOrderCovDep": "#10B981",
"DiffLimited": "#EF4444",
}
MECH_FULL_NAMES_EC = {
"BV": "Butler–Volmer",
"MHC": "Marcus–Hush–Chidsey",
"Nernst": "Nernstian (reversible)",
"Ads": "Adsorption-coupled",
"EC": "EC mechanism",
"LH": "Langmuir–Hinshelwood",
}
MECH_FULL_NAMES_TPD = {
"FirstOrder": "1st-Order",
"SecondOrder": "2nd-Order",
"LH_Surface": "LH Surface",
"MvK": "Mars–van Krevelen",
"FirstOrderCovDep": "1st-Order Cov-Dep",
"DiffLimited": "Diff-Limited",
}
def plot_mechanism_probs(probs_dict, domain="ec"):
"""
Horizontal bar chart of mechanism classification probabilities.
Args:
probs_dict: {mechanism_name: probability}
domain: 'ec' or 'tpd'
Returns:
matplotlib Figure
"""
colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD
full_names = MECH_FULL_NAMES_EC if domain == "ec" else MECH_FULL_NAMES_TPD
names = list(probs_dict.keys())
probs = [probs_dict[n] for n in names]
sorted_idx = np.argsort(probs)
names = [names[i] for i in sorted_idx]
probs = [probs[i] for i in sorted_idx]
bar_colors = [colors.get(n, COLORS["neutral"]) for n in names]
display_names = [f"{n} ({full_names.get(n, n)})" for n in names]
fig, ax = plt.subplots(figsize=(9, max(3, len(names) * 0.7)))
bars = ax.barh(range(len(names)), probs, color=bar_colors, edgecolor="white",
linewidth=0.5, height=0.7)
ax.set_yticks(range(len(names)))
ax.set_yticklabels(display_names, fontsize=11, fontweight="medium")
ax.set_xlim(0, 1.05)
ax.set_xlabel("Probability", fontsize=12)
ax.set_title("Mechanism Classification", fontsize=14, fontweight="bold", pad=15)
for i, (bar, prob) in enumerate(zip(bars, probs)):
if prob > 0.05:
ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height() / 2,
f"{prob:.1%}", va="center", fontsize=11, fontweight="bold")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.grid(axis="x", alpha=0.3, linestyle="--")
fig.tight_layout()
return fig
def plot_posteriors(samples, param_names, mechanism_name, domain="ec"):
"""
Violin plots of posterior distributions for each parameter.
Args:
samples: [n_samples, D] array of posterior samples
param_names: list of parameter names
mechanism_name: name of the mechanism
domain: 'ec' or 'tpd'
Returns:
matplotlib Figure
"""
n_params = len(param_names)
fig, axes = plt.subplots(1, n_params, figsize=(max(4, 3 * n_params), 4.5))
if n_params == 1:
axes = [axes]
colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD
color = colors.get(mechanism_name, COLORS["primary"])
for i, (ax, name) in enumerate(zip(axes, param_names)):
data = samples[:, i]
parts = ax.violinplot(data, positions=[0], showmeans=True,
showmedians=True, showextrema=False)
for pc in parts["bodies"]:
pc.set_facecolor(color)
pc.set_alpha(0.6)
parts["cmeans"].set_color("black")
parts["cmedians"].set_color(COLORS["warm"])
q05, q95 = np.quantile(data, [0.05, 0.95])
ax.axhline(q05, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8)
ax.axhline(q95, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8)
ax.set_title(_format_param_name(name), fontsize=11, fontweight="medium")
ax.set_xticks([])
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["bottom"].set_visible(False)
mean_val = data.mean()
ax.text(0.5, 0.02, f"mean={mean_val:.3f}", transform=ax.transAxes,
ha="center", fontsize=9, color=COLORS["neutral"])
fig.suptitle(f"Parameter Posteriors — {mechanism_name}",
fontsize=14, fontweight="bold")
fig.tight_layout(rect=[0, 0, 1, 0.93])
return fig
def plot_reconstruction(observed_curves, recon_curves, domain="ec",
nrmses=None, r2s=None, scan_labels=None):
"""
Overlay of observed vs reconstructed signals with optional metrics.
Args:
observed_curves: list of dicts with 'x' and 'y' arrays
recon_curves: list of dicts with 'x' and 'y' arrays (same length)
domain: 'ec' or 'tpd'
nrmses: optional list of NRMSE values per curve
r2s: optional list of R2 values per curve
scan_labels: optional list of label strings per curve
Returns:
matplotlib Figure
"""
n_curves = len(observed_curves)
fig, axes = plt.subplots(1, min(n_curves, 4),
figsize=(max(5, 4 * min(n_curves, 4)), 5),
squeeze=False)
axes = axes[0]
xlabel = "Potential (\u03b8)" if domain == "ec" else "Temperature (K)"
ylabel = "Flux" if domain == "ec" else "Rate"
for i, ax in enumerate(axes):
if i >= n_curves:
ax.set_visible(False)
continue
obs = observed_curves[i]
rec = recon_curves[i]
ax.plot(obs["x"], obs["y"], color=COLORS["neutral"], linewidth=1.5,
label="Observed", alpha=0.8)
ax.plot(rec["x"], rec["y"], color=COLORS["primary"], linewidth=1.5,
label="Reconstructed", linestyle="--")
ax.set_xlabel(xlabel, fontsize=10)
if i == 0:
ax.set_ylabel(ylabel, fontsize=10)
ax.legend(fontsize=8, framealpha=0.8, loc="best")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
if scan_labels and i < len(scan_labels):
title = scan_labels[i]
elif domain == "ec":
title = f"Scan rate {i + 1}"
else:
title = f"Heating rate {i + 1}"
ax.set_title(title, fontsize=10)
metrics_parts = []
if nrmses and i < len(nrmses) and np.isfinite(nrmses[i]):
metrics_parts.append(f"NRMSE={nrmses[i]:.4f}")
if r2s and i < len(r2s) and np.isfinite(r2s[i]):
metrics_parts.append(f"R\u00b2={r2s[i]:.4f}")
if metrics_parts:
ax.text(0.02, 0.98, " ".join(metrics_parts),
transform=ax.transAxes, fontsize=8, va="top",
color=COLORS["accent"], fontweight="bold",
bbox=dict(boxstyle="round,pad=0.3", facecolor="white",
alpha=0.8, edgecolor=COLORS["accent"]))
suptitle = "Signal Reconstruction"
if nrmses and r2s:
valid_nrmse = [v for v in nrmses if np.isfinite(v)]
valid_r2 = [v for v in r2s if np.isfinite(v)]
if valid_nrmse and valid_r2:
avg_nrmse = np.mean(valid_nrmse)
avg_r2 = np.mean(valid_r2)
suptitle += f" (avg NRMSE={avg_nrmse:.4f}, avg R\u00b2={avg_r2:.4f})"
fig.suptitle(suptitle, fontsize=12, fontweight="bold")
fig.tight_layout(rect=[0, 0, 1, 0.93])
return fig
def _add_sweep_arrows(ax, pot, y_ox, y_red, mid, show_labels=False):
"""Add direction arrows for forward/reverse sweeps on both species."""
sweep_specs = [
(slice(None, mid), 0.35),
(slice(mid, None), 0.65),
]
curves = [
(y_ox, COLORS["primary"]),
(y_red, COLORS["warm"]),
]
for y_data, color in curves:
for segment, frac in sweep_specs:
x_seg = pot[segment]
y_seg = y_data[segment]
n = len(x_seg)
if n < 10:
continue
idx = int(n * frac)
idx = max(2, min(idx, n - 3))
step = max(1, n // 30)
i0 = max(0, idx - step)
i1 = min(n - 1, idx + step)
ax.annotate(
"", xy=(x_seg[i1], y_seg[i1]),
xytext=(x_seg[i0], y_seg[i0]),
arrowprops=dict(arrowstyle="-|>", color=color,
lw=1.8, mutation_scale=14),
)
def plot_concentration_profiles(conc_curves, scan_labels=None):
"""
Plot surface concentration profiles (C_A and C_B) vs potential.
Args:
conc_curves: list of dicts with 'x' (potential), 'c_ox', 'c_red',
or None for failed reconstructions
scan_labels: optional list of label strings per curve
Returns:
matplotlib Figure, or None if no valid data
"""
valid = [c for c in conc_curves if c is not None]
if not valid:
return None
n_curves = len(conc_curves)
fig, axes = plt.subplots(1, min(n_curves, 4),
figsize=(max(5, 4 * min(n_curves, 4)), 5),
squeeze=False)
axes = axes[0]
for i, ax in enumerate(axes):
if i >= n_curves or conc_curves[i] is None:
ax.set_visible(False)
continue
c = conc_curves[i]
pot = np.asarray(c["x"])
c_ox = np.asarray(c["c_ox"])
c_red = np.asarray(c["c_red"])
mid = len(pot) // 2
# Forward sweep (reductive): first half
ax.plot(pot[:mid], c_ox[:mid], color=COLORS["primary"], linewidth=1.5,
label="C$_A$ (ox)")
ax.plot(pot[:mid], c_red[:mid], color=COLORS["warm"], linewidth=1.5,
label="C$_B$ (red)")
# Reverse sweep (oxidative): second half
ax.plot(pot[mid:], c_ox[mid:], color=COLORS["primary"], linewidth=1.5)
ax.plot(pot[mid:], c_red[mid:], color=COLORS["warm"], linewidth=1.5)
_add_sweep_arrows(ax, pot, c_ox, c_red, mid)
ax.set_xlabel("Potential (\u03b8)", fontsize=10)
if i == 0:
ax.set_ylabel("Surface concentration", fontsize=10)
ax.legend(fontsize=8, framealpha=0.8, loc="best")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
if scan_labels and i < len(scan_labels):
ax.set_title(scan_labels[i], fontsize=10)
else:
ax.set_title(f"Scan rate {i + 1}", fontsize=10)
fig.suptitle("Surface Concentration Profiles", fontsize=12,
fontweight="bold")
fig.tight_layout(rect=[0, 0, 1, 0.93])
return fig
def plot_parameter_table(param_stats, mechanism_name):
"""
Create a formatted parameter summary table as a figure.
Args:
param_stats: dict with 'names', 'mean', 'std', 'q05', 'q95'
mechanism_name: name of the mechanism
Returns:
matplotlib Figure
"""
names = param_stats["names"]
means = param_stats["mean"]
stds = param_stats["std"]
q05s = param_stats["q05"]
q95s = param_stats["q95"]
n = len(names)
fig, ax = plt.subplots(figsize=(8, max(2, 0.6 * n + 1)))
ax.axis("off")
col_labels = ["Parameter", "Mean", "Std", "5th %ile", "95th %ile"]
cell_text = []
for i in range(n):
cell_text.append([
_format_param_name(names[i]),
f"{means[i]:.4f}",
f"{stds[i]:.4f}",
f"{q05s[i]:.4f}",
f"{q95s[i]:.4f}",
])
table = ax.table(cellText=cell_text, colLabels=col_labels,
loc="center", cellLoc="center")
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1.0, 1.5)
for (row, col), cell in table.get_celld().items():
if row == 0:
cell.set_facecolor("#E5E7EB")
cell.set_text_props(fontweight="bold")
else:
cell.set_facecolor("#F9FAFB" if row % 2 == 0 else "white")
ax.set_title(f"Parameter Estimates — {mechanism_name}",
fontsize=14, fontweight="bold", pad=20)
fig.tight_layout()
return fig
def _format_param_name(name):
"""Format parameter names for display."""
replacements = {
"log10(K0)": "log₁₀(K₀)",
"log10(dB)": "log₁₀(d_B)",
"log10(dA)": "log₁₀(d_A)",
"log10(kc)": "log₁₀(k_c)",
"log10(reorg_e)": "log₁₀(λ)",
"log10(Gamma_sat)": "log₁₀(Γ_sat)",
"log10(KA_eq)": "log₁₀(K_A,eq)",
"log10(KB_eq)": "log₁₀(K_B,eq)",
"log10(nu)": "log₁₀(ν)",
"log10(nu_red)": "log₁₀(ν_red)",
"log10(D0)": "log₁₀(D₀)",
"E0_offset": "E₀ offset",
"alpha": "α",
"alpha_cov": "α_cov",
"Ed": "E_d (K)",
"Ed0": "E_d0 (K)",
"Ea": "E_a (K)",
"Ea_red": "E_a,red (K)",
"Ea_reox": "E_a,reox (K)",
"E_diff": "E_diff (K)",
"theta_0": "θ₀",
"theta_A0": "θ_A0",
"theta_B0": "θ_B0",
"theta_O0": "θ_O0",
}
return replacements.get(name, name)