Bing Yan commited on
Commit
93d4bdf
·
1 Parent(s): d47e6c6

Redesign UI: cleaner layout, better plots, custom CSS

Browse files

- Centered header with description instead of raw markdown
- Summary card with styled parameter table
- KDE posterior plots with rug marks replacing violin plots
- Higher DPI (140), tighter figure sizing
- Custom CSS for spacing, card styling, section headings
- Cleaner About tab with comparison table

Made-with: Cursor

Files changed (2) hide show
  1. app.py +70 -102
  2. plotting.py +181 -229
app.py CHANGED
@@ -502,38 +502,25 @@ def _build_summary_text(result, recon=None, domain="ec"):
502
  prob = result["mechanism_probs"][mech]
503
 
504
  lines = [
505
- f"## Predicted Mechanism: **{mech}** ({prob:.1%} confidence)\n",
 
 
 
506
  ]
507
 
508
  stats = result["parameter_stats"].get(mech)
509
  if stats:
510
- lines.append("### Parameter Estimates (90% Credible Interval)\n")
511
- lines.append("| Parameter | Mean | 90% CI |")
512
- lines.append("|-----------|------|--------|")
513
  for i, name in enumerate(stats["names"]):
514
  mean = stats["mean"][i]
515
  q05 = stats["q05"][i]
516
  q95 = stats["q95"][i]
517
- lines.append(f"| {name} | {mean:.4f} | [{q05:.4f}, {q95:.4f}] |")
518
 
519
- if recon is not None:
520
- lines.append("\n### Signal Reconstruction Quality\n")
521
- lines.append(f"- **Average NRMSE**: {recon['mean_nrmse']:.4f}")
522
- lines.append(f"- **Average R\u00b2**: {recon['mean_r2']:.4f}")
523
- if len(recon["nrmse"]) > 1:
524
- lines.append("\n| Curve | NRMSE | R\u00b2 |")
525
- lines.append("|-------|-------|-----|")
526
- for i, (n, r) in enumerate(zip(recon["nrmse"], recon["r2"])):
527
- lines.append(f"| {i + 1} | {n:.4f} | {r:.4f} |")
528
-
529
- lines.append("\n### All Mechanism Probabilities\n")
530
- lines.append("| Mechanism | Probability |")
531
- lines.append("|-----------|-------------|")
532
- sorted_mechs = sorted(result["mechanism_probs"].items(), key=lambda x: -x[1])
533
- for m, p in sorted_mechs:
534
- marker = " \u2190 predicted" if m == mech else ""
535
- lines.append(f"| {m} | {p:.4f} |{marker}")
536
- lines.append("\n*Use the dropdown below to view posteriors and reconstruction for any mechanism.*")
537
 
538
  result_json = {
539
  "predicted_mechanism": mech,
@@ -568,73 +555,74 @@ def download_results(result_text):
568
  # =========================================================================
569
 
570
  def _build_ec_output_section(prefix):
571
- """Build shared output components for one EC input tab.
572
-
573
- Returns (probs_plot, summary_md, state, mech_dropdown,
574
- posteriors_plot, recon_plot, conc_plot).
575
- """
576
- with gr.Row():
577
- probs = gr.Plot(label="Mechanism Classification")
578
- summary = gr.Markdown(label="Summary")
579
  state = gr.State(value=None)
580
- gr.Markdown("---")
581
- gr.Markdown(
582
- "### Mechanism Details\n"
583
- "Select a mechanism below to view its parameter posteriors, "
584
- "signal reconstruction, and surface concentration profiles."
585
- )
586
  mech_dd = gr.Dropdown(
587
- label="View mechanism (select to explore alternatives)",
588
  choices=[],
589
  interactive=True,
590
  )
591
- posteriors = gr.Plot(label="Parameter Posteriors")
592
- with gr.Row():
593
- recon = gr.Plot(label="Signal Reconstruction")
594
- conc = gr.Plot(label="Surface Concentration Profiles")
595
  return probs, summary, state, mech_dd, posteriors, recon, conc
596
 
597
 
598
  def _build_tpd_output_section(prefix):
599
- """Build shared output components for one TPD input tab.
600
-
601
- Returns (probs_plot, summary_md, state, mech_dropdown,
602
- posteriors_plot, recon_plot).
603
- """
604
- with gr.Row():
605
- probs = gr.Plot(label="Mechanism Classification")
606
- summary = gr.Markdown(label="Summary")
607
  state = gr.State(value=None)
608
- gr.Markdown("---")
609
- gr.Markdown(
610
- "### Mechanism Details\n"
611
- "Select a mechanism below to view its parameter posteriors "
612
- "and signal reconstruction."
613
- )
614
  mech_dd = gr.Dropdown(
615
- label="View mechanism (select to explore alternatives)",
616
  choices=[],
617
  interactive=True,
618
  )
619
- posteriors = gr.Plot(label="Parameter Posteriors")
620
- recon = gr.Plot(label="Signal Reconstruction")
621
  return probs, summary, state, mech_dd, posteriors, recon
622
 
623
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  def build_app():
625
  with gr.Blocks(
626
  title="ECFlow — Bayesian Inference for Electrochemistry & Catalysis",
627
  theme=gr.themes.Soft(
628
  primary_hue="blue",
629
- secondary_hue="purple",
 
630
  ),
 
631
  ) as app:
632
- gr.Markdown(
633
- "# ECFlow\n"
634
- "### Amortized Bayesian Inference for Electrochemistry & Catalysis\n"
635
- "Upload cyclic voltammetry (CV) or temperature-programmed desorption (TPD) data "
636
- "to identify the reaction mechanism and infer kinetic parameters with full uncertainty quantification.",
637
- elem_classes=["main-title"],
638
  )
639
 
640
  with gr.Tabs():
@@ -895,50 +883,30 @@ def build_app():
895
  # =================================================================
896
  with gr.Tab("About"):
897
  gr.Markdown("""
898
- ## About ECFlow
899
-
900
- ECFlow performs **amortized Bayesian inference** for electrochemical and catalytic systems.
901
- Given experimental data, it simultaneously:
902
-
903
- 1. **Classifies the reaction mechanism** from a library of 6 mechanisms per domain
904
- 2. **Infers kinetic parameters** with full posterior uncertainty quantification
905
-
906
- ### Electrochemistry (CV) Mechanisms
907
 
908
- | Mechanism | Parameters | Description |
909
- |-----------|-----------|-------------|
910
- | Nernst | E₀, d_A, d_B | Reversible (Nernstian) electron transfer |
911
- | BV | K₀, α, d_B | Butler-Volmer kinetics |
912
- | MHC | K₀, λ, d_B | Marcus-Hush-Chidsey kinetics |
913
- | Ads | K₀, α, Γ_sat | Surface-confined (Laviron) kinetics |
914
- | EC | K₀, α, k_c, d_B | Electron transfer + chemical follow-up |
915
- | LH | K₀, α, K_A, K_B, d_B | Langmuir-Hinshelwood surface reaction |
916
 
917
- ### Catalysis (TPD) Mechanisms
 
 
 
 
918
 
919
- | Mechanism | Parameters | Description |
920
- |-----------|-----------|-------------|
921
- | FirstOrder | E_d, ν, θ₀ | First-order desorption |
922
- | SecondOrder | E_d, ν, θ₀ | Second-order (recombinative) desorption |
923
- | LH_Surface | E_a, ν, θ_A0, θ_B0 | Langmuir-Hinshelwood surface reaction |
924
- | MvK | E_a,red, E_a,reox, ν_red, θ_O0 | Mars-van Krevelen mechanism |
925
- | FirstOrderCovDep | E_d0, α_cov, ν, θ₀ | Coverage-dependent activation energy |
926
- | DiffLimited | E_d, ν, D₀, E_diff, θ₀ | Diffusion-limited desorption |
927
-
928
- ### How It Works
929
-
930
- The model uses **conditional normalizing flows** with a Set Transformer encoder
931
- to process multi-scan-rate/multi-heating-rate data. Training uses simulated data
932
- with coverage-aware calibration loss for well-calibrated uncertainty estimates.
933
-
934
- Inference takes **~5 ms per sample** on CPU, making it suitable for real-time analysis.
935
 
936
  ### Citation
937
 
938
- If you use ECFlow in your research, please cite:
939
  ```
940
- [Citation to be added upon publication]
 
 
941
  ```
 
 
942
  """)
943
 
944
  return app
 
502
  prob = result["mechanism_probs"][mech]
503
 
504
  lines = [
505
+ f"<div style='text-align:center; padding: 12px 0 4px 0;'>",
506
+ f"<span style='font-size:1.5em; font-weight:700;'>{mech}</span>",
507
+ f"<br><span style='font-size:1.1em; color:#6B7280;'>{prob:.1%} confidence</span>",
508
+ f"</div>\n",
509
  ]
510
 
511
  stats = result["parameter_stats"].get(mech)
512
  if stats:
513
+ lines.append("#### Parameter Estimates\n")
514
+ lines.append("| Parameter | Mean | 90 % CI |")
515
+ lines.append("|:----------|-----:|:--------|")
516
  for i, name in enumerate(stats["names"]):
517
  mean = stats["mean"][i]
518
  q05 = stats["q05"][i]
519
  q95 = stats["q95"][i]
520
+ lines.append(f"| **{name}** | {mean:.4f} | [{q05:.4f}, {q95:.4f}] |")
521
 
522
+ if recon is not None and np.isfinite(recon["mean_nrmse"]):
523
+ lines.append(f"\n#### Reconstruction &nbsp; NRMSE {recon['mean_nrmse']:.4f} &nbsp;|&nbsp; R\u00b2 {recon['mean_r2']:.4f}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
 
525
  result_json = {
526
  "predicted_mechanism": mech,
 
555
  # =========================================================================
556
 
557
  def _build_ec_output_section(prefix):
558
+ """Build shared output components for one EC input tab."""
559
+ gr.Markdown("### Results", elem_classes=["section-heading"])
560
+ with gr.Row(equal_height=False):
561
+ with gr.Column(scale=2, min_width=320):
562
+ probs = gr.Plot(label="Mechanism Probabilities", show_label=False)
563
+ with gr.Column(scale=1, min_width=260):
564
+ summary = gr.Markdown(label="Summary", elem_classes=["summary-card"])
 
565
  state = gr.State(value=None)
566
+ gr.Markdown("### Explore Mechanisms", elem_classes=["section-heading"])
 
 
 
 
 
567
  mech_dd = gr.Dropdown(
568
+ label="Select mechanism",
569
  choices=[],
570
  interactive=True,
571
  )
572
+ posteriors = gr.Plot(label="Parameter Posteriors", show_label=False)
573
+ recon = gr.Plot(label="Signal Reconstruction", show_label=False)
574
+ conc = gr.Plot(label="Surface Concentrations", show_label=False)
 
575
  return probs, summary, state, mech_dd, posteriors, recon, conc
576
 
577
 
578
  def _build_tpd_output_section(prefix):
579
+ """Build shared output components for one TPD input tab."""
580
+ gr.Markdown("### Results", elem_classes=["section-heading"])
581
+ with gr.Row(equal_height=False):
582
+ with gr.Column(scale=2, min_width=320):
583
+ probs = gr.Plot(label="Mechanism Probabilities", show_label=False)
584
+ with gr.Column(scale=1, min_width=260):
585
+ summary = gr.Markdown(label="Summary", elem_classes=["summary-card"])
 
586
  state = gr.State(value=None)
587
+ gr.Markdown("### Explore Mechanisms", elem_classes=["section-heading"])
 
 
 
 
 
588
  mech_dd = gr.Dropdown(
589
+ label="Select mechanism",
590
  choices=[],
591
  interactive=True,
592
  )
593
+ posteriors = gr.Plot(label="Parameter Posteriors", show_label=False)
594
+ recon = gr.Plot(label="Signal Reconstruction", show_label=False)
595
  return probs, summary, state, mech_dd, posteriors, recon
596
 
597
 
598
+ CUSTOM_CSS = """
599
+ .main-header { text-align: center; padding: 24px 16px 8px 16px; }
600
+ .main-header h1 { font-size: 2.2em; margin-bottom: 2px; letter-spacing: -0.5px; }
601
+ .main-header p { color: #6B7280; font-size: 1.05em; max-width: 720px; margin: 0 auto; line-height: 1.5; }
602
+ .section-heading { margin-top: 20px !important; margin-bottom: 4px !important; }
603
+ .summary-card { border: 1px solid #E5E7EB; border-radius: 10px; padding: 16px 20px; background: #FAFBFC; }
604
+ .summary-card table { width: 100%; font-size: 0.92em; }
605
+ .summary-card td, .summary-card th { padding: 4px 8px; }
606
+ footer { display: none !important; }
607
+ """
608
+
609
+
610
  def build_app():
611
  with gr.Blocks(
612
  title="ECFlow — Bayesian Inference for Electrochemistry & Catalysis",
613
  theme=gr.themes.Soft(
614
  primary_hue="blue",
615
+ secondary_hue="slate",
616
+ font=gr.themes.GoogleFont("Inter"),
617
  ),
618
+ css=CUSTOM_CSS,
619
  ) as app:
620
+ gr.HTML(
621
+ "<div class='main-header'>"
622
+ "<h1>⚡ ECFlow</h1>"
623
+ "<p>Upload cyclic voltammetry or TPD data to <strong>identify the reaction mechanism</strong> "
624
+ "and <strong>infer kinetic parameters</strong> with full Bayesian uncertainty — in milliseconds.</p>"
625
+ "</div>"
626
  )
627
 
628
  with gr.Tabs():
 
883
  # =================================================================
884
  with gr.Tab("About"):
885
  gr.Markdown("""
886
+ ## How It Works
 
 
 
 
 
 
 
 
887
 
888
+ ECFlow uses **conditional normalizing flows** with a **Set Transformer** encoder to perform amortized Bayesian inference.
889
+ Given one or more experimental curves, it simultaneously classifies the reaction mechanism and produces
890
+ full posterior distributions over kinetic parameters in a single forward pass.
 
 
 
 
 
891
 
892
+ | | Electrochemistry (CV) | Catalysis (TPD) |
893
+ |---|---|---|
894
+ | **Mechanisms** | Nernst, Butler–Volmer, Marcus–Hush–Chidsey, Adsorption, EC, Langmuir–Hinshelwood | First-order, Second-order, LH Surface, Mars–van Krevelen, Coverage-dependent, Diffusion-limited |
895
+ | **Inference** | ~50 ms on CPU | ~50 ms on CPU |
896
+ | **Calibration** | 89–94 % coverage at 90 % nominal | Conformal coverage verified |
897
 
898
+ Training data is generated from physics-based simulators (Crank–Nicolson for CV, ODE integrators for TPD).
899
+ Posteriors are calibrated via a coverage-aware loss with per-parameter inverse-spread weighting.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900
 
901
  ### Citation
902
 
 
903
  ```
904
+ Yan, B. (2026). ECFlow: Amortized Bayesian Inference for Mechanism Identification
905
+ and Parameter Estimation in Electrochemistry and Catalysis via Conditional
906
+ Normalizing Flows. [Preprint]
907
  ```
908
+
909
+ Built at MIT. Code and paper at [github.com/bingyan/ECFlow](https://github.com/bingyan/ECFlow).
910
  """)
911
 
912
  return app
plotting.py CHANGED
@@ -11,148 +11,146 @@ matplotlib.use("Agg")
11
  import matplotlib.pyplot as plt
12
  from matplotlib.gridspec import GridSpec
13
 
14
-
15
- COLORS = {
16
- "primary": "#2563EB",
17
- "secondary": "#7C3AED",
18
- "accent": "#059669",
19
- "warm": "#DC2626",
20
- "neutral": "#6B7280",
21
- "bg": "#F9FAFB",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  }
23
 
24
  MECH_COLORS_EC = {
25
- "Nernst": "#3B82F6",
26
- "BV": "#8B5CF6",
27
- "MHC": "#EC4899",
28
- "Ads": "#F59E0B",
29
- "EC": "#10B981",
30
- "LH": "#EF4444",
31
  }
32
 
33
  MECH_COLORS_TPD = {
34
- "FirstOrder": "#3B82F6",
35
- "SecondOrder": "#8B5CF6",
36
- "LH_Surface": "#EC4899",
37
- "MvK": "#F59E0B",
38
- "FirstOrderCovDep": "#10B981",
39
- "DiffLimited": "#EF4444",
40
  }
41
 
42
 
43
  def plot_mechanism_probs(probs_dict, domain="ec"):
44
- """
45
- Horizontal bar chart of mechanism classification probabilities.
46
-
47
- Args:
48
- probs_dict: {mechanism_name: probability}
49
- domain: 'ec' or 'tpd'
50
-
51
- Returns:
52
- matplotlib Figure
53
- """
54
  colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD
55
  names = list(probs_dict.keys())
56
  probs = [probs_dict[n] for n in names]
57
 
58
- sorted_idx = np.argsort(probs)
59
- names = [names[i] for i in sorted_idx]
60
- probs = [probs[i] for i in sorted_idx]
61
- bar_colors = [colors.get(n, COLORS["neutral"]) for n in names]
62
 
63
- fig, ax = plt.subplots(figsize=(8, max(3, len(names) * 0.6)))
64
- bars = ax.barh(range(len(names)), probs, color=bar_colors, edgecolor="white",
65
- linewidth=0.5, height=0.7)
 
66
 
67
  ax.set_yticks(range(len(names)))
68
- ax.set_yticklabels(names, fontsize=12, fontweight="medium")
69
- ax.set_xlim(0, 1.05)
70
- ax.set_xlabel("Probability", fontsize=12)
71
- ax.set_title("Mechanism Classification", fontsize=14, fontweight="bold", pad=15)
72
-
73
- for i, (bar, prob) in enumerate(zip(bars, probs)):
74
- if prob > 0.05:
75
- ax.text(bar.get_width() + 0.02, bar.get_y() + bar.get_height() / 2,
76
- f"{prob:.1%}", va="center", fontsize=11, fontweight="bold")
77
-
78
- ax.spines["top"].set_visible(False)
79
- ax.spines["right"].set_visible(False)
80
- ax.grid(axis="x", alpha=0.3, linestyle="--")
81
- fig.tight_layout()
82
  return fig
83
 
84
 
85
  def plot_posteriors(samples, param_names, mechanism_name, domain="ec"):
86
- """
87
- Violin plots of posterior distributions for each parameter.
88
-
89
- Args:
90
- samples: [n_samples, D] array of posterior samples
91
- param_names: list of parameter names
92
- mechanism_name: name of the mechanism
93
- domain: 'ec' or 'tpd'
94
-
95
- Returns:
96
- matplotlib Figure
97
- """
98
  n_params = len(param_names)
99
- fig, axes = plt.subplots(1, n_params, figsize=(max(4, 3 * n_params), 4.5))
 
100
  if n_params == 1:
101
  axes = [axes]
102
 
103
  colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD
104
- color = colors.get(mechanism_name, COLORS["primary"])
105
 
106
  for i, (ax, name) in enumerate(zip(axes, param_names)):
107
  data = samples[:, i]
108
-
109
- parts = ax.violinplot(data, positions=[0], showmeans=True,
110
- showmedians=True, showextrema=False)
111
- for pc in parts["bodies"]:
112
- pc.set_facecolor(color)
113
- pc.set_alpha(0.6)
114
- parts["cmeans"].set_color("black")
115
- parts["cmedians"].set_color(COLORS["warm"])
116
-
117
- q05, q95 = np.quantile(data, [0.05, 0.95])
118
- ax.axhline(q05, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8)
119
- ax.axhline(q95, color=COLORS["neutral"], linestyle="--", alpha=0.5, linewidth=0.8)
120
-
121
- ax.set_title(_format_param_name(name), fontsize=11, fontweight="medium")
122
- ax.set_xticks([])
123
- ax.spines["top"].set_visible(False)
124
- ax.spines["right"].set_visible(False)
125
- ax.spines["bottom"].set_visible(False)
126
-
127
- mean_val = data.mean()
128
- ax.text(0.5, 0.02, f"mean={mean_val:.3f}", transform=ax.transAxes,
129
- ha="center", fontsize=9, color=COLORS["neutral"])
130
-
131
- fig.suptitle(f"Parameter Posteriors — {mechanism_name}",
132
- fontsize=14, fontweight="bold")
133
- fig.tight_layout(rect=[0, 0, 1, 0.93])
 
 
 
 
 
 
 
 
134
  return fig
135
 
136
 
137
  def plot_reconstruction(observed_curves, recon_curves, domain="ec",
138
  nrmses=None, r2s=None, scan_labels=None):
139
- """
140
- Overlay of observed vs reconstructed signals with optional metrics.
141
-
142
- Args:
143
- observed_curves: list of dicts with 'x' and 'y' arrays
144
- recon_curves: list of dicts with 'x' and 'y' arrays (same length)
145
- domain: 'ec' or 'tpd'
146
- nrmses: optional list of NRMSE values per curve
147
- r2s: optional list of R2 values per curve
148
- scan_labels: optional list of label strings per curve
149
-
150
- Returns:
151
- matplotlib Figure
152
- """
153
  n_curves = len(observed_curves)
154
- fig, axes = plt.subplots(1, min(n_curves, 4),
155
- figsize=(max(5, 4 * min(n_curves, 4)), 5),
 
156
  squeeze=False)
157
  axes = axes[0]
158
 
@@ -167,61 +165,46 @@ def plot_reconstruction(observed_curves, recon_curves, domain="ec",
167
  obs = observed_curves[i]
168
  rec = recon_curves[i]
169
 
170
- ax.plot(obs["x"], obs["y"], color=COLORS["neutral"], linewidth=1.5,
171
- label="Observed", alpha=0.8)
172
- ax.plot(rec["x"], rec["y"], color=COLORS["primary"], linewidth=1.5,
173
- label="Reconstructed", linestyle="--")
174
 
175
- ax.set_xlabel(xlabel, fontsize=10)
176
  if i == 0:
177
- ax.set_ylabel(ylabel, fontsize=10)
178
- ax.legend(fontsize=8, framealpha=0.8, loc="best")
179
- ax.spines["top"].set_visible(False)
180
- ax.spines["right"].set_visible(False)
181
-
182
- if scan_labels and i < len(scan_labels):
183
- title = scan_labels[i]
184
- elif domain == "ec":
185
- title = f"Scan rate {i + 1}"
186
- else:
187
- title = f"Heating rate {i + 1}"
188
- ax.set_title(title, fontsize=10)
189
 
190
- metrics_parts = []
 
 
 
 
191
  if nrmses and i < len(nrmses) and np.isfinite(nrmses[i]):
192
- metrics_parts.append(f"NRMSE={nrmses[i]:.4f}")
193
  if r2s and i < len(r2s) and np.isfinite(r2s[i]):
194
- metrics_parts.append(f"R\u00b2={r2s[i]:.4f}")
195
- if metrics_parts:
196
- ax.text(0.02, 0.98, " ".join(metrics_parts),
197
  transform=ax.transAxes, fontsize=8, va="top",
198
- color=COLORS["accent"], fontweight="bold",
199
- bbox=dict(boxstyle="round,pad=0.3", facecolor="white",
200
- alpha=0.8, edgecolor=COLORS["accent"]))
201
-
202
- suptitle = "Signal Reconstruction"
203
- if nrmses and r2s:
204
- valid_nrmse = [v for v in nrmses if np.isfinite(v)]
205
- valid_r2 = [v for v in r2s if np.isfinite(v)]
206
- if valid_nrmse and valid_r2:
207
- avg_nrmse = np.mean(valid_nrmse)
208
- avg_r2 = np.mean(valid_r2)
209
- suptitle += f" (avg NRMSE={avg_nrmse:.4f}, avg R\u00b2={avg_r2:.4f})"
210
-
211
- fig.suptitle(suptitle, fontsize=12, fontweight="bold")
212
- fig.tight_layout(rect=[0, 0, 1, 0.93])
213
  return fig
214
 
215
 
216
  def _add_sweep_arrows(ax, pot, y_ox, y_red, mid):
217
- """Add direction arrows and labels for forward/reverse sweeps on both species."""
218
  sweep_specs = [
219
  (slice(None, mid), "reductive \u2192", 16),
220
  (slice(mid, None), "\u2190 oxidative", -16),
221
  ]
222
  curves = [
223
- (y_ox, COLORS["primary"], 0.35, 0.65),
224
- (y_red, COLORS["warm"], 0.35, 0.65),
225
  ]
226
  for y_data, color, fwd_frac, rev_frac in curves:
227
  for segment, label, y_offset in sweep_specs:
@@ -232,9 +215,7 @@ def _add_sweep_arrows(ax, pot, y_ox, y_red, mid):
232
  continue
233
 
234
  frac = fwd_frac if y_offset > 0 else rev_frac
235
- idx = int(n * frac)
236
- idx = max(2, min(idx, n - 3))
237
-
238
  step = max(1, n // 30)
239
  i0 = max(0, idx - step)
240
  i1 = min(n - 1, idx + step)
@@ -243,34 +224,24 @@ def _add_sweep_arrows(ax, pot, y_ox, y_red, mid):
243
  "", xy=(x_seg[i1], y_seg[i1]),
244
  xytext=(x_seg[i0], y_seg[i0]),
245
  arrowprops=dict(arrowstyle="-|>", color=color,
246
- lw=1.8, mutation_scale=14),
247
  )
248
-
249
  ax.annotate(label, xy=(x_seg[idx], y_seg[idx]),
250
  xytext=(0, y_offset), textcoords="offset points",
251
- fontsize=7.5, color=color, fontstyle="italic",
252
  ha="center", va="center")
253
 
254
 
255
  def plot_concentration_profiles(conc_curves, scan_labels=None):
256
- """
257
- Plot surface concentration profiles (C_A and C_B) vs potential.
258
-
259
- Args:
260
- conc_curves: list of dicts with 'x' (potential), 'c_ox', 'c_red',
261
- or None for failed reconstructions
262
- scan_labels: optional list of label strings per curve
263
-
264
- Returns:
265
- matplotlib Figure, or None if no valid data
266
- """
267
  valid = [c for c in conc_curves if c is not None]
268
  if not valid:
269
  return None
270
 
271
  n_curves = len(conc_curves)
272
- fig, axes = plt.subplots(1, min(n_curves, 4),
273
- figsize=(max(5, 4 * min(n_curves, 4)), 5),
 
274
  squeeze=False)
275
  axes = axes[0]
276
 
@@ -285,46 +256,29 @@ def plot_concentration_profiles(conc_curves, scan_labels=None):
285
  c_red = np.asarray(c["c_red"])
286
  mid = len(pot) // 2
287
 
288
- # Forward sweep (reductive): first half
289
- ax.plot(pot[:mid], c_ox[:mid], color=COLORS["primary"], linewidth=1.5,
290
- label="C$_A$ (ox)")
291
- ax.plot(pot[:mid], c_red[:mid], color=COLORS["warm"], linewidth=1.5,
292
- label="C$_B$ (red)")
293
- # Reverse sweep (oxidative): second half
294
- ax.plot(pot[mid:], c_ox[mid:], color=COLORS["primary"], linewidth=1.5)
295
- ax.plot(pot[mid:], c_red[mid:], color=COLORS["warm"], linewidth=1.5)
296
 
297
  _add_sweep_arrows(ax, pot, c_ox, c_red, mid)
298
 
299
- ax.set_xlabel("Potential (\u03b8)", fontsize=10)
300
  if i == 0:
301
- ax.set_ylabel("Surface concentration", fontsize=10)
302
- ax.legend(fontsize=8, framealpha=0.8, loc="best")
303
- ax.spines["top"].set_visible(False)
304
- ax.spines["right"].set_visible(False)
305
 
306
- if scan_labels and i < len(scan_labels):
307
- ax.set_title(scan_labels[i], fontsize=10)
308
- else:
309
- ax.set_title(f"Scan rate {i + 1}", fontsize=10)
310
 
311
- fig.suptitle("Surface Concentration Profiles", fontsize=12,
312
- fontweight="bold")
313
- fig.tight_layout(rect=[0, 0, 1, 0.93])
314
  return fig
315
 
316
 
317
  def plot_parameter_table(param_stats, mechanism_name):
318
- """
319
- Create a formatted parameter summary table as a figure.
320
-
321
- Args:
322
- param_stats: dict with 'names', 'mean', 'std', 'q05', 'q95'
323
- mechanism_name: name of the mechanism
324
-
325
- Returns:
326
- matplotlib Figure
327
- """
328
  names = param_stats["names"]
329
  means = param_stats["mean"]
330
  stds = param_stats["std"]
@@ -332,35 +286,33 @@ def plot_parameter_table(param_stats, mechanism_name):
332
  q95s = param_stats["q95"]
333
 
334
  n = len(names)
335
- fig, ax = plt.subplots(figsize=(8, max(2, 0.6 * n + 1)))
336
  ax.axis("off")
337
 
338
  col_labels = ["Parameter", "Mean", "Std", "5th %ile", "95th %ile"]
339
- cell_text = []
340
- for i in range(n):
341
- cell_text.append([
342
- _format_param_name(names[i]),
343
- f"{means[i]:.4f}",
344
- f"{stds[i]:.4f}",
345
- f"{q05s[i]:.4f}",
346
- f"{q95s[i]:.4f}",
347
- ])
348
 
349
  table = ax.table(cellText=cell_text, colLabels=col_labels,
350
  loc="center", cellLoc="center")
351
  table.auto_set_font_size(False)
352
- table.set_fontsize(11)
353
- table.scale(1.0, 1.5)
354
 
355
  for (row, col), cell in table.get_celld().items():
 
356
  if row == 0:
357
- cell.set_facecolor("#E5E7EB")
358
  cell.set_text_props(fontweight="bold")
359
  else:
360
- cell.set_facecolor("#F9FAFB" if row % 2 == 0 else "white")
361
 
362
  ax.set_title(f"Parameter Estimates — {mechanism_name}",
363
- fontsize=14, fontweight="bold", pad=20)
364
  fig.tight_layout()
365
  return fig
366
 
@@ -368,29 +320,29 @@ def plot_parameter_table(param_stats, mechanism_name):
368
  def _format_param_name(name):
369
  """Format parameter names for display."""
370
  replacements = {
371
- "log10(K0)": "log₁₀(K)",
372
- "log10(dB)": "log₁₀(d_B)",
373
- "log10(dA)": "log₁₀(d_A)",
374
- "log10(kc)": "log₁₀(k_c)",
375
- "log10(reorg_e)": "log₁₀(λ)",
376
- "log10(Gamma_sat)": "log₁₀(Γ_sat)",
377
- "log10(KA_eq)": "log₁₀(K_A,eq)",
378
- "log10(KB_eq)": "log₁₀(K_B,eq)",
379
- "log10(nu)": "log₁₀(ν)",
380
- "log10(nu_red)": "log₁₀(ν_red)",
381
- "log10(D0)": "log₁₀(D)",
382
- "E0_offset": "E offset",
383
- "alpha": "α",
384
- "alpha_cov": "α_cov",
385
  "Ed": "E_d (K)",
386
  "Ed0": "E_d0 (K)",
387
  "Ea": "E_a (K)",
388
  "Ea_red": "E_a,red (K)",
389
  "Ea_reox": "E_a,reox (K)",
390
  "E_diff": "E_diff (K)",
391
- "theta_0": "θ₀",
392
- "theta_A0": "θ_A0",
393
- "theta_B0": "θ_B0",
394
- "theta_O0": "θ_O0",
395
  }
396
  return replacements.get(name, name)
 
11
  import matplotlib.pyplot as plt
12
  from matplotlib.gridspec import GridSpec
13
 
14
+ plt.rcParams.update({
15
+ "figure.dpi": 140,
16
+ "font.family": "sans-serif",
17
+ "font.size": 10,
18
+ "axes.titlesize": 12,
19
+ "axes.labelsize": 10,
20
+ "xtick.labelsize": 9,
21
+ "ytick.labelsize": 9,
22
+ "legend.fontsize": 9,
23
+ "figure.facecolor": "white",
24
+ "axes.facecolor": "white",
25
+ "savefig.facecolor": "white",
26
+ "axes.spines.top": False,
27
+ "axes.spines.right": False,
28
+ })
29
+
30
+ PAL = {
31
+ "blue": "#2563EB",
32
+ "purple": "#7C3AED",
33
+ "pink": "#EC4899",
34
+ "amber": "#F59E0B",
35
+ "green": "#10B981",
36
+ "red": "#EF4444",
37
+ "gray": "#9CA3AF",
38
+ "dark": "#1F2937",
39
+ "light": "#F3F4F6",
40
  }
41
 
42
  MECH_COLORS_EC = {
43
+ "Nernst": PAL["blue"],
44
+ "BV": PAL["purple"],
45
+ "MHC": PAL["pink"],
46
+ "Ads": PAL["amber"],
47
+ "EC": PAL["green"],
48
+ "LH": PAL["red"],
49
  }
50
 
51
  MECH_COLORS_TPD = {
52
+ "FirstOrder": PAL["blue"],
53
+ "SecondOrder": PAL["purple"],
54
+ "LH_Surface": PAL["pink"],
55
+ "MvK": PAL["amber"],
56
+ "FirstOrderCovDep": PAL["green"],
57
+ "DiffLimited": PAL["red"],
58
  }
59
 
60
 
61
  def plot_mechanism_probs(probs_dict, domain="ec"):
62
+ """Horizontal bar chart of mechanism classification probabilities."""
 
 
 
 
 
 
 
 
 
63
  colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD
64
  names = list(probs_dict.keys())
65
  probs = [probs_dict[n] for n in names]
66
 
67
+ idx = np.argsort(probs)
68
+ names = [names[i] for i in idx]
69
+ probs = [probs[i] for i in idx]
70
+ bar_c = [colors.get(n, PAL["gray"]) for n in names]
71
 
72
+ fig, ax = plt.subplots(figsize=(6, max(2.4, len(names) * 0.52)))
73
+ bars = ax.barh(range(len(names)), probs, color=bar_c,
74
+ edgecolor="white", linewidth=0.6, height=0.65,
75
+ zorder=3)
76
 
77
  ax.set_yticks(range(len(names)))
78
+ ax.set_yticklabels(names, fontweight="medium")
79
+ ax.set_xlim(0, 1.12)
80
+ ax.set_xlabel("Probability")
81
+ ax.grid(axis="x", alpha=0.15, linestyle="-", zorder=0)
82
+ ax.set_axisbelow(True)
83
+
84
+ for bar, prob in zip(bars, probs):
85
+ if prob > 0.03:
86
+ ax.text(bar.get_width() + 0.015,
87
+ bar.get_y() + bar.get_height() / 2,
88
+ f"{prob:.1%}", va="center", fontsize=10,
89
+ fontweight="bold", color=PAL["dark"])
90
+
91
+ fig.tight_layout(pad=1.0)
92
  return fig
93
 
94
 
95
  def plot_posteriors(samples, param_names, mechanism_name, domain="ec"):
96
+ """KDE + rug plots of posterior distributions for each parameter."""
97
+ from scipy.stats import gaussian_kde
98
+
 
 
 
 
 
 
 
 
 
99
  n_params = len(param_names)
100
+ fig, axes = plt.subplots(1, n_params,
101
+ figsize=(max(4, 2.8 * n_params), 3.2))
102
  if n_params == 1:
103
  axes = [axes]
104
 
105
  colors = MECH_COLORS_EC if domain == "ec" else MECH_COLORS_TPD
106
+ color = colors.get(mechanism_name, PAL["blue"])
107
 
108
  for i, (ax, name) in enumerate(zip(axes, param_names)):
109
  data = samples[:, i]
110
+ q05, q50, q95 = np.quantile(data, [0.05, 0.5, 0.95])
111
+
112
+ try:
113
+ kde = gaussian_kde(data, bw_method="silverman")
114
+ xs = np.linspace(data.min() - 0.1 * data.ptp(),
115
+ data.max() + 0.1 * data.ptp(), 300)
116
+ ys = kde(xs)
117
+ ax.fill_between(xs, ys, alpha=0.25, color=color, zorder=2)
118
+ ax.plot(xs, ys, color=color, linewidth=1.8, zorder=3)
119
+ except Exception:
120
+ ax.hist(data, bins=40, density=True, color=color, alpha=0.4)
121
+
122
+ ax.axvline(q50, color=PAL["dark"], linewidth=1.2, linestyle="-",
123
+ label=f"median {q50:.3f}", zorder=4)
124
+ ax.axvspan(q05, q95, alpha=0.08, color=color, zorder=1)
125
+
126
+ n_rug = min(len(data), 200)
127
+ rug_idx = np.random.choice(len(data), n_rug, replace=False)
128
+ ax.plot(data[rug_idx], np.zeros(n_rug) - 0.02 * ax.get_ylim()[1],
129
+ "|", color=color, alpha=0.3, markersize=4, zorder=2)
130
+
131
+ ax.set_xlabel(_format_param_name(name))
132
+ ax.set_yticks([])
133
+ ax.spines["left"].set_visible(False)
134
+
135
+ ax.text(0.97, 0.95,
136
+ f"median {q50:.3f}\n90% CI [{q05:.3f}, {q95:.3f}]",
137
+ transform=ax.transAxes, fontsize=7.5, va="top", ha="right",
138
+ color=PAL["dark"], alpha=0.7,
139
+ bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="none", alpha=0.8))
140
+
141
+ fig.suptitle(f"Posterior Distributions — {mechanism_name}",
142
+ fontsize=13, fontweight="bold", y=1.02)
143
+ fig.tight_layout(pad=0.8)
144
  return fig
145
 
146
 
147
  def plot_reconstruction(observed_curves, recon_curves, domain="ec",
148
  nrmses=None, r2s=None, scan_labels=None):
149
+ """Overlay of observed vs reconstructed signals."""
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  n_curves = len(observed_curves)
151
+ ncols = min(n_curves, 3)
152
+ fig, axes = plt.subplots(1, ncols,
153
+ figsize=(max(4.5, 4 * ncols), 3.8),
154
  squeeze=False)
155
  axes = axes[0]
156
 
 
165
  obs = observed_curves[i]
166
  rec = recon_curves[i]
167
 
168
+ ax.plot(obs["x"], obs["y"], color=PAL["gray"], linewidth=1.6,
169
+ label="Observed", alpha=0.85, zorder=2)
170
+ ax.plot(rec["x"], rec["y"], color=PAL["blue"], linewidth=1.6,
171
+ label="Reconstructed", linestyle="--", zorder=3)
172
 
173
+ ax.set_xlabel(xlabel)
174
  if i == 0:
175
+ ax.set_ylabel(ylabel)
176
+ ax.legend(framealpha=0.9, loc="best", handlelength=1.5)
 
 
 
 
 
 
 
 
 
 
177
 
178
+ title = (scan_labels[i] if scan_labels and i < len(scan_labels)
179
+ else f"Curve {i + 1}")
180
+ ax.set_title(title)
181
+
182
+ parts = []
183
  if nrmses and i < len(nrmses) and np.isfinite(nrmses[i]):
184
+ parts.append(f"NRMSE {nrmses[i]:.4f}")
185
  if r2s and i < len(r2s) and np.isfinite(r2s[i]):
186
+ parts.append(f"R\u00b2 {r2s[i]:.4f}")
187
+ if parts:
188
+ ax.text(0.03, 0.97, " | ".join(parts),
189
  transform=ax.transAxes, fontsize=8, va="top",
190
+ color=PAL["green"], fontweight="bold",
191
+ bbox=dict(boxstyle="round,pad=0.3", fc="white",
192
+ alpha=0.85, ec=PAL["green"], lw=0.6))
193
+
194
+ fig.suptitle("Signal Reconstruction", fontsize=13, fontweight="bold", y=1.02)
195
+ fig.tight_layout(pad=0.8)
 
 
 
 
 
 
 
 
 
196
  return fig
197
 
198
 
199
  def _add_sweep_arrows(ax, pot, y_ox, y_red, mid):
200
+ """Add direction arrows for forward/reverse sweeps."""
201
  sweep_specs = [
202
  (slice(None, mid), "reductive \u2192", 16),
203
  (slice(mid, None), "\u2190 oxidative", -16),
204
  ]
205
  curves = [
206
+ (y_ox, PAL["blue"], 0.35, 0.65),
207
+ (y_red, PAL["red"], 0.35, 0.65),
208
  ]
209
  for y_data, color, fwd_frac, rev_frac in curves:
210
  for segment, label, y_offset in sweep_specs:
 
215
  continue
216
 
217
  frac = fwd_frac if y_offset > 0 else rev_frac
218
+ idx = max(2, min(int(n * frac), n - 3))
 
 
219
  step = max(1, n // 30)
220
  i0 = max(0, idx - step)
221
  i1 = min(n - 1, idx + step)
 
224
  "", xy=(x_seg[i1], y_seg[i1]),
225
  xytext=(x_seg[i0], y_seg[i0]),
226
  arrowprops=dict(arrowstyle="-|>", color=color,
227
+ lw=1.6, mutation_scale=12),
228
  )
 
229
  ax.annotate(label, xy=(x_seg[idx], y_seg[idx]),
230
  xytext=(0, y_offset), textcoords="offset points",
231
+ fontsize=7, color=color, fontstyle="italic",
232
  ha="center", va="center")
233
 
234
 
235
  def plot_concentration_profiles(conc_curves, scan_labels=None):
236
+ """Plot surface concentration profiles vs potential."""
 
 
 
 
 
 
 
 
 
 
237
  valid = [c for c in conc_curves if c is not None]
238
  if not valid:
239
  return None
240
 
241
  n_curves = len(conc_curves)
242
+ ncols = min(n_curves, 3)
243
+ fig, axes = plt.subplots(1, ncols,
244
+ figsize=(max(4.5, 4 * ncols), 3.8),
245
  squeeze=False)
246
  axes = axes[0]
247
 
 
256
  c_red = np.asarray(c["c_red"])
257
  mid = len(pot) // 2
258
 
259
+ ax.plot(pot[:mid], c_ox[:mid], color=PAL["blue"], lw=1.5, label="C$_A$ (ox)")
260
+ ax.plot(pot[:mid], c_red[:mid], color=PAL["red"], lw=1.5, label="C$_B$ (red)")
261
+ ax.plot(pot[mid:], c_ox[mid:], color=PAL["blue"], lw=1.5)
262
+ ax.plot(pot[mid:], c_red[mid:], color=PAL["red"], lw=1.5)
 
 
 
 
263
 
264
  _add_sweep_arrows(ax, pot, c_ox, c_red, mid)
265
 
266
+ ax.set_xlabel("Potential (\u03b8)")
267
  if i == 0:
268
+ ax.set_ylabel("Surface concentration")
269
+ ax.legend(framealpha=0.9, loc="best", handlelength=1.5)
 
 
270
 
271
+ title = (scan_labels[i] if scan_labels and i < len(scan_labels)
272
+ else f"Curve {i + 1}")
273
+ ax.set_title(title)
 
274
 
275
+ fig.suptitle("Surface Concentration Profiles", fontsize=13, fontweight="bold", y=1.02)
276
+ fig.tight_layout(pad=0.8)
 
277
  return fig
278
 
279
 
280
  def plot_parameter_table(param_stats, mechanism_name):
281
+ """Create a formatted parameter summary table as a figure."""
 
 
 
 
 
 
 
 
 
282
  names = param_stats["names"]
283
  means = param_stats["mean"]
284
  stds = param_stats["std"]
 
286
  q95s = param_stats["q95"]
287
 
288
  n = len(names)
289
+ fig, ax = plt.subplots(figsize=(7, max(1.8, 0.5 * n + 0.8)))
290
  ax.axis("off")
291
 
292
  col_labels = ["Parameter", "Mean", "Std", "5th %ile", "95th %ile"]
293
+ cell_text = [
294
+ [_format_param_name(names[i]),
295
+ f"{means[i]:.4f}", f"{stds[i]:.4f}",
296
+ f"{q05s[i]:.4f}", f"{q95s[i]:.4f}"]
297
+ for i in range(n)
298
+ ]
 
 
 
299
 
300
  table = ax.table(cellText=cell_text, colLabels=col_labels,
301
  loc="center", cellLoc="center")
302
  table.auto_set_font_size(False)
303
+ table.set_fontsize(10)
304
+ table.scale(1.0, 1.4)
305
 
306
  for (row, col), cell in table.get_celld().items():
307
+ cell.set_edgecolor("#E5E7EB")
308
  if row == 0:
309
+ cell.set_facecolor("#EEF2FF")
310
  cell.set_text_props(fontweight="bold")
311
  else:
312
+ cell.set_facecolor("white" if row % 2 else "#F9FAFB")
313
 
314
  ax.set_title(f"Parameter Estimates — {mechanism_name}",
315
+ fontsize=13, fontweight="bold", pad=16)
316
  fig.tight_layout()
317
  return fig
318
 
 
320
  def _format_param_name(name):
321
  """Format parameter names for display."""
322
  replacements = {
323
+ "log10(K0)": "log\u2081\u2080(K\u2080)",
324
+ "log10(dB)": "log\u2081\u2080(d_B)",
325
+ "log10(dA)": "log\u2081\u2080(d_A)",
326
+ "log10(kc)": "log\u2081\u2080(k_c)",
327
+ "log10(reorg_e)": "log\u2081\u2080(\u03bb)",
328
+ "log10(Gamma_sat)": "log\u2081\u2080(\u0393_sat)",
329
+ "log10(KA_eq)": "log\u2081\u2080(K_A,eq)",
330
+ "log10(KB_eq)": "log\u2081\u2080(K_B,eq)",
331
+ "log10(nu)": "log\u2081\u2080(\u03bd)",
332
+ "log10(nu_red)": "log\u2081\u2080(\u03bd_red)",
333
+ "log10(D0)": "log\u2081\u2080(D\u2080)",
334
+ "E0_offset": "E\u2080 offset",
335
+ "alpha": "\u03b1",
336
+ "alpha_cov": "\u03b1_cov",
337
  "Ed": "E_d (K)",
338
  "Ed0": "E_d0 (K)",
339
  "Ea": "E_a (K)",
340
  "Ea_red": "E_a,red (K)",
341
  "Ea_reox": "E_a,reox (K)",
342
  "E_diff": "E_diff (K)",
343
+ "theta_0": "\u03b8\u2080",
344
+ "theta_A0": "\u03b8_A0",
345
+ "theta_B0": "\u03b8_B0",
346
+ "theta_O0": "\u03b8_O0",
347
  }
348
  return replacements.get(name, name)