""" FrozenLake Video Dataset Generator — generate, eval, verify. Uses generate_auto() which picks random (small grids) or guided (large grids) strategy automatically. Usage: python frozenlake_video_gen.py generate --output-dir frozenlake \ --sizes 8 16 32 --num-per-size 100 500 1000 --p 0.8 python frozenlake_video_gen.py eval result_videos/ --table-dir frozenlake/tables python frozenlake_video_gen.py verify results.json --table-dir frozenlake/tables """ import json import csv import hashlib import random import re import argparse from dataclasses import dataclass, asdict from pathlib import Path from typing import Dict, List, Optional import cv2 import numpy as np from tqdm import tqdm from frozenlake_processor import FrozenLakeProcessor # ==================== Checkpoint ==================== @dataclass class GenerationState: params_hash: str size_progress: Dict[int, int] seen_fingerprints: List[str] all_samples: List[Dict] completed: bool = False def to_dict(self) -> Dict: return asdict(self) @classmethod def from_dict(cls, d: Dict) -> "GenerationState": return cls(**d) def _params_hash(params: Dict) -> str: key = {k: v for k, v in params.items() if k != "output_dir"} return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12] def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]: meta = output_dir / "metadata.json" if not meta.exists(): return None with open(meta) as f: data = json.load(f) state = GenerationState.from_dict(data["state"]) expected = _params_hash(params) if state.params_hash != expected: print(f"⚠️ Params changed ({state.params_hash} → {expected}), starting fresh") return None if state.completed: print("✓ Already completed") return state print(f"✓ Resuming: {sum(state.size_progress.values())} done") return state def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict): meta = output_dir / "metadata.json" tmp = meta.with_suffix(".tmp") with open(tmp, "w") as f: json.dump({"params": params, "state": state.to_dict()}, f, indent=2) tmp.rename(meta) # ==================== Video I/O ==================== def save_video_cv2(frames: list, path: str, fps: int = 10): first = np.array(frames[0]) h, w = first.shape[:2] writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) for frame in frames: writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) writer.release() def extract_last_frame(video_path: str) -> Optional[np.ndarray]: cap = cv2.VideoCapture(str(video_path)) if not cap.isOpened(): return None total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total > 0: cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1) ret, frame = cap.read() cap.release() if not ret or frame is None: return None return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) def _normalise_list(val, sizes, name="parameter"): if isinstance(val, int): return [val] * len(sizes) if len(val) != len(sizes): raise ValueError(f"{name} length ({len(val)}) != sizes ({len(sizes)})") return list(val) # ==================== Generate ==================== def generate_dataset( output_dir: str = "frozenlake", sizes: List[int] = [8, 16, 32], num_per_size: list = [100, 500, 1000], p: float = 0.8, min_path_ratio: float = 0.1, img_size: int = 512, prompt: str = "Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.", train_ratio: float = 0.9, n_start: int = 2, m_end: int = 3, frames: Optional[int] = None, fps: int = 10, seed: int = 42, use_gym: bool = True, checkpoint_interval: int = 50, ): params = { "sizes": sizes, "num_per_size": num_per_size, "p": p, "min_path_ratio": min_path_ratio, "img_size": img_size, "prompt": prompt, "train_ratio": train_ratio, "n_start": n_start, "m_end": m_end, "frames": frames, "fps": fps, "seed": seed, "use_gym": use_gym, } out = Path(output_dir) img_dir, vid_dir, tbl_dir = out / "images", out / "videos", out / "tables" for d in (img_dir, vid_dir, tbl_dir): d.mkdir(parents=True, exist_ok=True) state = load_checkpoint(out, params) if state and state.completed: return num_list = _normalise_list( num_per_size[0] if len(num_per_size) == 1 else num_per_size, sizes, "num_per_size", ) num_w = len(str(max(num_list))) proc = FrozenLakeProcessor(img_size=img_size) if state is None: random.seed(seed) state = GenerationState( params_hash=_params_hash(params), size_progress={sz: 0 for sz in sizes}, seen_fingerprints=[], all_samples=[], ) print(f"Fresh generation: sizes={sizes}, counts={num_list}, p={p}") else: random.seed(seed) for _ in range(sum(state.size_progress.values()) * 10): random.random() seen = set(state.seen_fingerprints) all_samples = list(state.all_samples) progress = {int(k): v for k, v in state.size_progress.items()} since_ckpt = 0 with tqdm(total=sum(num_list), initial=sum(progress.values()), desc="Total", unit="puzzle") as pbar: for grid_size, target in zip(sizes, num_list): generated = progress.get(grid_size, 0) if generated >= target: continue min_len = max(1, int(grid_size * grid_size * min_path_ratio)) with tqdm(total=target, initial=generated, desc=f"Size {grid_size:3d}", unit="puzzle", leave=False) as pbar_sz: for _ in range((target - generated) * 5): if generated >= target: break try: desc, path = proc.generate_auto( grid_size, p=p, min_path_len=min_len ) except RuntimeError: continue fp = proc.fingerprint(desc) if fp in seen: continue seen.add(fp) base = f"size{grid_size}_{generated:0{num_w}d}" proc.render(desc, use_gym=use_gym).save(str(img_dir / f"{base}.png")) vid_frames = proc.generate_video_frames( desc, path, n_start=n_start, m_end=m_end, frames=frames, use_gym=use_gym, ) save_video_cv2(vid_frames, str(vid_dir / f"{base}.mp4"), fps=fps) proc.save_table(str(tbl_dir / f"{base}.txt"), desc) udrl = proc.path_to_udrl(path) all_samples.append({ "prompt": prompt, "image": f"{base}.png", "video": f"{base}.mp4", "table": f"{base}.txt", "grid_size": grid_size, "grid_desc": desc, "start": list(proc.find_start(desc)), "path_udrl": udrl, "path_length": len(path), "frame_count": len(vid_frames), }) generated += 1 progress[grid_size] = generated since_ckpt += 1 pbar_sz.update(1) pbar.update(1) if since_ckpt >= checkpoint_interval: state.size_progress = progress state.seen_fingerprints = list(seen) state.all_samples = all_samples save_checkpoint(out, state, params) since_ckpt = 0 tqdm.write(f"Size {grid_size}: {generated} puzzles") with open(out / "path.json", "w") as f: json.dump(dict(sorted((s["image"], s["path_udrl"]) for s in all_samples)), f, indent=4) # Stratified split: ensure each size is proportionally represented in test set random.seed(seed + 1) by_size: Dict[int, List[Dict]] = {} for s in all_samples: by_size.setdefault(s["maze_size"], []).append(s) train_samples, test_samples = [], [] for sz in sorted(by_size): group = by_size[sz] random.shuffle(group) sz_split = int(len(group) * train_ratio) train_samples.extend(group[:sz_split]) test_samples.extend(group[sz_split:]) random.shuffle(train_samples) random.shuffle(test_samples) split = len(train_samples) def _write_jsonl(samples, path): with open(path, "w") as f: for s in samples: f.write(json.dumps(s) + "\n") _write_jsonl(train_samples, out / "train.jsonl") _write_jsonl(test_samples, out / "test.jsonl") for name, samples in [("train", train_samples), ("test", test_samples)]: with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f: w = csv.writer(f) w.writerow(["input_image", "video", "prompt"]) for s in samples: w.writerow([f"images/{s['image']}", f"videos/{s['video']}", s["prompt"]]) state.size_progress = progress state.seen_fingerprints = list(seen) state.all_samples = all_samples state.completed = True save_checkpoint(out, state, params) lengths = [s["path_length"] for s in all_samples] fcounts = [s["frame_count"] for s in all_samples] print(f"\n✓ Complete: {out}/ | {len(all_samples)} puzzles " f"(train={split}, test={len(all_samples)-split})") print(f" Paths: avg={np.mean(lengths):.1f} min={min(lengths)} max={max(lengths)}") # ==================== Eval ==================== def eval_videos( video_dir: str, table_dir: str, output_json: Optional[str] = None, gt_json: Optional[str] = None, use_gym: bool = True, ): proc = FrozenLakeProcessor() vid_root, tbl_root = Path(video_dir), Path(table_dir) if output_json is None: output_json = str(vid_root / "0_result.json") videos = sorted( vid_root.glob("*.mp4"), key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)], ) if not videos: print(f"No .mp4 in {vid_root}"); return extracted: Dict[str, str] = {} missing_tbl = missing_frame = 0 for vp in tqdm(videos, desc="Extracting"): desc = proc.load_table(str(tbl_root / f"{vp.stem}.txt")) if desc is None: missing_tbl += 1; continue start = proc.find_start(desc) if start is None: missing_tbl += 1; continue lf = extract_last_frame(str(vp)) if lf is None: missing_frame += 1; continue extracted[f"{vp.stem}.png"] = proc.extract_path_from_pixels( lf, len(desc), len(desc[0]), start, desc) with open(output_json, "w") as f: json.dump(extracted, f, indent=4) verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim correct = total = 0 size_stats: Dict[int, Dict[str, int]] = {} top: List[Dict] = [] for name, udrl in extracted.items(): desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt")) if desc is None: continue total += 1 sz = len(desc) size_stats.setdefault(sz, {"total": 0, "correct": 0}) size_stats[sz]["total"] += 1 if verify_fn(desc, udrl): correct += 1 size_stats[sz]["correct"] += 1 top.append({"name": name, "length": len(udrl)}) acc = correct / total * 100 if total else 0 print(f"\n{'='*50}\nEval: {correct}/{total} ({acc:.2f}%) | " f"missing_tbl={missing_tbl} bad_frame={missing_frame}") for sz in sorted(size_stats): s = size_stats[sz] print(f" Size {sz:3d}: {s['correct']}/{s['total']} " f"({s['correct']/s['total']*100:.1f}%)") top.sort(key=lambda x: x["length"], reverse=True) for i, item in enumerate(top[:3]): print(f" Top {i+1}: {item['name']} (len={item['length']})") if gt_json: try: with open(gt_json) as f: gt = json.load(f) bins: Dict[str, Dict[str, int]] = {} for name, pred in extracted.items(): if name not in gt: continue lo = (len(gt[name]) // 10) * 10 label = f"{lo:3d}-{lo+9:3d}" bins.setdefault(label, {"total": 0, "correct": 0}) bins[label]["total"] += 1 desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt")) if desc and verify_fn(desc, pred): bins[label]["correct"] += 1 if bins: print("\nBy GT path length:") for label in sorted(bins): b = bins[label] print(f" {label}: {b['correct']}/{b['total']} " f"({b['correct']/b['total']*100:.1f}%)") except Exception: pass print(f"{'='*50}") def verify_results(json_file: str, table_dir: str, use_gym: bool = True): proc = FrozenLakeProcessor() with open(json_file) as f: solutions = json.load(f) verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim correct = skipped = valid = 0 for name, udrl in solutions.items(): desc = proc.load_table(str(Path(table_dir) / f"{name.replace('.png','')}.txt")) if desc is None: skipped += 1; continue valid += 1 if verify_fn(desc, udrl): correct += 1 acc = correct / valid * 100 if valid else 0 print(f"\nVerification: {correct}/{valid} ({acc:.2f}%)") # ==================== CLI ==================== def parse_args(): p = argparse.ArgumentParser(description="FrozenLake video dataset") sub = p.add_subparsers(dest="command") gen = sub.add_parser("generate") gen.add_argument("--output-dir", default="frozenlake") gen.add_argument("--sizes", type=int, nargs="+", default=[8, 12, 16, 32]) gen.add_argument("--num-per-size", type=int, nargs="+", default=[1000, 2000, 5000, 10000]) gen.add_argument("--p", type=float, default=0.5) gen.add_argument("--min-path-ratio", type=float, default=0.1) gen.add_argument("--img-size", type=int, default=1024) gen.add_argument("--prompt", default="Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.") gen.add_argument("--train-ratio", type=float, default=0.9) gen.add_argument("--n-start", type=int, default=2) gen.add_argument("--m-end", type=int, default=3) gen.add_argument("--frames", type=int, default=None) gen.add_argument("--fps", type=int, default=10) gen.add_argument("--seed", type=int, default=42) gen.add_argument("--no-gym", action="store_true") gen.add_argument("--checkpoint-interval", type=int, default=50) ev = sub.add_parser("eval") ev.add_argument("video_dir") ev.add_argument("--table-dir", required=True) ev.add_argument("--output-json", default=None) ev.add_argument("--gt-json", default=None) ev.add_argument("--no-gym", action="store_true") ver = sub.add_parser("verify") ver.add_argument("json_file") ver.add_argument("--table-dir", required=True) ver.add_argument("--no-gym", action="store_true") return p.parse_args() if __name__ == "__main__": args = parse_args() if args.command == "generate": kw = {k: v for k, v in vars(args).items() if k not in ("command", "no_gym")} kw["use_gym"] = not args.no_gym generate_dataset(**kw) elif args.command == "eval": eval_videos(args.video_dir, args.table_dir, args.output_json, args.gt_json, not args.no_gym) elif args.command == "verify": verify_results(args.json_file, args.table_dir, not args.no_gym) else: print("Usage: python frozenlake_video_gen.py {generate|eval|verify} ...")