| | """ |
| | FrozenLake Video Dataset Generator — generate, eval, verify. |
| | |
| | Uses generate_auto() which picks random (small grids) or guided (large grids) |
| | strategy automatically. |
| | |
| | Usage: |
| | python frozenlake_video_gen.py generate --output-dir frozenlake \ |
| | --sizes 8 16 32 --num-per-size 100 500 1000 --p 0.8 |
| | python frozenlake_video_gen.py eval result_videos/ --table-dir frozenlake/tables |
| | python frozenlake_video_gen.py verify results.json --table-dir frozenlake/tables |
| | """ |
| | import json |
| | import csv |
| | import hashlib |
| | import random |
| | import re |
| | import argparse |
| | from dataclasses import dataclass, asdict |
| | from pathlib import Path |
| | from typing import Dict, List, Optional |
| |
|
| | import cv2 |
| | import numpy as np |
| | from tqdm import tqdm |
| |
|
| | from frozenlake_processor import FrozenLakeProcessor |
| |
|
| |
|
| | |
| |
|
| | @dataclass |
| | class GenerationState: |
| | params_hash: str |
| | size_progress: Dict[int, int] |
| | seen_fingerprints: List[str] |
| | all_samples: List[Dict] |
| | completed: bool = False |
| |
|
| | def to_dict(self) -> Dict: |
| | return asdict(self) |
| |
|
| | @classmethod |
| | def from_dict(cls, d: Dict) -> "GenerationState": |
| | return cls(**d) |
| |
|
| |
|
| | def _params_hash(params: Dict) -> str: |
| | key = {k: v for k, v in params.items() if k != "output_dir"} |
| | return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12] |
| |
|
| |
|
| | def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]: |
| | meta = output_dir / "metadata.json" |
| | if not meta.exists(): |
| | return None |
| | with open(meta) as f: |
| | data = json.load(f) |
| | state = GenerationState.from_dict(data["state"]) |
| | expected = _params_hash(params) |
| | if state.params_hash != expected: |
| | print(f"⚠️ Params changed ({state.params_hash} → {expected}), starting fresh") |
| | return None |
| | if state.completed: |
| | print("✓ Already completed") |
| | return state |
| | print(f"✓ Resuming: {sum(state.size_progress.values())} done") |
| | return state |
| |
|
| |
|
| | def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict): |
| | meta = output_dir / "metadata.json" |
| | tmp = meta.with_suffix(".tmp") |
| | with open(tmp, "w") as f: |
| | json.dump({"params": params, "state": state.to_dict()}, f, indent=2) |
| | tmp.rename(meta) |
| |
|
| |
|
| | |
| |
|
| | def save_video_cv2(frames: list, path: str, fps: int = 10): |
| | first = np.array(frames[0]) |
| | h, w = first.shape[:2] |
| | writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) |
| | for frame in frames: |
| | writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)) |
| | writer.release() |
| |
|
| |
|
| | def extract_last_frame(video_path: str) -> Optional[np.ndarray]: |
| | cap = cv2.VideoCapture(str(video_path)) |
| | if not cap.isOpened(): |
| | return None |
| | total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | if total > 0: |
| | cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1) |
| | ret, frame = cap.read() |
| | cap.release() |
| | if not ret or frame is None: |
| | return None |
| | return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| |
|
| |
|
| | def _normalise_list(val, sizes, name="parameter"): |
| | if isinstance(val, int): |
| | return [val] * len(sizes) |
| | if len(val) != len(sizes): |
| | raise ValueError(f"{name} length ({len(val)}) != sizes ({len(sizes)})") |
| | return list(val) |
| |
|
| |
|
| | |
| |
|
| | def generate_dataset( |
| | output_dir: str = "frozenlake", |
| | sizes: List[int] = [8, 16, 32], |
| | num_per_size: list = [100, 500, 1000], |
| | p: float = 0.8, |
| | min_path_ratio: float = 0.1, |
| | img_size: int = 512, |
| | prompt: str = "Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.", |
| | train_ratio: float = 0.9, |
| | n_start: int = 2, |
| | m_end: int = 3, |
| | frames: Optional[int] = None, |
| | fps: int = 10, |
| | seed: int = 42, |
| | use_gym: bool = True, |
| | checkpoint_interval: int = 50, |
| | ): |
| | params = { |
| | "sizes": sizes, "num_per_size": num_per_size, |
| | "p": p, "min_path_ratio": min_path_ratio, "img_size": img_size, |
| | "prompt": prompt, "train_ratio": train_ratio, |
| | "n_start": n_start, "m_end": m_end, "frames": frames, |
| | "fps": fps, "seed": seed, "use_gym": use_gym, |
| | } |
| |
|
| | out = Path(output_dir) |
| | img_dir, vid_dir, tbl_dir = out / "images", out / "videos", out / "tables" |
| | for d in (img_dir, vid_dir, tbl_dir): |
| | d.mkdir(parents=True, exist_ok=True) |
| |
|
| | state = load_checkpoint(out, params) |
| | if state and state.completed: |
| | return |
| |
|
| | num_list = _normalise_list( |
| | num_per_size[0] if len(num_per_size) == 1 else num_per_size, |
| | sizes, "num_per_size", |
| | ) |
| | num_w = len(str(max(num_list))) |
| | proc = FrozenLakeProcessor(img_size=img_size) |
| |
|
| | if state is None: |
| | random.seed(seed) |
| | state = GenerationState( |
| | params_hash=_params_hash(params), |
| | size_progress={sz: 0 for sz in sizes}, |
| | seen_fingerprints=[], all_samples=[], |
| | ) |
| | print(f"Fresh generation: sizes={sizes}, counts={num_list}, p={p}") |
| | else: |
| | random.seed(seed) |
| | for _ in range(sum(state.size_progress.values()) * 10): |
| | random.random() |
| |
|
| | seen = set(state.seen_fingerprints) |
| | all_samples = list(state.all_samples) |
| | progress = {int(k): v for k, v in state.size_progress.items()} |
| | since_ckpt = 0 |
| |
|
| | with tqdm(total=sum(num_list), initial=sum(progress.values()), |
| | desc="Total", unit="puzzle") as pbar: |
| | for grid_size, target in zip(sizes, num_list): |
| | generated = progress.get(grid_size, 0) |
| | if generated >= target: |
| | continue |
| |
|
| | min_len = max(1, int(grid_size * grid_size * min_path_ratio)) |
| |
|
| | with tqdm(total=target, initial=generated, |
| | desc=f"Size {grid_size:3d}", unit="puzzle", leave=False) as pbar_sz: |
| | for _ in range((target - generated) * 5): |
| | if generated >= target: |
| | break |
| | try: |
| | desc, path = proc.generate_auto( |
| | grid_size, p=p, min_path_len=min_len |
| | ) |
| | except RuntimeError: |
| | continue |
| |
|
| | fp = proc.fingerprint(desc) |
| | if fp in seen: |
| | continue |
| | seen.add(fp) |
| |
|
| | base = f"size{grid_size}_{generated:0{num_w}d}" |
| |
|
| | proc.render(desc, use_gym=use_gym).save(str(img_dir / f"{base}.png")) |
| | vid_frames = proc.generate_video_frames( |
| | desc, path, n_start=n_start, m_end=m_end, |
| | frames=frames, use_gym=use_gym, |
| | ) |
| | save_video_cv2(vid_frames, str(vid_dir / f"{base}.mp4"), fps=fps) |
| | proc.save_table(str(tbl_dir / f"{base}.txt"), desc) |
| |
|
| | udrl = proc.path_to_udrl(path) |
| | all_samples.append({ |
| | "prompt": prompt, "image": f"{base}.png", |
| | "video": f"{base}.mp4", "table": f"{base}.txt", |
| | "grid_size": grid_size, "grid_desc": desc, |
| | "start": list(proc.find_start(desc)), |
| | "path_udrl": udrl, "path_length": len(path), |
| | "frame_count": len(vid_frames), |
| | }) |
| |
|
| | generated += 1 |
| | progress[grid_size] = generated |
| | since_ckpt += 1 |
| | pbar_sz.update(1) |
| | pbar.update(1) |
| |
|
| | if since_ckpt >= checkpoint_interval: |
| | state.size_progress = progress |
| | state.seen_fingerprints = list(seen) |
| | state.all_samples = all_samples |
| | save_checkpoint(out, state, params) |
| | since_ckpt = 0 |
| |
|
| | tqdm.write(f"Size {grid_size}: {generated} puzzles") |
| |
|
| | with open(out / "path.json", "w") as f: |
| | json.dump(dict(sorted((s["image"], s["path_udrl"]) for s in all_samples)), f, indent=4) |
| |
|
| | |
| | random.seed(seed + 1) |
| | by_size: Dict[int, List[Dict]] = {} |
| | for s in all_samples: |
| | by_size.setdefault(s["maze_size"], []).append(s) |
| |
|
| | train_samples, test_samples = [], [] |
| | for sz in sorted(by_size): |
| | group = by_size[sz] |
| | random.shuffle(group) |
| | sz_split = int(len(group) * train_ratio) |
| | train_samples.extend(group[:sz_split]) |
| | test_samples.extend(group[sz_split:]) |
| |
|
| | random.shuffle(train_samples) |
| | random.shuffle(test_samples) |
| | split = len(train_samples) |
| |
|
| | def _write_jsonl(samples, path): |
| | with open(path, "w") as f: |
| | for s in samples: |
| | f.write(json.dumps(s) + "\n") |
| |
|
| | _write_jsonl(train_samples, out / "train.jsonl") |
| | _write_jsonl(test_samples, out / "test.jsonl") |
| |
|
| | for name, samples in [("train", train_samples), ("test", test_samples)]: |
| | with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f: |
| | w = csv.writer(f) |
| | w.writerow(["input_image", "video", "prompt"]) |
| | for s in samples: |
| | w.writerow([f"images/{s['image']}", f"videos/{s['video']}", s["prompt"]]) |
| |
|
| | state.size_progress = progress |
| | state.seen_fingerprints = list(seen) |
| | state.all_samples = all_samples |
| | state.completed = True |
| | save_checkpoint(out, state, params) |
| |
|
| | lengths = [s["path_length"] for s in all_samples] |
| | fcounts = [s["frame_count"] for s in all_samples] |
| | print(f"\n✓ Complete: {out}/ | {len(all_samples)} puzzles " |
| | f"(train={split}, test={len(all_samples)-split})") |
| | print(f" Paths: avg={np.mean(lengths):.1f} min={min(lengths)} max={max(lengths)}") |
| |
|
| |
|
| | |
| |
|
| | def eval_videos( |
| | video_dir: str, table_dir: str, |
| | output_json: Optional[str] = None, gt_json: Optional[str] = None, |
| | use_gym: bool = True, |
| | ): |
| | proc = FrozenLakeProcessor() |
| | vid_root, tbl_root = Path(video_dir), Path(table_dir) |
| | if output_json is None: |
| | output_json = str(vid_root / "0_result.json") |
| |
|
| | videos = sorted( |
| | vid_root.glob("*.mp4"), |
| | key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)], |
| | ) |
| | if not videos: |
| | print(f"No .mp4 in {vid_root}"); return |
| |
|
| | extracted: Dict[str, str] = {} |
| | missing_tbl = missing_frame = 0 |
| |
|
| | for vp in tqdm(videos, desc="Extracting"): |
| | desc = proc.load_table(str(tbl_root / f"{vp.stem}.txt")) |
| | if desc is None: |
| | missing_tbl += 1; continue |
| | start = proc.find_start(desc) |
| | if start is None: |
| | missing_tbl += 1; continue |
| | lf = extract_last_frame(str(vp)) |
| | if lf is None: |
| | missing_frame += 1; continue |
| | extracted[f"{vp.stem}.png"] = proc.extract_path_from_pixels( |
| | lf, len(desc), len(desc[0]), start, desc) |
| |
|
| | with open(output_json, "w") as f: |
| | json.dump(extracted, f, indent=4) |
| |
|
| | verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim |
| | correct = total = 0 |
| | size_stats: Dict[int, Dict[str, int]] = {} |
| | top: List[Dict] = [] |
| |
|
| | for name, udrl in extracted.items(): |
| | desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt")) |
| | if desc is None: continue |
| | total += 1 |
| | sz = len(desc) |
| | size_stats.setdefault(sz, {"total": 0, "correct": 0}) |
| | size_stats[sz]["total"] += 1 |
| | if verify_fn(desc, udrl): |
| | correct += 1 |
| | size_stats[sz]["correct"] += 1 |
| | top.append({"name": name, "length": len(udrl)}) |
| |
|
| | acc = correct / total * 100 if total else 0 |
| | print(f"\n{'='*50}\nEval: {correct}/{total} ({acc:.2f}%) | " |
| | f"missing_tbl={missing_tbl} bad_frame={missing_frame}") |
| | for sz in sorted(size_stats): |
| | s = size_stats[sz] |
| | print(f" Size {sz:3d}: {s['correct']}/{s['total']} " |
| | f"({s['correct']/s['total']*100:.1f}%)") |
| | top.sort(key=lambda x: x["length"], reverse=True) |
| | for i, item in enumerate(top[:3]): |
| | print(f" Top {i+1}: {item['name']} (len={item['length']})") |
| |
|
| | if gt_json: |
| | try: |
| | with open(gt_json) as f: |
| | gt = json.load(f) |
| | bins: Dict[str, Dict[str, int]] = {} |
| | for name, pred in extracted.items(): |
| | if name not in gt: continue |
| | lo = (len(gt[name]) // 10) * 10 |
| | label = f"{lo:3d}-{lo+9:3d}" |
| | bins.setdefault(label, {"total": 0, "correct": 0}) |
| | bins[label]["total"] += 1 |
| | desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt")) |
| | if desc and verify_fn(desc, pred): |
| | bins[label]["correct"] += 1 |
| | if bins: |
| | print("\nBy GT path length:") |
| | for label in sorted(bins): |
| | b = bins[label] |
| | print(f" {label}: {b['correct']}/{b['total']} " |
| | f"({b['correct']/b['total']*100:.1f}%)") |
| | except Exception: |
| | pass |
| | print(f"{'='*50}") |
| |
|
| |
|
| | def verify_results(json_file: str, table_dir: str, use_gym: bool = True): |
| | proc = FrozenLakeProcessor() |
| | with open(json_file) as f: |
| | solutions = json.load(f) |
| | verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim |
| | correct = skipped = valid = 0 |
| | for name, udrl in solutions.items(): |
| | desc = proc.load_table(str(Path(table_dir) / f"{name.replace('.png','')}.txt")) |
| | if desc is None: |
| | skipped += 1; continue |
| | valid += 1 |
| | if verify_fn(desc, udrl): |
| | correct += 1 |
| | acc = correct / valid * 100 if valid else 0 |
| | print(f"\nVerification: {correct}/{valid} ({acc:.2f}%)") |
| |
|
| |
|
| | |
| |
|
| | def parse_args(): |
| | p = argparse.ArgumentParser(description="FrozenLake video dataset") |
| | sub = p.add_subparsers(dest="command") |
| |
|
| | gen = sub.add_parser("generate") |
| | gen.add_argument("--output-dir", default="frozenlake") |
| | gen.add_argument("--sizes", type=int, nargs="+", default=[8, 12, 16, 32]) |
| | gen.add_argument("--num-per-size", type=int, nargs="+", default=[1000, 2000, 5000, 10000]) |
| | gen.add_argument("--p", type=float, default=0.5) |
| | gen.add_argument("--min-path-ratio", type=float, default=0.1) |
| | gen.add_argument("--img-size", type=int, default=1024) |
| | gen.add_argument("--prompt", default="Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.") |
| | gen.add_argument("--train-ratio", type=float, default=0.9) |
| | gen.add_argument("--n-start", type=int, default=2) |
| | gen.add_argument("--m-end", type=int, default=3) |
| | gen.add_argument("--frames", type=int, default=None) |
| | gen.add_argument("--fps", type=int, default=10) |
| | gen.add_argument("--seed", type=int, default=42) |
| | gen.add_argument("--no-gym", action="store_true") |
| | gen.add_argument("--checkpoint-interval", type=int, default=50) |
| |
|
| | ev = sub.add_parser("eval") |
| | ev.add_argument("video_dir") |
| | ev.add_argument("--table-dir", required=True) |
| | ev.add_argument("--output-json", default=None) |
| | ev.add_argument("--gt-json", default=None) |
| | ev.add_argument("--no-gym", action="store_true") |
| |
|
| | ver = sub.add_parser("verify") |
| | ver.add_argument("json_file") |
| | ver.add_argument("--table-dir", required=True) |
| | ver.add_argument("--no-gym", action="store_true") |
| |
|
| | return p.parse_args() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| | if args.command == "generate": |
| | kw = {k: v for k, v in vars(args).items() if k not in ("command", "no_gym")} |
| | kw["use_gym"] = not args.no_gym |
| | generate_dataset(**kw) |
| | elif args.command == "eval": |
| | eval_videos(args.video_dir, args.table_dir, args.output_json, |
| | args.gt_json, not args.no_gym) |
| | elif args.command == "verify": |
| | verify_results(args.json_file, args.table_dir, not args.no_gym) |
| | else: |
| | print("Usage: python frozenlake_video_gen.py {generate|eval|verify} ...") |