| """ |
| Visualization utilities for ECFlow web app. |
| |
| Generates matplotlib figures for mechanism classification, parameter |
| posteriors, and signal reconstruction overlays. |
| """ |
|
|
| import numpy as np |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| from matplotlib.gridspec import GridSpec |
|
|
|
|
| COLORS = { |
| "primary": "#2563EB", |
| "secondary": "#7C3AED", |
| "accent": "#059669", |
| "warm": "#DC2626", |
| "neutral": "#6B7280", |
| "bg": "#F9FAFB", |
| } |
|
|
| MECH_COLORS_EC = { |
| "Nernst": "#3B82F6", |
| "BV": "#8B5CF6", |
| "MHC": "#EC4899", |
| "Ads": "#F59E0B", |
| "EC": "#10B981", |
| "LH": "#EF4444", |
| } |
|
|
| MECH_COLORS_TPD = { |
| "FirstOrder": "#3B82F6", |
| "SecondOrder": "#8B5CF6", |
| "LH_Surface": "#EC4899", |
| "MvK": "#F59E0B", |
| "FirstOrderCovDep": "#10B981", |
| "DiffLimited": "#EF4444", |
| } |
|
|
|
|
| MECH_FULL_NAMES_EC = { |
| "BV": "Butler–Volmer", |
| "MHC": "Marcus–Hush–Chidsey", |
| "Nernst": "Nernstian (reversible)", |
| "Ads": "Adsorption-coupled", |
| "EC": "EC mechanism", |
| "LH": "Langmuir–Hinshelwood", |
| } |
|
|
| MECH_FULL_NAMES_TPD = { |
| "FirstOrder": "1st-Order", |
| "SecondOrder": "2nd-Order", |
| "LH_Surface": "LH Surface", |
| "MvK": "Mars–van Krevelen", |
| "FirstOrderCovDep": "1st-Order Cov-Dep", |
| "DiffLimited": "Diff-Limited", |
| } |
|
|
|
|
| def plot_mechanism_probs(probs_dict, domain="ec"): |
| """ |
| Horizontal bar chart of mechanism classification probabilities. |
| |
| Args: |
| probs_dict: {mechanism_name: probability} |
| domain: 'ec' or 'tpd' |
| |
| Returns: |
| matplotlib Figure |
| """ |
| colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD |
| full_names = MECH_FULL_NAMES_EC if domain == "ec" else MECH_FULL_NAMES_TPD |
| names = list(probs_dict.keys()) |
| probs = [probs_dict[n] for n in names] |
|
|
| sorted_idx = np.argsort(probs) |
| names = [names[i] for i in sorted_idx] |
| probs = [probs[i] for i in sorted_idx] |
| bar_colors = [colors.get(n, COLORS["neutral"]) for n in names] |
| display_names = [f"{n} ({full_names.get(n, n)})" for n in names] |
|
|
| fig, ax = plt.subplots(figsize=(9, max(3, len(names) * 0.7))) |
| bars = ax.barh(range(len(names)), probs, color=bar_colors, edgecolor="white", |
| linewidth=0.5, height=0.7) |
|
|
| ax.set_yticks(range(len(names))) |
| ax.set_yticklabels(display_names, fontsize=11, fontweight="medium") |
| ax.set_xlim(0, 1.05) |
| ax.set_xlabel("Probability", fontsize=12) |
| ax.set_title("Mechanism Classification", fontsize=14, fontweight="bold", pad=15) |
|
|
| for i, (bar, prob) in enumerate(zip(bars, probs)): |
| if prob > 0.05: |
| ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height() / 2, |
| f"{prob:.1%}", va="center", fontsize=11, fontweight="bold") |
|
|
| ax.spines["top"].set_visible(False) |
| ax.spines["right"].set_visible(False) |
| ax.grid(axis="x", alpha=0.3, linestyle="--") |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def plot_posteriors(samples, param_names, mechanism_name, domain="ec"): |
| """ |
| Violin plots of posterior distributions for each parameter. |
| |
| Args: |
| samples: [n_samples, D] array of posterior samples |
| param_names: list of parameter names |
| mechanism_name: name of the mechanism |
| domain: 'ec' or 'tpd' |
| |
| Returns: |
| matplotlib Figure |
| """ |
| n_params = len(param_names) |
| fig, axes = plt.subplots(1, n_params, figsize=(max(4, 3 * n_params), 4.5)) |
| if n_params == 1: |
| axes = [axes] |
|
|
| colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD |
| color = colors.get(mechanism_name, COLORS["primary"]) |
|
|
| for i, (ax, name) in enumerate(zip(axes, param_names)): |
| data = samples[:, i] |
|
|
| parts = ax.violinplot(data, positions=[0], showmeans=True, |
| showmedians=True, showextrema=False) |
| for pc in parts["bodies"]: |
| pc.set_facecolor(color) |
| pc.set_alpha(0.6) |
| parts["cmeans"].set_color("black") |
| parts["cmedians"].set_color(COLORS["warm"]) |
|
|
| q05, q95 = np.quantile(data, [0.05, 0.95]) |
| ax.axhline(q05, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8) |
| ax.axhline(q95, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8) |
|
|
| ax.set_title(_format_param_name(name), fontsize=11, fontweight="medium") |
| ax.set_xticks([]) |
| ax.spines["top"].set_visible(False) |
| ax.spines["right"].set_visible(False) |
| ax.spines["bottom"].set_visible(False) |
|
|
| mean_val = data.mean() |
| ax.text(0.5, 0.02, f"mean={mean_val:.3f}", transform=ax.transAxes, |
| ha="center", fontsize=9, color=COLORS["neutral"]) |
|
|
| fig.suptitle(f"Parameter Posteriors — {mechanism_name}", |
| fontsize=14, fontweight="bold") |
| fig.tight_layout(rect=[0, 0, 1, 0.93]) |
| return fig |
|
|
|
|
| def plot_reconstruction(observed_curves, recon_curves, domain="ec", |
| nrmses=None, r2s=None, scan_labels=None): |
| """ |
| Overlay of observed vs reconstructed signals with optional metrics. |
| |
| Args: |
| observed_curves: list of dicts with 'x' and 'y' arrays |
| recon_curves: list of dicts with 'x' and 'y' arrays (same length) |
| domain: 'ec' or 'tpd' |
| nrmses: optional list of NRMSE values per curve |
| r2s: optional list of R2 values per curve |
| scan_labels: optional list of label strings per curve |
| |
| Returns: |
| matplotlib Figure |
| """ |
| n_curves = len(observed_curves) |
| fig, axes = plt.subplots(1, min(n_curves, 4), |
| figsize=(max(5, 4 * min(n_curves, 4)), 5), |
| squeeze=False) |
| axes = axes[0] |
|
|
| xlabel = "Potential (\u03b8)" if domain == "ec" else "Temperature (K)" |
| ylabel = "Flux" if domain == "ec" else "Rate" |
|
|
| for i, ax in enumerate(axes): |
| if i >= n_curves: |
| ax.set_visible(False) |
| continue |
|
|
| obs = observed_curves[i] |
| rec = recon_curves[i] |
|
|
| ax.plot(obs["x"], obs["y"], color=COLORS["neutral"], linewidth=1.5, |
| label="Observed", alpha=0.8) |
| ax.plot(rec["x"], rec["y"], color=COLORS["primary"], linewidth=1.5, |
| label="Reconstructed", linestyle="--") |
|
|
| ax.set_xlabel(xlabel, fontsize=10) |
| if i == 0: |
| ax.set_ylabel(ylabel, fontsize=10) |
| ax.legend(fontsize=8, framealpha=0.8, loc="best") |
| ax.spines["top"].set_visible(False) |
| ax.spines["right"].set_visible(False) |
|
|
| if scan_labels and i < len(scan_labels): |
| title = scan_labels[i] |
| elif domain == "ec": |
| title = f"Scan rate {i + 1}" |
| else: |
| title = f"Heating rate {i + 1}" |
| ax.set_title(title, fontsize=10) |
|
|
| metrics_parts = [] |
| if nrmses and i < len(nrmses) and np.isfinite(nrmses[i]): |
| metrics_parts.append(f"NRMSE={nrmses[i]:.4f}") |
| if r2s and i < len(r2s) and np.isfinite(r2s[i]): |
| metrics_parts.append(f"R\u00b2={r2s[i]:.4f}") |
| if metrics_parts: |
| ax.text(0.02, 0.98, " ".join(metrics_parts), |
| transform=ax.transAxes, fontsize=8, va="top", |
| color=COLORS["accent"], fontweight="bold", |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="white", |
| alpha=0.8, edgecolor=COLORS["accent"])) |
|
|
| suptitle = "Signal Reconstruction" |
| if nrmses and r2s: |
| valid_nrmse = [v for v in nrmses if np.isfinite(v)] |
| valid_r2 = [v for v in r2s if np.isfinite(v)] |
| if valid_nrmse and valid_r2: |
| avg_nrmse = np.mean(valid_nrmse) |
| avg_r2 = np.mean(valid_r2) |
| suptitle += f" (avg NRMSE={avg_nrmse:.4f}, avg R\u00b2={avg_r2:.4f})" |
|
|
| fig.suptitle(suptitle, fontsize=12, fontweight="bold") |
| fig.tight_layout(rect=[0, 0, 1, 0.93]) |
| return fig |
|
|
|
|
| def _add_sweep_arrows(ax, pot, y_ox, y_red, mid, show_labels=False): |
| """Add direction arrows for forward/reverse sweeps on both species.""" |
| sweep_specs = [ |
| (slice(None, mid), 0.35), |
| (slice(mid, None), 0.65), |
| ] |
| curves = [ |
| (y_ox, COLORS["primary"]), |
| (y_red, COLORS["warm"]), |
| ] |
| for y_data, color in curves: |
| for segment, frac in sweep_specs: |
| x_seg = pot[segment] |
| y_seg = y_data[segment] |
| n = len(x_seg) |
| if n < 10: |
| continue |
|
|
| idx = int(n * frac) |
| idx = max(2, min(idx, n - 3)) |
|
|
| step = max(1, n // 30) |
| i0 = max(0, idx - step) |
| i1 = min(n - 1, idx + step) |
|
|
| ax.annotate( |
| "", xy=(x_seg[i1], y_seg[i1]), |
| xytext=(x_seg[i0], y_seg[i0]), |
| arrowprops=dict(arrowstyle="-|>", color=color, |
| lw=1.8, mutation_scale=14), |
| ) |
|
|
|
|
| def plot_concentration_profiles(conc_curves, scan_labels=None): |
| """ |
| Plot surface concentration profiles (C_A and C_B) vs potential. |
| |
| Args: |
| conc_curves: list of dicts with 'x' (potential), 'c_ox', 'c_red', |
| or None for failed reconstructions |
| scan_labels: optional list of label strings per curve |
| |
| Returns: |
| matplotlib Figure, or None if no valid data |
| """ |
| valid = [c for c in conc_curves if c is not None] |
| if not valid: |
| return None |
|
|
| n_curves = len(conc_curves) |
| fig, axes = plt.subplots(1, min(n_curves, 4), |
| figsize=(max(5, 4 * min(n_curves, 4)), 5), |
| squeeze=False) |
| axes = axes[0] |
|
|
| for i, ax in enumerate(axes): |
| if i >= n_curves or conc_curves[i] is None: |
| ax.set_visible(False) |
| continue |
|
|
| c = conc_curves[i] |
| pot = np.asarray(c["x"]) |
| c_ox = np.asarray(c["c_ox"]) |
| c_red = np.asarray(c["c_red"]) |
| mid = len(pot) // 2 |
|
|
| |
| ax.plot(pot[:mid], c_ox[:mid], color=COLORS["primary"], linewidth=1.5, |
| label="C$_A$ (ox)") |
| ax.plot(pot[:mid], c_red[:mid], color=COLORS["warm"], linewidth=1.5, |
| label="C$_B$ (red)") |
| |
| ax.plot(pot[mid:], c_ox[mid:], color=COLORS["primary"], linewidth=1.5) |
| ax.plot(pot[mid:], c_red[mid:], color=COLORS["warm"], linewidth=1.5) |
|
|
| _add_sweep_arrows(ax, pot, c_ox, c_red, mid) |
|
|
| ax.set_xlabel("Potential (\u03b8)", fontsize=10) |
| if i == 0: |
| ax.set_ylabel("Surface concentration", fontsize=10) |
| ax.legend(fontsize=8, framealpha=0.8, loc="best") |
| ax.spines["top"].set_visible(False) |
| ax.spines["right"].set_visible(False) |
|
|
| if scan_labels and i < len(scan_labels): |
| ax.set_title(scan_labels[i], fontsize=10) |
| else: |
| ax.set_title(f"Scan rate {i + 1}", fontsize=10) |
|
|
| fig.suptitle("Surface Concentration Profiles", fontsize=12, |
| fontweight="bold") |
| fig.tight_layout(rect=[0, 0, 1, 0.93]) |
| return fig |
|
|
|
|
| def plot_parameter_table(param_stats, mechanism_name): |
| """ |
| Create a formatted parameter summary table as a figure. |
| |
| Args: |
| param_stats: dict with 'names', 'mean', 'std', 'q05', 'q95' |
| mechanism_name: name of the mechanism |
| |
| Returns: |
| matplotlib Figure |
| """ |
| names = param_stats["names"] |
| means = param_stats["mean"] |
| stds = param_stats["std"] |
| q05s = param_stats["q05"] |
| q95s = param_stats["q95"] |
|
|
| n = len(names) |
| fig, ax = plt.subplots(figsize=(8, max(2, 0.6 * n + 1))) |
| ax.axis("off") |
|
|
| col_labels = ["Parameter", "Mean", "Std", "5th %ile", "95th %ile"] |
| cell_text = [] |
| for i in range(n): |
| cell_text.append([ |
| _format_param_name(names[i]), |
| f"{means[i]:.4f}", |
| f"{stds[i]:.4f}", |
| f"{q05s[i]:.4f}", |
| f"{q95s[i]:.4f}", |
| ]) |
|
|
| table = ax.table(cellText=cell_text, colLabels=col_labels, |
| loc="center", cellLoc="center") |
| table.auto_set_font_size(False) |
| table.set_fontsize(11) |
| table.scale(1.0, 1.5) |
|
|
| for (row, col), cell in table.get_celld().items(): |
| if row == 0: |
| cell.set_facecolor("#E5E7EB") |
| cell.set_text_props(fontweight="bold") |
| else: |
| cell.set_facecolor("#F9FAFB" if row % 2 == 0 else "white") |
|
|
| ax.set_title(f"Parameter Estimates — {mechanism_name}", |
| fontsize=14, fontweight="bold", pad=20) |
| fig.tight_layout() |
| return fig |
|
|
|
|
| def _format_param_name(name): |
| """Format parameter names for display.""" |
| replacements = { |
| "log10(K0)": "log₁₀(K₀)", |
| "log10(dB)": "log₁₀(d_B)", |
| "log10(dA)": "log₁₀(d_A)", |
| "log10(kc)": "log₁₀(k_c)", |
| "log10(reorg_e)": "log₁₀(λ)", |
| "log10(Gamma_sat)": "log₁₀(Γ_sat)", |
| "log10(KA_eq)": "log₁₀(K_A,eq)", |
| "log10(KB_eq)": "log₁₀(K_B,eq)", |
| "log10(nu)": "log₁₀(ν)", |
| "log10(nu_red)": "log₁₀(ν_red)", |
| "log10(D0)": "log₁₀(D₀)", |
| "E0_offset": "E₀ offset", |
| "alpha": "α", |
| "alpha_cov": "α_cov", |
| "Ed": "E_d (K)", |
| "Ed0": "E_d0 (K)", |
| "Ea": "E_a (K)", |
| "Ea_red": "E_a,red (K)", |
| "Ea_reox": "E_a,reox (K)", |
| "E_diff": "E_diff (K)", |
| "theta_0": "θ₀", |
| "theta_A0": "θ_A0", |
| "theta_B0": "θ_B0", |
| "theta_O0": "θ_O0", |
| } |
| return replacements.get(name, name) |
|
|