| """ |
| FrozenLake Video Dataset Generator — generate, eval, verify. |
| |
| Uses generate_auto() which picks random (small grids) or guided (large grids) |
| strategy automatically. |
| |
| Usage: |
| python frozenlake_video_gen.py generate --output-dir frozenlake \ |
| --sizes 8 16 32 --num-per-size 100 500 1000 --p 0.8 |
| python frozenlake_video_gen.py eval result_videos/ --table-dir frozenlake/tables |
| python frozenlake_video_gen.py verify results.json --table-dir frozenlake/tables |
| """ |
| import json |
| import csv |
| import hashlib |
| import random |
| import re |
| import argparse |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
| import cv2 |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from frozenlake_processor import FrozenLakeProcessor |
|
|
|
|
| |
|
|
| @dataclass |
| class GenerationState: |
| params_hash: str |
| size_progress: Dict[int, int] |
| seen_fingerprints: List[str] |
| all_samples: List[Dict] |
| completed: bool = False |
|
|
| def to_dict(self) -> Dict: |
| return asdict(self) |
|
|
| @classmethod |
| def from_dict(cls, d: Dict) -> "GenerationState": |
| return cls(**d) |
|
|
|
|
| def _params_hash(params: Dict) -> str: |
| key = {k: v for k, v in params.items() if k != "output_dir"} |
| return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12] |
|
|
|
|
| def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]: |
| meta = output_dir / "metadata.json" |
| if not meta.exists(): |
| return None |
| with open(meta) as f: |
| data = json.load(f) |
| state = GenerationState.from_dict(data["state"]) |
| expected = _params_hash(params) |
| if state.params_hash != expected: |
| print(f"⚠️ Params changed ({state.params_hash} → {expected}), starting fresh") |
| return None |
| if state.completed: |
| print("✓ Already completed") |
| return state |
| print(f"✓ Resuming: {sum(state.size_progress.values())} done") |
| return state |
|
|
|
|
| def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict): |
| meta = output_dir / "metadata.json" |
| tmp = meta.with_suffix(".tmp") |
| with open(tmp, "w") as f: |
| json.dump({"params": params, "state": state.to_dict()}, f, indent=2) |
| tmp.rename(meta) |
|
|
|
|
| |
|
|
| def save_video_cv2(frames: list, path: str, fps: int = 10): |
| first = np.array(frames[0]) |
| h, w = first.shape[:2] |
| writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) |
| for frame in frames: |
| writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) |
| writer.release() |
|
|
|
|
| def extract_last_frame(video_path: str) -> Optional[np.ndarray]: |
| cap = cv2.VideoCapture(str(video_path)) |
| if not cap.isOpened(): |
| return None |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| if total > 0: |
| cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1) |
| ret, frame = cap.read() |
| cap.release() |
| if not ret or frame is None: |
| return None |
| return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
| def _normalise_list(val, sizes, name="parameter"): |
| if isinstance(val, int): |
| return [val] * len(sizes) |
| if len(val) != len(sizes): |
| raise ValueError(f"{name} length ({len(val)}) != sizes ({len(sizes)})") |
| return list(val) |
|
|
|
|
| |
|
|
| def generate_dataset( |
| output_dir: str = "frozenlake", |
| sizes: List[int] = [8, 16, 32], |
| num_per_size: list = [100, 500, 1000], |
| p: float = 0.8, |
| min_path_ratio: float = 0.1, |
| img_size: int = 512, |
| prompt: str = "Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.", |
| train_ratio: float = 0.9, |
| n_start: int = 2, |
| m_end: int = 3, |
| frames: Optional[int] = None, |
| fps: int = 10, |
| seed: int = 42, |
| use_gym: bool = True, |
| checkpoint_interval: int = 50, |
| ): |
| params = { |
| "sizes": sizes, "num_per_size": num_per_size, |
| "p": p, "min_path_ratio": min_path_ratio, "img_size": img_size, |
| "prompt": prompt, "train_ratio": train_ratio, |
| "n_start": n_start, "m_end": m_end, "frames": frames, |
| "fps": fps, "seed": seed, "use_gym": use_gym, |
| } |
|
|
| out = Path(output_dir) |
| img_dir, vid_dir, tbl_dir = out / "images", out / "videos", out / "tables" |
| for d in (img_dir, vid_dir, tbl_dir): |
| d.mkdir(parents=True, exist_ok=True) |
|
|
| state = load_checkpoint(out, params) |
| if state and state.completed: |
| return |
|
|
| num_list = _normalise_list( |
| num_per_size[0] if len(num_per_size) == 1 else num_per_size, |
| sizes, "num_per_size", |
| ) |
| num_w = len(str(max(num_list))) |
| proc = FrozenLakeProcessor(img_size=img_size) |
|
|
| if state is None: |
| random.seed(seed) |
| state = GenerationState( |
| params_hash=_params_hash(params), |
| size_progress={sz: 0 for sz in sizes}, |
| seen_fingerprints=[], all_samples=[], |
| ) |
| print(f"Fresh generation: sizes={sizes}, counts={num_list}, p={p}") |
| else: |
| random.seed(seed) |
| for _ in range(sum(state.size_progress.values()) * 10): |
| random.random() |
|
|
| seen = set(state.seen_fingerprints) |
| all_samples = list(state.all_samples) |
| progress = {int(k): v for k, v in state.size_progress.items()} |
| since_ckpt = 0 |
|
|
| with tqdm(total=sum(num_list), initial=sum(progress.values()), |
| desc="Total", unit="puzzle") as pbar: |
| for grid_size, target in zip(sizes, num_list): |
| generated = progress.get(grid_size, 0) |
| if generated >= target: |
| continue |
|
|
| min_len = max(1, int(grid_size * grid_size * min_path_ratio)) |
|
|
| with tqdm(total=target, initial=generated, |
| desc=f"Size {grid_size:3d}", unit="puzzle", leave=False) as pbar_sz: |
| for _ in range((target - generated) * 5): |
| if generated >= target: |
| break |
| try: |
| desc, path = proc.generate_auto( |
| grid_size, p=p, min_path_len=min_len |
| ) |
| except RuntimeError: |
| continue |
|
|
| fp = proc.fingerprint(desc) |
| if fp in seen: |
| continue |
| seen.add(fp) |
|
|
| base = f"size{grid_size}_{generated:0{num_w}d}" |
|
|
| proc.render(desc, use_gym=use_gym).save(str(img_dir / f"{base}.png")) |
| vid_frames = proc.generate_video_frames( |
| desc, path, n_start=n_start, m_end=m_end, |
| frames=frames, use_gym=use_gym, |
| ) |
| save_video_cv2(vid_frames, str(vid_dir / f"{base}.mp4"), fps=fps) |
| proc.save_table(str(tbl_dir / f"{base}.txt"), desc) |
|
|
| udrl = proc.path_to_udrl(path) |
| all_samples.append({ |
| "prompt": prompt, "image": f"{base}.png", |
| "video": f"{base}.mp4", "table": f"{base}.txt", |
| "grid_size": grid_size, "grid_desc": desc, |
| "start": list(proc.find_start(desc)), |
| "path_udrl": udrl, "path_length": len(path), |
| "frame_count": len(vid_frames), |
| }) |
|
|
| generated += 1 |
| progress[grid_size] = generated |
| since_ckpt += 1 |
| pbar_sz.update(1) |
| pbar.update(1) |
|
|
| if since_ckpt >= checkpoint_interval: |
| state.size_progress = progress |
| state.seen_fingerprints = list(seen) |
| state.all_samples = all_samples |
| save_checkpoint(out, state, params) |
| since_ckpt = 0 |
|
|
| tqdm.write(f"Size {grid_size}: {generated} puzzles") |
|
|
| with open(out / "path.json", "w") as f: |
| json.dump(dict(sorted((s["image"], s["path_udrl"]) for s in all_samples)), f, indent=4) |
|
|
| |
| random.seed(seed + 1) |
| by_size: Dict[int, List[Dict]] = {} |
| for s in all_samples: |
| by_size.setdefault(s["maze_size"], []).append(s) |
|
|
| train_samples, test_samples = [], [] |
| for sz in sorted(by_size): |
| group = by_size[sz] |
| random.shuffle(group) |
| sz_split = int(len(group) * train_ratio) |
| train_samples.extend(group[:sz_split]) |
| test_samples.extend(group[sz_split:]) |
|
|
| random.shuffle(train_samples) |
| random.shuffle(test_samples) |
| split = len(train_samples) |
|
|
| def _write_jsonl(samples, path): |
| with open(path, "w") as f: |
| for s in samples: |
| f.write(json.dumps(s) + "\n") |
|
|
| _write_jsonl(train_samples, out / "train.jsonl") |
| _write_jsonl(test_samples, out / "test.jsonl") |
|
|
| for name, samples in [("train", train_samples), ("test", test_samples)]: |
| with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f: |
| w = csv.writer(f) |
| w.writerow(["input_image", "video", "prompt"]) |
| for s in samples: |
| w.writerow([f"images/{s['image']}", f"videos/{s['video']}", s["prompt"]]) |
|
|
| state.size_progress = progress |
| state.seen_fingerprints = list(seen) |
| state.all_samples = all_samples |
| state.completed = True |
| save_checkpoint(out, state, params) |
|
|
| lengths = [s["path_length"] for s in all_samples] |
| fcounts = [s["frame_count"] for s in all_samples] |
| print(f"\n✓ Complete: {out}/ | {len(all_samples)} puzzles " |
| f"(train={split}, test={len(all_samples)-split})") |
| print(f" Paths: avg={np.mean(lengths):.1f} min={min(lengths)} max={max(lengths)}") |
|
|
|
|
| |
|
|
| def eval_videos( |
| video_dir: str, table_dir: str, |
| output_json: Optional[str] = None, gt_json: Optional[str] = None, |
| use_gym: bool = True, |
| ): |
| proc = FrozenLakeProcessor() |
| vid_root, tbl_root = Path(video_dir), Path(table_dir) |
| if output_json is None: |
| output_json = str(vid_root / "0_result.json") |
|
|
| videos = sorted( |
| vid_root.glob("*.mp4"), |
| key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)], |
| ) |
| if not videos: |
| print(f"No .mp4 in {vid_root}"); return |
|
|
| extracted: Dict[str, str] = {} |
| missing_tbl = missing_frame = 0 |
|
|
| for vp in tqdm(videos, desc="Extracting"): |
| desc = proc.load_table(str(tbl_root / f"{vp.stem}.txt")) |
| if desc is None: |
| missing_tbl += 1; continue |
| start = proc.find_start(desc) |
| if start is None: |
| missing_tbl += 1; continue |
| lf = extract_last_frame(str(vp)) |
| if lf is None: |
| missing_frame += 1; continue |
| extracted[f"{vp.stem}.png"] = proc.extract_path_from_pixels( |
| lf, len(desc), len(desc[0]), start, desc) |
|
|
| with open(output_json, "w") as f: |
| json.dump(extracted, f, indent=4) |
|
|
| verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim |
| correct = total = 0 |
| size_stats: Dict[int, Dict[str, int]] = {} |
| top: List[Dict] = [] |
|
|
| for name, udrl in extracted.items(): |
| desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt")) |
| if desc is None: continue |
| total += 1 |
| sz = len(desc) |
| size_stats.setdefault(sz, {"total": 0, "correct": 0}) |
| size_stats[sz]["total"] += 1 |
| if verify_fn(desc, udrl): |
| correct += 1 |
| size_stats[sz]["correct"] += 1 |
| top.append({"name": name, "length": len(udrl)}) |
|
|
| acc = correct / total * 100 if total else 0 |
| print(f"\n{'='*50}\nEval: {correct}/{total} ({acc:.2f}%) | " |
| f"missing_tbl={missing_tbl} bad_frame={missing_frame}") |
| for sz in sorted(size_stats): |
| s = size_stats[sz] |
| print(f" Size {sz:3d}: {s['correct']}/{s['total']} " |
| f"({s['correct']/s['total']*100:.1f}%)") |
| top.sort(key=lambda x: x["length"], reverse=True) |
| for i, item in enumerate(top[:3]): |
| print(f" Top {i+1}: {item['name']} (len={item['length']})") |
|
|
| if gt_json: |
| try: |
| with open(gt_json) as f: |
| gt = json.load(f) |
| bins: Dict[str, Dict[str, int]] = {} |
| for name, pred in extracted.items(): |
| if name not in gt: continue |
| lo = (len(gt[name]) // 10) * 10 |
| label = f"{lo:3d}-{lo+9:3d}" |
| bins.setdefault(label, {"total": 0, "correct": 0}) |
| bins[label]["total"] += 1 |
| desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt")) |
| if desc and verify_fn(desc, pred): |
| bins[label]["correct"] += 1 |
| if bins: |
| print("\nBy GT path length:") |
| for label in sorted(bins): |
| b = bins[label] |
| print(f" {label}: {b['correct']}/{b['total']} " |
| f"({b['correct']/b['total']*100:.1f}%)") |
| except Exception: |
| pass |
| print(f"{'='*50}") |
|
|
|
|
| def verify_results(json_file: str, table_dir: str, use_gym: bool = True): |
| proc = FrozenLakeProcessor() |
| with open(json_file) as f: |
| solutions = json.load(f) |
| verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim |
| correct = skipped = valid = 0 |
| for name, udrl in solutions.items(): |
| desc = proc.load_table(str(Path(table_dir) / f"{name.replace('.png','')}.txt")) |
| if desc is None: |
| skipped += 1; continue |
| valid += 1 |
| if verify_fn(desc, udrl): |
| correct += 1 |
| acc = correct / valid * 100 if valid else 0 |
| print(f"\nVerification: {correct}/{valid} ({acc:.2f}%)") |
|
|
|
|
| |
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="FrozenLake video dataset") |
| sub = p.add_subparsers(dest="command") |
|
|
| gen = sub.add_parser("generate") |
| gen.add_argument("--output-dir", default="frozenlake") |
| gen.add_argument("--sizes", type=int, nargs="+", default=[8, 12, 16, 32]) |
| gen.add_argument("--num-per-size", type=int, nargs="+", default=[1000, 2000, 5000, 10000]) |
| gen.add_argument("--p", type=float, default=0.5) |
| gen.add_argument("--min-path-ratio", type=float, default=0.1) |
| gen.add_argument("--img-size", type=int, default=1024) |
| gen.add_argument("--prompt", default="Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.") |
| gen.add_argument("--train-ratio", type=float, default=0.9) |
| gen.add_argument("--n-start", type=int, default=2) |
| gen.add_argument("--m-end", type=int, default=3) |
| gen.add_argument("--frames", type=int, default=None) |
| gen.add_argument("--fps", type=int, default=10) |
| gen.add_argument("--seed", type=int, default=42) |
| gen.add_argument("--no-gym", action="store_true") |
| gen.add_argument("--checkpoint-interval", type=int, default=50) |
|
|
| ev = sub.add_parser("eval") |
| ev.add_argument("video_dir") |
| ev.add_argument("--table-dir", required=True) |
| ev.add_argument("--output-json", default=None) |
| ev.add_argument("--gt-json", default=None) |
| ev.add_argument("--no-gym", action="store_true") |
|
|
| ver = sub.add_parser("verify") |
| ver.add_argument("json_file") |
| ver.add_argument("--table-dir", required=True) |
| ver.add_argument("--no-gym", action="store_true") |
|
|
| return p.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| if args.command == "generate": |
| kw = {k: v for k, v in vars(args).items() if k not in ("command", "no_gym")} |
| kw["use_gym"] = not args.no_gym |
| generate_dataset(**kw) |
| elif args.command == "eval": |
| eval_videos(args.video_dir, args.table_dir, args.output_json, |
| args.gt_json, not args.no_gym) |
| elif args.command == "verify": |
| verify_results(args.json_file, args.table_dir, not args.no_gym) |
| else: |
| print("Usage: python frozenlake_video_gen.py {generate|eval|verify} ...") |