| """ |
| Train a GCN on an L‑RMC subgraph and compare to a full‑graph baseline. |
| |
| Modes: |
| - core_mode=forward : Train on core subgraph, then forward on full graph (your current approach). |
| - core_mode=appnp : Train on core subgraph, then seed logits on core and APPNP‑propagate on full graph. |
| |
| Extras: |
| - --expand_core_with_train : Make sure all training labels lie inside the core |
| (C' = C ∪ train_idx) for fair train‑time comparison. |
| - --warm_ft_epochs N : Optional short finetune on the full graph starting |
| from the core model's weights (measure time‑to‑target). |
| |
| It prints: |
| - Dataset stats |
| - Core size and coverage of train/val/test inside the core |
| - Train/Val/Test accuracy for baseline and core model |
| - Wall‑clock times |
| """ |
|
|
| import argparse |
| import json |
| import time |
| import random |
| from statistics import mean, stdev |
| from pathlib import Path |
| from typing import Dict |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import nn, Tensor |
| from torch_geometric.datasets import Planetoid |
| from torch_geometric.nn import GCNConv, APPNP |
| from torch_geometric.utils import subgraph |
|
|
| |
| |
| |
| from rich.console import Console |
| from rich.table import Table |
| from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn |
|
|
| |
| console = Console() |
|
|
| |
| |
| |
| def load_top1_assignment(seeds_json: str, n_nodes: int) -> torch.Tensor: |
| """ |
| seeds_json format (expected): |
| {"clusters": [{"seed_nodes":[...], "score": float, ...}, ...]} |
| We pick the cluster with max (score, size) and return a boolean core mask. |
| |
| Always assume that the seeds json nodes are 1-indexed. |
| """ |
| obj = json.loads(Path(seeds_json).read_text()) |
| clusters = obj.get("clusters", []) |
| if not clusters: |
| return torch.zeros(n_nodes, dtype=torch.bool) |
| best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", [])))) |
| ids = best.get("seed_nodes", []) |
| ids = [int(x) - 1 for x in ids] |
| ids = sorted(set([i for i in ids if 0 <= i < n_nodes])) |
| mask = torch.zeros(n_nodes, dtype=torch.bool) |
| if ids: |
| mask[torch.tensor(ids, dtype=torch.long)] = True |
| return mask |
|
|
|
|
| def coverage_counts(core_mask: torch.Tensor, train_mask: torch.Tensor, |
| val_mask: torch.Tensor, test_mask: torch.Tensor) -> Dict[str, int]: |
| return { |
| "core_size": int(core_mask.sum().item()), |
| "train_in_core": int((core_mask & train_mask).sum().item()), |
| "val_in_core": int((core_mask & val_mask).sum().item()), |
| "test_in_core": int((core_mask & test_mask).sum().item()), |
| } |
|
|
|
|
| def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float: |
| pred = logits[mask].argmax(dim=1) |
| return (pred == y[mask]).float().mean().item() |
|
|
|
|
| def set_seed(seed: int): |
| """Set random seeds for reproducibility across runs.""" |
| random.seed(seed) |
| try: |
| import numpy as np |
| np.random.seed(seed) |
| except Exception: |
| pass |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| |
| |
| |
| class GCN2(nn.Module): |
| def __init__(self, in_dim: int, hid: int, out_dim: int, dropout: float = 0.5): |
| super().__init__() |
| self.c1 = GCNConv(in_dim, hid) |
| self.c2 = GCNConv(hid, out_dim) |
| self.dropout = dropout |
|
|
| def forward(self, x, ei): |
| x = self.c1(x, ei) |
| x = torch.relu(x) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
| x = self.c2(x, ei) |
| return x |
|
|
|
|
| |
| |
| |
| @torch.no_grad() |
| def eval_all(model: nn.Module, data) -> Dict[str, float]: |
| model.eval() |
| logits = model(data.x, data.edge_index) |
| return { |
| "train": accuracy(logits, data.y, data.train_mask), |
| "val": accuracy(logits, data.y, data.val_mask), |
| "test": accuracy(logits, data.y, data.test_mask), |
| } |
|
|
|
|
| def train(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4, patience=100): |
| opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd) |
| best, best_state, bad = -1.0, None, 0 |
|
|
| |
| with Progress( |
| SpinnerColumn(), |
| "[progress.description]{task.description}", |
| TimeElapsedColumn(), |
| transient=True, |
| ) as progress: |
| task = progress.add_task("Training", total=epochs) |
|
|
| for ep in range(1, epochs + 1): |
| model.train() |
| opt.zero_grad(set_to_none=True) |
| out = model(data.x, data.edge_index) |
| loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) |
| loss.backward() |
| opt.step() |
|
|
| |
| with torch.no_grad(): |
| val = accuracy(model(data.x, data.edge_index), data.y, data.val_mask) |
|
|
| if val > best: |
| best, bad = val, 0 |
| best_state = {k: v.detach().clone() for k, v in model.state_dict().items()} |
| else: |
| bad += 1 |
| if bad >= patience: |
| break |
|
|
| progress.update(task, advance=1, description=f"Epoch {ep} | val={val:.4f}") |
|
|
| if best_state is not None: |
| model.load_state_dict(best_state) |
| model.eval() |
|
|
|
|
| def subset_data(data, nodes_idx: torch.Tensor): |
| """ |
| Build an induced subgraph on 'nodes_idx'. Keeps x,y,masks restricted to that set. |
| Returns a shallow copy with edge_index/feature/labels/masks sliced. |
| """ |
| nodes_idx = nodes_idx.to(torch.long) |
| sub_ei, _ = subgraph(nodes_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes) |
| sub = type(data)() |
| sub.x = data.x[nodes_idx] |
| sub.y = data.y[nodes_idx] |
| sub.train_mask = data.train_mask[nodes_idx] |
| sub.val_mask = data.val_mask[nodes_idx] |
| sub.test_mask = data.test_mask[nodes_idx] |
| sub.edge_index = sub_ei |
| sub.num_nodes = sub.x.size(0) |
| return sub |
|
|
|
|
| |
| |
| |
| def appnp_seed_propagate(logits_seed: Tensor, edge_index: Tensor, K=10, alpha=0.1) -> Tensor: |
| """ |
| logits_seed is [N, C] where rows outside the core are zeros. |
| We propagate these logits with APPNP to fill the graph. |
| """ |
| appnp = APPNP(K=K, alpha=alpha) |
| return appnp(logits_seed, edge_index) |
|
|
|
|
| |
| |
| |
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"]) |
| p.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON") |
| p.add_argument("--hidden", type=int, default=64) |
| p.add_argument("--dropout", type=float, default=0.5) |
| p.add_argument("--epochs", type=int, default=200) |
| p.add_argument("--lr", type=float, default=0.01) |
| p.add_argument("--wd", type=float, default=5e-4) |
| p.add_argument("--patience", type=int, default=100) |
| p.add_argument("--core_mode", choices=["forward", "appnp"], default="forward", |
| help="How to evaluate the core model on the full graph.") |
| p.add_argument("--alpha", type=float, default=0.1, help="APPNP teleport prob (Mode B).") |
| p.add_argument("--K", type=int, default=10, help="APPNP steps (Mode B).") |
| p.add_argument("--expand_core_with_train", action="store_true", |
| help="Expand LRMC core with all training nodes (C' = C ∪ train_idx).") |
| p.add_argument("--warm_ft_epochs", type=int, default=0, |
| help="If >0, run a short finetune on the FULL graph starting from the core model.") |
| p.add_argument("--warm_ft_lr", type=float, default=0.005) |
| p.add_argument("--runs", type=int, default=1, |
| help="Number of runs with different seeds to average results.") |
| p.add_argument("-o", "--output_json", type=str, default=None, |
| help="If set, save all computed metrics and settings to this JSON file.") |
| args = p.parse_args() |
|
|
| |
| |
| |
| ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset) |
| data = ds[0] |
| n, e = data.num_nodes, data.edge_index.size(1) // 2 |
|
|
| console.print(f"[bold cyan]Dataset: {args.dataset} | Nodes: {n} | Edges: {e}[/bold cyan]") |
|
|
| |
| results = { |
| "args": { |
| k: (float(v) if isinstance(v, float) else v) |
| for k, v in vars(args).items() |
| if k != "output_json" |
| }, |
| "dataset": { |
| "name": args.dataset, |
| "num_nodes": int(n), |
| "num_edges": int(e), |
| }, |
| } |
|
|
| def maybe_save_results(): |
| """Write results to JSON if the user requested it.""" |
| if not args.output_json: |
| return |
| out_path = Path(args.output_json) |
| try: |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| except Exception: |
| pass |
| with out_path.open("w") as f: |
| json.dump(results, f, indent=2) |
|
|
| |
| |
| |
| core_mask = load_top1_assignment(args.seeds, n) |
| if args.expand_core_with_train: |
| core_mask = core_mask | data.train_mask |
|
|
| C_idx = torch.nonzero(core_mask, as_tuple=False).view(-1) |
| frac = 100.0 * C_idx.numel() / n |
| cov = coverage_counts(core_mask, data.train_mask, data.val_mask, data.test_mask) |
|
|
| console.print(f"[bold green]Loaded LRMC core of size {cov['core_size']} (≈{frac:.2f}% of the graph) from {args.seeds}[/bold green]") |
|
|
| |
| results["core"] = { |
| "source": str(args.seeds), |
| "expanded_with_train": bool(args.expand_core_with_train), |
| "size": int(cov["core_size"]), |
| "fraction": float(frac / 100.0), |
| "coverage": { |
| "train_in_core": int(cov["train_in_core"]), |
| "val_in_core": int(cov["val_in_core"]), |
| "test_in_core": int(cov["test_in_core"]), |
| }, |
| } |
|
|
| |
| cov_table = Table(title="LRMC Core Coverage") |
| cov_table.add_column("Metric", style="cyan") |
| cov_table.add_column("Count", style="magenta") |
| cov_table.add_row("Core Size", str(cov["core_size"])) |
| cov_table.add_row("Train in Core", str(cov["train_in_core"])) |
| cov_table.add_row("Val in Core", str(cov["val_in_core"])) |
| cov_table.add_row("Test in Core", str(cov["test_in_core"])) |
| console.print(cov_table) |
|
|
| |
| |
| |
| if args.runs == 1: |
| |
| |
| |
| set_seed(0) |
| t0 = time.perf_counter() |
| base = GCN2(in_dim=ds.num_node_features, |
| hid=args.hidden, |
| out_dim=ds.num_classes, |
| dropout=args.dropout) |
| train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) |
| base_metrics = eval_all(base, data) |
| t1 = time.perf_counter() |
|
|
| console.print("\n[bold]Baseline (trained on full graph):[/bold]") |
| base_table = Table(show_header=True, header_style="bold magenta") |
| base_table.add_column("Metric", style="cyan") |
| base_table.add_column("Value", style="magenta") |
| base_table.add_row("Train Accuracy", f"{base_metrics['train']:.4f}") |
| base_table.add_row("Validation Accuracy", f"{base_metrics['val']:.4f}") |
| base_table.add_row("Test Accuracy", f"{base_metrics['test']:.4f}") |
| base_table.add_row("Time (s)", f"{t1 - t0:.2f}") |
| console.print(base_table) |
|
|
| |
| results["single_run"] = { |
| "baseline": { |
| "train": float(base_metrics["train"]), |
| "val": float(base_metrics["val"]), |
| "test": float(base_metrics["test"]), |
| "time_s": float(t1 - t0), |
| } |
| } |
|
|
| |
| |
| |
| if C_idx.numel() == 0: |
| console.print("[bold yellow]LRMC core is empty; skipping core model.[/bold yellow]") |
| results["core_empty"] = True |
| maybe_save_results() |
| return |
|
|
| data_C = subset_data(data, C_idx) |
| mC = GCN2(in_dim=ds.num_node_features, |
| hid=args.hidden, |
| out_dim=ds.num_classes, |
| dropout=args.dropout) |
|
|
| t2 = time.perf_counter() |
| train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) |
| t3 = time.perf_counter() |
|
|
| |
| if args.core_mode == "forward": |
| |
| mC.eval() |
| logits_full = mC(data.x, data.edge_index) |
| else: |
| |
| mC.eval() |
| with torch.no_grad(): |
| logits_C = mC(data_C.x, data_C.edge_index) |
| logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device) |
| logits_seed[C_idx] = logits_C |
| logits_full = appnp_seed_propagate(logits_seed, |
| data.edge_index, |
| K=args.K, |
| alpha=args.alpha) |
|
|
| core_metrics = { |
| "train": accuracy(logits_full, data.y, data.train_mask), |
| "val": accuracy(logits_full, data.y, data.val_mask), |
| "test": accuracy(logits_full, data.y, data.test_mask), |
| } |
|
|
| console.print("\n[bold]LRMC‑core model (trained on core, evaluated on full graph):[/bold]") |
| core_table = Table(show_header=True, header_style="bold magenta") |
| core_table.add_column("Metric", style="cyan") |
| core_table.add_column("Value", style="magenta") |
| core_table.add_row("Train Accuracy", f"{core_metrics['train']:.4f}") |
| core_table.add_row("Validation Accuracy", f"{core_metrics['val']:.4f}") |
| core_table.add_row("Test Accuracy", f"{core_metrics['test']:.4f}") |
| core_table.add_row("Core Training Time (s)", f"{t3 - t2:.2f}") |
| speedup = (t1 - t0) / (t3 - t2 + 1e-9) |
| core_table.add_row("Speedup vs. Baseline", f"{speedup:.2f}×") |
| console.print(core_table) |
|
|
| |
| results["single_run"]["core_model"] = { |
| "mode": str(args.core_mode), |
| "train": float(core_metrics["train"]), |
| "val": float(core_metrics["val"]), |
| "test": float(core_metrics["test"]), |
| "core_train_time_s": float(t3 - t2), |
| "speedup_vs_baseline": float(speedup), |
| } |
|
|
| |
|
|
| console.print("\n[bold]Model Comparison: Baseline vs. L-RMC-core[/bold]") |
|
|
| |
| comparison_table = Table(title="Performance Comparison", show_header=True, header_style="bold magenta") |
| comparison_table.add_column("Metric", style="cyan") |
| comparison_table.add_column("Baseline", style="magenta") |
| comparison_table.add_column("L-RMC-core", style="green") |
| comparison_table.add_column("Speedup", style="yellow") |
|
|
| |
| for metric in ["train", "val", "test"]: |
| comparison_table.add_row( |
| f"{metric.capitalize()} Accuracy", |
| f"{base_metrics[metric]:.4f}", |
| f"{core_metrics[metric]:.4f}", |
| "" |
| ) |
|
|
| |
| baseline_time = t1 - t0 |
| core_time = t3 - t2 |
| speedup = baseline_time / core_time if core_time > 0 else float('inf') |
|
|
| comparison_table.add_row( |
| "Training Time (s)", |
| f"{baseline_time:.2f}", |
| f"{core_time:.2f}", |
| f"{speedup:.2f}x" |
| ) |
|
|
| comparison_table.add_row( |
| "Speedup", |
| "1x", |
| f"{speedup:.2f}x", |
| "" |
| ) |
|
|
| console.print(comparison_table) |
|
|
| |
| if args.warm_ft_epochs > 0: |
| warm = GCN2(in_dim=ds.num_node_features, |
| hid=args.hidden, |
| out_dim=ds.num_classes, |
| dropout=args.dropout) |
| warm.load_state_dict(mC.state_dict()) |
|
|
| t4 = time.perf_counter() |
| train(warm, data, |
| epochs=args.warm_ft_epochs, |
| lr=args.warm_ft_lr, |
| wd=args.wd, |
| patience=args.warm_ft_epochs + 1) |
| t5 = time.perf_counter() |
| warm_metrics = eval_all(warm, data) |
|
|
| console.print("\n[bold]Warm‑start finetune (start from core model, train on FULL graph):[/bold]") |
| warm_table = Table(show_header=True, header_style="bold magenta") |
| warm_table.add_column("Metric", style="cyan") |
| warm_table.add_column("Value", style="magenta") |
| warm_table.add_row("Train Accuracy", f"{warm_metrics['train']:.4f}") |
| warm_table.add_row("Validation Accuracy", f"{warm_metrics['val']:.4f}") |
| warm_table.add_row("Test Accuracy", f"{warm_metrics['test']:.4f}") |
| warm_table.add_row("Finetune Time (s)", f"{t5 - t4:.2f}") |
| warm_table.add_row("Total (core train + warm)", f"{(t3 - t2 + t5 - t4):.2f}s") |
| console.print(warm_table) |
|
|
| |
| results["single_run"]["warm_finetune"] = { |
| "train": float(warm_metrics["train"]), |
| "val": float(warm_metrics["val"]), |
| "test": float(warm_metrics["test"]), |
| "finetune_time_s": float(t5 - t4), |
| "total_time_s": float((t3 - t2) + (t5 - t4)), |
| } |
|
|
| |
| maybe_save_results() |
| else: |
| |
| |
| |
| runs = args.runs |
| console.print(f"\n[bold]Running {runs} seeds and averaging results[/bold]") |
|
|
| |
| base_train, base_val, base_test, base_time = [], [], [], [] |
| core_train, core_val, core_test, core_time = [], [], [], [] |
| speedups = [] |
|
|
| warm_train, warm_val, warm_test, warm_time, warm_total_time = [], [], [], [], [] |
|
|
| data_C = subset_data(data, C_idx) if C_idx.numel() > 0 else None |
| results["core_empty"] = data_C is None |
|
|
| for r in range(runs): |
| set_seed(r) |
|
|
| |
| t0 = time.perf_counter() |
| base = GCN2(in_dim=ds.num_node_features, |
| hid=args.hidden, |
| out_dim=ds.num_classes, |
| dropout=args.dropout) |
| train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) |
| bm = eval_all(base, data) |
| t1 = time.perf_counter() |
|
|
| base_train.append(bm["train"]) ; base_val.append(bm["val"]) ; base_test.append(bm["test"]) ; base_time.append(t1 - t0) |
|
|
| |
| if data_C is None: |
| continue |
|
|
| t2 = time.perf_counter() |
| mC = GCN2(in_dim=ds.num_node_features, |
| hid=args.hidden, |
| out_dim=ds.num_classes, |
| dropout=args.dropout) |
| train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience) |
| t3 = time.perf_counter() |
|
|
| if args.core_mode == "forward": |
| mC.eval() |
| logits_full = mC(data.x, data.edge_index) |
| else: |
| mC.eval() |
| with torch.no_grad(): |
| logits_C = mC(data_C.x, data_C.edge_index) |
| logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device) |
| logits_seed[C_idx] = logits_C |
| logits_full = appnp_seed_propagate(logits_seed, |
| data.edge_index, |
| K=args.K, |
| alpha=args.alpha) |
|
|
| cm = { |
| "train": accuracy(logits_full, data.y, data.train_mask), |
| "val": accuracy(logits_full, data.y, data.val_mask), |
| "test": accuracy(logits_full, data.y, data.test_mask), |
| } |
|
|
| core_train.append(cm["train"]) ; core_val.append(cm["val"]) ; core_test.append(cm["test"]) ; core_time.append(t3 - t2) |
| speedups.append((t1 - t0) / (t3 - t2 + 1e-9)) |
|
|
| |
| if args.warm_ft_epochs > 0: |
| warm = GCN2(in_dim=ds.num_node_features, |
| hid=args.hidden, |
| out_dim=ds.num_classes, |
| dropout=args.dropout) |
| warm.load_state_dict(mC.state_dict()) |
|
|
| t4 = time.perf_counter() |
| train(warm, data, |
| epochs=args.warm_ft_epochs, |
| lr=args.warm_ft_lr, |
| wd=args.wd, |
| patience=args.warm_ft_epochs + 1) |
| t5 = time.perf_counter() |
| wm = eval_all(warm, data) |
| warm_train.append(wm["train"]) ; warm_val.append(wm["val"]) ; warm_test.append(wm["test"]) ; warm_time.append(t5 - t4) |
| warm_total_time.append((t3 - t2) + (t5 - t4)) |
|
|
| |
| def fmt(values, prec=4): |
| if not values: |
| return "n/a" |
| if len(values) == 1: |
| return f"{values[0]:.{prec}f}" |
| try: |
| return f"{mean(values):.{prec}f} ± {stdev(values):.{prec}f}" |
| except Exception: |
| m = sum(values) / len(values) |
| var = sum((v - m) ** 2 for v in values) / max(1, len(values) - 1) |
| return f"{m:.{prec}f} ± {var ** 0.5:.{prec}f}" |
|
|
| def stats(values): |
| """Return dict with list, mean, std, count for JSON.""" |
| d = { |
| "values": [float(v) for v in values], |
| "count": int(len(values)), |
| } |
| if len(values) >= 1: |
| d["mean"] = float(mean(values)) |
| if len(values) >= 2: |
| d["std"] = float(stdev(values)) |
| else: |
| d["std"] = None |
| return d |
|
|
| |
| console.print("\n[bold]Baseline (averaged over runs):[/bold]") |
| base_table = Table(show_header=True, header_style="bold magenta") |
| base_table.add_column("Metric", style="cyan") |
| base_table.add_column("Mean ± Std", style="magenta") |
| base_table.add_row("Train Accuracy", fmt(base_train)) |
| base_table.add_row("Validation Accuracy", fmt(base_val)) |
| base_table.add_row("Test Accuracy", fmt(base_test)) |
| base_table.add_row("Time (s)", fmt(base_time, prec=2)) |
| console.print(base_table) |
|
|
| |
| results["multi_run"] = { |
| "runs": int(runs), |
| "baseline": { |
| "train": stats(base_train), |
| "val": stats(base_val), |
| "test": stats(base_test), |
| "time_s": stats(base_time), |
| } |
| } |
|
|
| if data_C is None: |
| console.print("[bold yellow]LRMC core is empty; no core runs to average.[/bold yellow]") |
| maybe_save_results() |
| return |
|
|
| |
| console.print("\n[bold]LRMC‑core (averaged over runs):[/bold]") |
| core_table = Table(show_header=True, header_style="bold magenta") |
| core_table.add_column("Metric", style="cyan") |
| core_table.add_column("Mean ± Std", style="magenta") |
| core_table.add_row("Train Accuracy", fmt(core_train)) |
| core_table.add_row("Validation Accuracy", fmt(core_val)) |
| core_table.add_row("Test Accuracy", fmt(core_test)) |
| core_table.add_row("Core Training Time (s)", fmt(core_time, prec=2)) |
| core_table.add_row("Speedup vs. Baseline", fmt(speedups, prec=2)) |
| console.print(core_table) |
|
|
| |
| results["multi_run"]["core_model"] = { |
| "mode": str(args.core_mode), |
| "train": stats(core_train), |
| "val": stats(core_val), |
| "test": stats(core_test), |
| "core_train_time_s": stats(core_time), |
| "speedup_vs_baseline": stats(speedups), |
| } |
|
|
| |
| console.print("\n[bold]Model Comparison (averaged): Baseline vs. L-RMC-core[/bold]") |
| comparison_table = Table(title="Performance Comparison (Mean ± Std)", show_header=True, header_style="bold magenta") |
| comparison_table.add_column("Metric", style="cyan") |
| comparison_table.add_column("Baseline", style="magenta") |
| comparison_table.add_column("L-RMC-core", style="green") |
| comparison_table.add_column("Speedup", style="yellow") |
|
|
| for metric, b_vals, c_vals in [ |
| ("Train Accuracy", base_train, core_train), |
| ("Validation Accuracy", base_val, core_val), |
| ("Test Accuracy", base_test, core_test), |
| ]: |
| comparison_table.add_row(metric, fmt(b_vals), fmt(c_vals), "") |
|
|
| comparison_table.add_row("Training Time (s)", fmt(base_time, prec=2), fmt(core_time, prec=2), fmt(speedups, prec=2)) |
| comparison_table.add_row("Speedup", "1x", fmt(speedups, prec=2), "") |
| console.print(comparison_table) |
|
|
| |
| if args.warm_ft_epochs > 0 and warm_time: |
| console.print("\n[bold]Warm‑start finetune (averaged over runs):[/bold]") |
| warm_table = Table(show_header=True, header_style="bold magenta") |
| warm_table.add_column("Metric", style="cyan") |
| warm_table.add_column("Mean ± Std", style="magenta") |
| warm_table.add_row("Train Accuracy", fmt(warm_train)) |
| warm_table.add_row("Validation Accuracy", fmt(warm_val)) |
| warm_table.add_row("Test Accuracy", fmt(warm_test)) |
| warm_table.add_row("Finetune Time (s)", fmt(warm_time, prec=2)) |
| warm_table.add_row("Total (core train + warm)", fmt(warm_total_time, prec=2)) |
| console.print(warm_table) |
|
|
| |
| results["multi_run"]["warm_finetune"] = { |
| "train": stats(warm_train), |
| "val": stats(warm_val), |
| "test": stats(warm_test), |
| "finetune_time_s": stats(warm_time), |
| "total_time_s": stats(warm_total_time), |
| } |
|
|
| |
| maybe_save_results() |
|
|
| if __name__ == "__main__": |
| main() |
|
|