| | """ |
| | Maze Video Dataset Generator — generates maze puzzle images and solution videos |
| | with checkpoint/resume support, train/test splitting, and JSONL metadata. |
| | |
| | Includes an ``eval`` subcommand that takes a directory of result videos, |
| | extracts the last frame from each, parses the red path, and verifies it |
| | against the ground-truth maze text files. |
| | |
| | Usage: |
| | # Generate |
| | python maze_video_gen.py generate --output-dir maze --sizes 8 16 32 \ |
| | --num-per-size 100 500 1000 --min-path-ratio 0.3 \ |
| | --n-start 5 --m-end 5 --frames 50 --fps 10 --seed 42 |
| | |
| | # Evaluate result videos |
| | python maze_video_gen.py eval result_videos/ --text-dir maze/texts |
| | |
| | # Evaluate with backtracking path extraction |
| | python maze_video_gen.py eval result_videos/ --text-dir maze/texts --recursive |
| | |
| | # Verify a pre-extracted JSON |
| | python maze_video_gen.py verify results.json --text-dir maze/texts |
| | """ |
| | import json |
| | import csv |
| | import hashlib |
| | import random |
| | import re |
| | import argparse |
| | from dataclasses import dataclass, asdict |
| | from pathlib import Path |
| | from typing import Dict, List, Optional |
| |
|
| | import cv2 |
| | import numpy as np |
| | from tqdm import tqdm |
| |
|
| | from maze_processor import MazeProcessor |
| |
|
| |
|
| | |
| |
|
| | @dataclass |
| | class GenerationState: |
| | """Tracks generation progress for checkpoint/resume.""" |
| | params_hash: str |
| | size_progress: Dict[int, int] |
| | seen_fingerprints: List[str] |
| | all_samples: List[Dict] |
| | completed: bool = False |
| |
|
| | def to_dict(self) -> Dict: |
| | return asdict(self) |
| |
|
| | @classmethod |
| | def from_dict(cls, d: Dict) -> "GenerationState": |
| | return cls(**d) |
| |
|
| |
|
| | def _params_hash(params: Dict) -> str: |
| | """Deterministic hash of generation parameters (excluding output_dir).""" |
| | key = {k: v for k, v in params.items() if k != "output_dir"} |
| | return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12] |
| |
|
| |
|
| | def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]: |
| | """Load checkpoint if it exists and parameters match.""" |
| | meta = output_dir / "metadata.json" |
| | if not meta.exists(): |
| | return None |
| | with open(meta) as f: |
| | data = json.load(f) |
| | state = GenerationState.from_dict(data["state"]) |
| | expected = _params_hash(params) |
| | if state.params_hash != expected: |
| | print(f"⚠️ Parameters changed ({state.params_hash} → {expected}), starting fresh") |
| | return None |
| | if state.completed: |
| | print("✓ Generation already completed") |
| | return state |
| | done = sum(state.size_progress.values()) |
| | print(f"✓ Resuming from checkpoint: {done} mazes generated") |
| | return state |
| |
|
| |
|
| | def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict): |
| | """Atomically write checkpoint to metadata.json.""" |
| | meta = output_dir / "metadata.json" |
| | tmp = meta.with_suffix(".tmp") |
| | with open(tmp, "w") as f: |
| | json.dump({"params": params, "state": state.to_dict()}, f, indent=2) |
| | tmp.rename(meta) |
| |
|
| |
|
| | |
| |
|
| | def save_video_cv2(frames: list, path: str, fps: int = 10): |
| | """Save list of PIL Images as an mp4 video.""" |
| | first = np.array(frames[0]) |
| | h, w = first.shape[:2] |
| | writer = cv2.VideoWriter( |
| | str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h) |
| | ) |
| | for frame in frames: |
| | writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) |
| | writer.release() |
| |
|
| |
|
| | def extract_last_frame(video_path: str) -> Optional[np.ndarray]: |
| | """ |
| | Extract the last frame from a video file as an RGB numpy array. |
| | |
| | Returns: |
| | (H, W, 3) uint8 RGB array, or None on failure. |
| | """ |
| | cap = cv2.VideoCapture(str(video_path)) |
| | if not cap.isOpened(): |
| | return None |
| |
|
| | total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | if total > 0: |
| | cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1) |
| |
|
| | ret, frame = cap.read() |
| | cap.release() |
| |
|
| | if not ret or frame is None: |
| | return None |
| | return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| |
|
| |
|
| | |
| |
|
| | def _normalise_list(val, sizes, name="parameter"): |
| | """Broadcast a single int to a list, or validate list length.""" |
| | if isinstance(val, int): |
| | return [val] * len(sizes) |
| | if len(val) != len(sizes): |
| | raise ValueError(f"{name} length ({len(val)}) != sizes length ({len(sizes)})") |
| | return list(val) |
| |
|
| |
|
| | |
| |
|
| | def generate_dataset( |
| | output_dir: str = "maze", |
| | sizes: List[int] = [8, 16, 32], |
| | num_per_size: list = [100, 500, 1000], |
| | min_path_ratio: float = 0.3, |
| | img_size: int = 1024, |
| | prompt: str = "Draw a continuous red line from the yellow dot to the blue dot, avoiding all walls.", |
| | train_ratio: float = 0.9, |
| | n_start: int = 5, |
| | m_end: int = 5, |
| | frames: Optional[int] = None, |
| | fps: int = 10, |
| | seed: int = 42, |
| | checkpoint_interval: int = 50, |
| | ): |
| | """ |
| | Generate maze video dataset with checkpoint/resume support. |
| | |
| | The *frames* parameter controls content frames per video: |
| | - None → one content frame per path step (variable length) |
| | - N > 0 → exactly N content frames (slow-mo / fast-fwd as needed) |
| | |
| | Directory layout:: |
| | |
| | output_dir/ |
| | images/ — puzzle PNG (no solution line) |
| | videos/ — solution MP4 (progressive red line) |
| | texts/ — maze text files (bitmask format) |
| | train.jsonl / test.jsonl |
| | train.csv / test.csv |
| | path.json — UDRL answer key |
| | metadata.json — checkpoint state |
| | """ |
| | params = { |
| | "sizes": sizes, "num_per_size": num_per_size, |
| | "min_path_ratio": min_path_ratio, "img_size": img_size, |
| | "prompt": prompt, "train_ratio": train_ratio, |
| | "n_start": n_start, "m_end": m_end, "frames": frames, |
| | "fps": fps, "seed": seed, |
| | } |
| |
|
| | out = Path(output_dir) |
| | img_dir = out / "images" |
| | vid_dir = out / "videos" |
| | txt_dir = out / "texts" |
| | for d in (img_dir, vid_dir, txt_dir): |
| | d.mkdir(parents=True, exist_ok=True) |
| |
|
| | state = load_checkpoint(out, params) |
| | if state and state.completed: |
| | return |
| |
|
| | num_list = _normalise_list( |
| | num_per_size[0] if len(num_per_size) == 1 else num_per_size, |
| | sizes, "num_per_size", |
| | ) |
| | max_puzzles = max(num_list) |
| | num_w = len(str(max_puzzles)) |
| | proc = MazeProcessor(img_size=img_size) |
| |
|
| | if state is None: |
| | random.seed(seed) |
| | state = GenerationState( |
| | params_hash=_params_hash(params), |
| | size_progress={sz: 0 for sz in sizes}, |
| | seen_fingerprints=[], |
| | all_samples=[], |
| | ) |
| | print(f"Starting fresh generation: sizes={sizes}, counts={num_list}") |
| | print(f" frames={'auto (1 per step)' if frames is None else frames}, " |
| | f"n_start={n_start}, m_end={m_end}, fps={fps}") |
| | else: |
| | random.seed(seed) |
| | for _ in range(sum(state.size_progress.values()) * 10): |
| | random.random() |
| |
|
| | seen = set(state.seen_fingerprints) |
| | all_samples = list(state.all_samples) |
| | progress = {int(k): v for k, v in state.size_progress.items()} |
| | since_ckpt = 0 |
| |
|
| | total_target = sum(num_list) |
| | total_done = sum(progress.values()) |
| |
|
| | with tqdm(total=total_target, initial=total_done, desc="Total", unit="maze") as pbar: |
| | for maze_size, target in zip(sizes, num_list): |
| | generated = progress.get(maze_size, 0) |
| | if generated >= target: |
| | continue |
| |
|
| | min_len = max(1, int(maze_size * maze_size * min_path_ratio)) |
| | max_attempts = (target - generated) * 20 |
| |
|
| | with tqdm( |
| | total=target, initial=generated, desc=f"Size {maze_size:3d}", |
| | unit="maze", leave=False, |
| | ) as pbar_sz: |
| | for _ in range(max_attempts): |
| | if generated >= target: |
| | break |
| |
|
| | try: |
| | grid, start, end, path = proc.generate( |
| | maze_size, min_path_len=min_len |
| | ) |
| | except RuntimeError: |
| | continue |
| |
|
| | fp = proc.fingerprint(grid, start, end) |
| | if fp in seen: |
| | continue |
| | seen.add(fp) |
| |
|
| | idx = generated |
| | base = f"size{maze_size}_{idx:0{num_w}d}" |
| | img_name = f"{base}.png" |
| | vid_name = f"{base}.mp4" |
| | txt_name = f"{base}.txt" |
| |
|
| | puzzle_img = proc.render(grid, start, end) |
| | puzzle_img.save(str(img_dir / img_name)) |
| |
|
| | vid_frames = proc.generate_video_frames( |
| | grid, start, end, path, |
| | n_start=n_start, m_end=m_end, frames=frames, |
| | ) |
| | save_video_cv2(vid_frames, str(vid_dir / vid_name), fps=fps) |
| |
|
| | proc.save_text(str(txt_dir / txt_name), grid, start, end) |
| |
|
| | udrl = proc.path_to_udrl(path) |
| |
|
| | all_samples.append({ |
| | "prompt": prompt, |
| | "image": img_name, |
| | "video": vid_name, |
| | "text": txt_name, |
| | "maze_size": maze_size, |
| | "start": list(start), |
| | "end": list(end), |
| | "path_udrl": udrl, |
| | "path_length": len(path), |
| | "frame_count": len(vid_frames), |
| | }) |
| |
|
| | generated += 1 |
| | progress[maze_size] = generated |
| | since_ckpt += 1 |
| | pbar_sz.update(1) |
| | pbar.update(1) |
| |
|
| | if since_ckpt >= checkpoint_interval: |
| | state.size_progress = progress |
| | state.seen_fingerprints = list(seen) |
| | state.all_samples = all_samples |
| | save_checkpoint(out, state, params) |
| | since_ckpt = 0 |
| |
|
| | tqdm.write( |
| | f"Size {maze_size}: {generated} mazes, " |
| | f"{sum(1 for s in all_samples if s['maze_size'] == maze_size)} samples" |
| | ) |
| |
|
| | |
| |
|
| | path_answers = {s["image"]: s["path_udrl"] for s in all_samples} |
| | with open(out / "path.json", "w") as f: |
| | json.dump(dict(sorted(path_answers.items())), f, indent=4) |
| |
|
| | |
| | random.seed(seed + 1) |
| | by_size: Dict[int, List[Dict]] = {} |
| | for s in all_samples: |
| | by_size.setdefault(s["maze_size"], []).append(s) |
| |
|
| | train_samples, test_samples = [], [] |
| | for sz in sorted(by_size): |
| | group = by_size[sz] |
| | random.shuffle(group) |
| | sz_split = int(len(group) * train_ratio) |
| | train_samples.extend(group[:sz_split]) |
| | test_samples.extend(group[sz_split:]) |
| |
|
| | random.shuffle(train_samples) |
| | random.shuffle(test_samples) |
| | split = len(train_samples) |
| |
|
| | def _write_jsonl(samples, path): |
| | with open(path, "w") as f: |
| | for s in samples: |
| | f.write(json.dumps(s) + "\n") |
| |
|
| | _write_jsonl(train_samples, out / "train.jsonl") |
| | _write_jsonl(test_samples, out / "test.jsonl") |
| |
|
| | for name, samples in [("train", train_samples), ("test", test_samples)]: |
| | with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f: |
| | writer = csv.writer(f) |
| | writer.writerow(["input_image", "video", "prompt"]) |
| | for s in samples: |
| | writer.writerow([ |
| | f"images/{s['image']}", f"videos/{s['video']}", s["prompt"] |
| | ]) |
| |
|
| | state.size_progress = progress |
| | state.seen_fingerprints = list(seen) |
| | state.all_samples = all_samples |
| | state.completed = True |
| | save_checkpoint(out, state, params) |
| |
|
| | print(f"\n✓ Dataset complete: {out}/") |
| | print(f" Sizes: {sizes}") |
| | print(f" Mazes: {len(all_samples)}") |
| | print(f" Train: {split}, Test: {len(all_samples) - split}") |
| | lengths = [s["path_length"] for s in all_samples] |
| | fcounts = [s["frame_count"] for s in all_samples] |
| | print(f" Path lengths: avg={np.mean(lengths):.1f}, " |
| | f"min={min(lengths)}, max={max(lengths)}") |
| | print(f" Frame counts: avg={np.mean(fcounts):.1f}, " |
| | f"min={min(fcounts)}, max={max(fcounts)}") |
| |
|
| |
|
| | |
| |
|
| | def eval_videos( |
| | video_dir: str, |
| | text_dir: str, |
| | output_json: Optional[str] = None, |
| | gt_json: Optional[str] = None, |
| | strict: bool = True, |
| | recursive: bool = False, |
| | ): |
| | """ |
| | Evaluate a directory of result videos against ground-truth mazes. |
| | |
| | Pipeline per video: |
| | 1. Extract last frame from .mp4 |
| | 2. Detect red path via pixel analysis |
| | 3. Convert to UDRL action string |
| | 4. Verify against maze .txt (wall-respecting walk from start to end) |
| | |
| | Matching convention: |
| | Video ``<stem>.mp4`` → Text ``<stem>.txt`` in *text_dir*. |
| | |
| | Args: |
| | video_dir: Directory containing result .mp4 files. |
| | text_dir: Directory containing ground-truth maze .txt files. |
| | output_json: Path to save extracted paths as JSON. |
| | gt_json: Optional ground-truth answer JSON for accuracy by path length. |
| | strict: Strict verification mode. |
| | recursive: Use backtracking DFS for red-path extraction instead of greedy. |
| | """ |
| | proc = MazeProcessor() |
| | vid_root = Path(video_dir) |
| | txt_root = Path(text_dir) |
| |
|
| | if output_json is None: |
| | output_json = str(vid_root / "0_result.json") |
| |
|
| | |
| | videos = sorted( |
| | vid_root.glob("*.mp4"), |
| | key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)], |
| | ) |
| |
|
| | if not videos: |
| | print(f"No .mp4 files found in {vid_root}") |
| | return |
| |
|
| | print(f"Found {len(videos)} result videos in {vid_root}") |
| | print(f"Text dir: {txt_root}") |
| | print(f"Mode: {'recursive (backtracking)' if recursive else 'greedy'}, " |
| | f"strict={'yes' if strict else 'no'}") |
| |
|
| | |
| | extracted: Dict[str, str] = {} |
| | missing_txt = 0 |
| | missing_frame = 0 |
| |
|
| | for vpath in tqdm(videos, desc="Extracting paths"): |
| | stem = vpath.stem |
| | stem = stem.replace('_gen', '') |
| | txt_path = txt_root / f"{stem}.txt" |
| |
|
| | if not txt_path.exists(): |
| | missing_txt += 1 |
| | continue |
| |
|
| | maze = proc.load_text(str(txt_path)) |
| | if maze is None: |
| | missing_txt += 1 |
| | continue |
| |
|
| | last_frame = extract_last_frame(str(vpath)) |
| | if last_frame is None: |
| | missing_frame += 1 |
| | continue |
| |
|
| | udrl = proc.extract_path_from_pixels( |
| | last_frame, |
| | grid_raw=maze["grid_raw"], |
| | size=maze["size"], |
| | start=maze["start"], |
| | recursive=recursive, |
| | end=maze["end"], |
| | strict=strict, |
| | ) |
| | extracted[f"{stem}.png"] = udrl |
| |
|
| | |
| | with open(output_json, "w", encoding="utf-8") as f: |
| | json.dump(extracted, f, indent=4) |
| | print(f"\nExtracted paths saved to: {output_json}") |
| |
|
| | |
| | correct = 0 |
| | total_valid = 0 |
| | correctly_solved: List[Dict] = [] |
| |
|
| | for name, udrl in extracted.items(): |
| | stem = name.replace(".png", "") |
| | txt_path = txt_root / f"{stem}.txt" |
| | maze = proc.load_text(str(txt_path)) |
| | if maze is None: |
| | continue |
| | total_valid += 1 |
| | if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl, strict=strict): |
| | correct += 1 |
| | correctly_solved.append({"name": name, "length": len(udrl)}) |
| |
|
| | acc = (correct / total_valid * 100) if total_valid else 0 |
| |
|
| | print(f"\n{'=' * 50}") |
| | print("Evaluation Summary") |
| | print(f"{'=' * 50}") |
| | print(f"Total Videos : {len(videos)}") |
| | print(f"Missing .txt : {missing_txt}") |
| | print(f"Failed Frame Read : {missing_frame}") |
| | print(f"Evaluated : {total_valid}") |
| | print(f"Correctly Solved : {correct}") |
| | print(f"Accuracy : {acc:.2f}%") |
| | print(f"Extraction Mode : {'recursive' if recursive else 'greedy'}") |
| | print(f"{'-' * 50}") |
| |
|
| | |
| | size_stats: Dict[int, Dict[str, int]] = {} |
| | for name, udrl in extracted.items(): |
| | stem = name.replace(".png", "") |
| | txt_path = txt_root / f"{stem}.txt" |
| | maze = proc.load_text(str(txt_path)) |
| | if maze is None: |
| | continue |
| | sz = maze["size"] |
| | if sz not in size_stats: |
| | size_stats[sz] = {"total": 0, "correct": 0} |
| | size_stats[sz]["total"] += 1 |
| | if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl, strict=strict): |
| | size_stats[sz]["correct"] += 1 |
| |
|
| | if size_stats: |
| | print("\nAccuracy by maze size:") |
| | for sz in sorted(size_stats): |
| | s = size_stats[sz] |
| | sz_acc = s["correct"] / s["total"] * 100 if s["total"] else 0 |
| | print(f" Size {sz:3d}: {s['correct']:4d}/{s['total']:4d} ({sz_acc:.2f}%)") |
| |
|
| | |
| | correctly_solved.sort(key=lambda x: x["length"], reverse=True) |
| | if correctly_solved: |
| | print(f"\nTop 3 Longest Correct Paths:") |
| | for i, item in enumerate(correctly_solved[:3]): |
| | print(f" {i+1}. {item['name']} (length: {item['length']})") |
| |
|
| | |
| | if gt_json: |
| | _compare_with_gt(extracted, gt_json, txt_root, proc, strict=strict) |
| |
|
| | print(f"{'=' * 50}") |
| |
|
| |
|
| | def _compare_with_gt( |
| | extracted: Dict[str, str], |
| | gt_json_path: str, |
| | txt_root: Path, |
| | proc: MazeProcessor, |
| | strict: bool = True, |
| | ): |
| | """Print accuracy binned by ground-truth path length.""" |
| | try: |
| | with open(gt_json_path) as f: |
| | gt = json.load(f) |
| | except Exception: |
| | print(f" Warning: could not load ground-truth JSON: {gt_json_path}") |
| | return |
| |
|
| | bins: Dict[str, Dict[str, int]] = {} |
| | for name, pred_udrl in extracted.items(): |
| | if name not in gt: |
| | continue |
| | gt_udrl = gt[name] |
| | gt_len = len(gt_udrl) |
| |
|
| | lo = (gt_len // 10) * 10 |
| | hi = lo + 9 |
| | label = f"{lo:3d}-{hi:3d}" |
| | if label not in bins: |
| | bins[label] = {"total": 0, "correct": 0} |
| | bins[label]["total"] += 1 |
| |
|
| | stem = name.replace(".png", "") |
| | maze = proc.load_text(str(txt_root / f"{stem}.txt")) |
| | if maze and proc.verify_path(maze["grid"], maze["start"], maze["end"], pred_udrl, strict=strict): |
| | bins[label]["correct"] += 1 |
| |
|
| | if bins: |
| | print("\nAccuracy by GT path length:") |
| | for label in sorted(bins): |
| | b = bins[label] |
| | b_acc = b["correct"] / b["total"] * 100 if b["total"] else 0 |
| | print(f" Length {label}: {b['correct']:4d}/{b['total']:4d} ({b_acc:.2f}%)") |
| |
|
| |
|
| | |
| |
|
| | def verify_results(json_file: str, text_dir: str, strict: bool = True): |
| | """ |
| | Verify pre-extracted UDRL paths (from a JSON file) against maze .txt files. |
| | |
| | Args: |
| | json_file: Path to JSON with {name: udrl_string} predictions. |
| | text_dir: Directory containing maze .txt files. |
| | strict: Strict verification mode. |
| | """ |
| | proc = MazeProcessor() |
| | json_path = Path(json_file) |
| | txt_root = Path(text_dir) |
| |
|
| | with open(json_path) as f: |
| | solutions = json.load(f) |
| |
|
| | correct = skipped = valid = 0 |
| |
|
| | for name, udrl in solutions.items(): |
| | clean = name.replace(".png", "") |
| | txt_path = txt_root / f"{clean}.txt" |
| | maze = proc.load_text(str(txt_path)) |
| | if maze is None: |
| | skipped += 1 |
| | continue |
| | valid += 1 |
| | if proc.verify_path(maze["grid"], maze["start"], maze["end"], udrl, strict=strict): |
| | correct += 1 |
| |
|
| | acc = (correct / valid * 100) if valid else 0 |
| | print(f"\n{'='*40}") |
| | print(f"Verification: {correct}/{valid} correct ({acc:.2f}%)") |
| | if skipped: |
| | print(f"Skipped: {skipped}") |
| | print(f"{'='*40}") |
| |
|
| |
|
| | |
| |
|
| | def parse_args(): |
| | p = argparse.ArgumentParser( |
| | description="Maze video dataset: generate, eval, verify" |
| | ) |
| | sub = p.add_subparsers(dest="command", help="Sub-command") |
| |
|
| | |
| | gen = sub.add_parser("generate", help="Generate dataset") |
| | gen.add_argument("--output-dir", type=str, default="maze") |
| | gen.add_argument("--sizes", type=int, nargs="+", default=[8, 12, 16, 32]) |
| | gen.add_argument("--num-per-size", type=int, nargs="+", default=[1000, 1000, 1000, 2000]) |
| | gen.add_argument("--min-path-ratio", type=float, default=0.1, |
| | help="Min path length as fraction of size²") |
| | gen.add_argument("--img-size", type=int, default=1024) |
| | gen.add_argument("--prompt", type=str, |
| | default="Draw a continuous red line from the yellow dot " |
| | "to the blue dot, avoiding all walls.") |
| | gen.add_argument("--train-ratio", type=float, default=0.9) |
| | gen.add_argument("--n-start", type=int, default=2, |
| | help="Hold frames at video start (blank puzzle)") |
| | gen.add_argument("--m-end", type=int, default=3, |
| | help="Hold frames at video end (completed solution)") |
| | gen.add_argument("--frames", type=int, default=None, |
| | help="Content frames per video (None=auto 1 per step)") |
| | gen.add_argument("--fps", type=int, default=10) |
| | gen.add_argument("--seed", type=int, default=42) |
| | gen.add_argument("--checkpoint-interval", type=int, default=50) |
| |
|
| | |
| | ev = sub.add_parser("eval", |
| | help="Evaluate result videos (last frame → extract → verify)") |
| | ev.add_argument("video_dir", type=str, |
| | help="Directory containing result .mp4 files") |
| | ev.add_argument("--text-dir", type=str, required=True, |
| | help="Directory with ground-truth maze .txt files") |
| | ev.add_argument("--output-json", type=str, default=None, |
| | help="Output JSON for extracted paths (default: video_dir/0_result.json)") |
| | ev.add_argument("--gt-json", type=str, default=None, |
| | help="Optional ground-truth path.json for length-binned accuracy") |
| | ev.add_argument("--strict", action="store_true", |
| | help="Strict verification (exact UDRL match) vs leniency on no-op moves") |
| | ev.add_argument("--recursive", action="store_true", |
| | help="Use backtracking DFS for path extraction instead of greedy") |
| |
|
| | |
| | ver = sub.add_parser("verify", help="Verify a pre-extracted JSON of UDRL paths") |
| | ver.add_argument("json_file", type=str) |
| | ver.add_argument("--text-dir", type=str, required=True, |
| | help="Directory with maze .txt files") |
| | ver.add_argument("--strict", action="store_true", |
| | help="Strict verification (exact UDRL match) vs leniency on no-op moves") |
| | return p.parse_args() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| |
|
| | if args.command == "generate": |
| | kwargs = {k: v for k, v in vars(args).items() if k != "command"} |
| | generate_dataset(**kwargs) |
| |
|
| | elif args.command == "eval": |
| | eval_videos( |
| | video_dir=args.video_dir, |
| | text_dir=args.text_dir, |
| | output_json=args.output_json, |
| | gt_json=args.gt_json, |
| | strict=args.strict, |
| | recursive=args.recursive, |
| | ) |
| |
|
| | elif args.command == "verify": |
| | verify_results(args.json_file, args.text_dir, strict=args.strict) |
| |
|
| | else: |
| | print("Usage: python maze_video_gen.py {generate|eval|verify} [options]") |
| | print(" python maze_video_gen.py generate --help") |
| | print(" python maze_video_gen.py eval --help") |
| | print(" python maze_video_gen.py verify --help") |