Visual-Reasoning / frozenlake /data_process.py
Jayce-Ping's picture
Add files using upload-large-folder tool
2e9398f verified
"""
FrozenLake Video Dataset Generator — generate, eval, verify.
Uses generate_auto() which picks random (small grids) or guided (large grids)
strategy automatically.
Usage:
python frozenlake_video_gen.py generate --output-dir frozenlake \
--sizes 8 16 32 --num-per-size 100 500 1000 --p 0.8
python frozenlake_video_gen.py eval result_videos/ --table-dir frozenlake/tables
python frozenlake_video_gen.py verify results.json --table-dir frozenlake/tables
"""
import json
import csv
import hashlib
import random
import re
import argparse
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Optional
import cv2
import numpy as np
from tqdm import tqdm
from frozenlake_processor import FrozenLakeProcessor
# ==================== Checkpoint ====================
@dataclass
class GenerationState:
params_hash: str
size_progress: Dict[int, int]
seen_fingerprints: List[str]
all_samples: List[Dict]
completed: bool = False
def to_dict(self) -> Dict:
return asdict(self)
@classmethod
def from_dict(cls, d: Dict) -> "GenerationState":
return cls(**d)
def _params_hash(params: Dict) -> str:
key = {k: v for k, v in params.items() if k != "output_dir"}
return hashlib.md5(json.dumps(key, sort_keys=True).encode()).hexdigest()[:12]
def load_checkpoint(output_dir: Path, params: Dict) -> Optional[GenerationState]:
meta = output_dir / "metadata.json"
if not meta.exists():
return None
with open(meta) as f:
data = json.load(f)
state = GenerationState.from_dict(data["state"])
expected = _params_hash(params)
if state.params_hash != expected:
print(f"⚠️ Params changed ({state.params_hash}{expected}), starting fresh")
return None
if state.completed:
print("✓ Already completed")
return state
print(f"✓ Resuming: {sum(state.size_progress.values())} done")
return state
def save_checkpoint(output_dir: Path, state: GenerationState, params: Dict):
meta = output_dir / "metadata.json"
tmp = meta.with_suffix(".tmp")
with open(tmp, "w") as f:
json.dump({"params": params, "state": state.to_dict()}, f, indent=2)
tmp.rename(meta)
# ==================== Video I/O ====================
def save_video_cv2(frames: list, path: str, fps: int = 10):
first = np.array(frames[0])
h, w = first.shape[:2]
writer = cv2.VideoWriter(str(path), cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h))
for frame in frames:
writer.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
writer.release()
def extract_last_frame(video_path: str) -> Optional[np.ndarray]:
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
return None
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total > 0:
cap.set(cv2.CAP_PROP_POS_FRAMES, total - 1)
ret, frame = cap.read()
cap.release()
if not ret or frame is None:
return None
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
def _normalise_list(val, sizes, name="parameter"):
if isinstance(val, int):
return [val] * len(sizes)
if len(val) != len(sizes):
raise ValueError(f"{name} length ({len(val)}) != sizes ({len(sizes)})")
return list(val)
# ==================== Generate ====================
def generate_dataset(
output_dir: str = "frozenlake",
sizes: List[int] = [8, 16, 32],
num_per_size: list = [100, 500, 1000],
p: float = 0.8,
min_path_ratio: float = 0.1,
img_size: int = 512,
prompt: str = "Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.",
train_ratio: float = 0.9,
n_start: int = 2,
m_end: int = 3,
frames: Optional[int] = None,
fps: int = 10,
seed: int = 42,
use_gym: bool = True,
checkpoint_interval: int = 50,
):
params = {
"sizes": sizes, "num_per_size": num_per_size,
"p": p, "min_path_ratio": min_path_ratio, "img_size": img_size,
"prompt": prompt, "train_ratio": train_ratio,
"n_start": n_start, "m_end": m_end, "frames": frames,
"fps": fps, "seed": seed, "use_gym": use_gym,
}
out = Path(output_dir)
img_dir, vid_dir, tbl_dir = out / "images", out / "videos", out / "tables"
for d in (img_dir, vid_dir, tbl_dir):
d.mkdir(parents=True, exist_ok=True)
state = load_checkpoint(out, params)
if state and state.completed:
return
num_list = _normalise_list(
num_per_size[0] if len(num_per_size) == 1 else num_per_size,
sizes, "num_per_size",
)
num_w = len(str(max(num_list)))
proc = FrozenLakeProcessor(img_size=img_size)
if state is None:
random.seed(seed)
state = GenerationState(
params_hash=_params_hash(params),
size_progress={sz: 0 for sz in sizes},
seen_fingerprints=[], all_samples=[],
)
print(f"Fresh generation: sizes={sizes}, counts={num_list}, p={p}")
else:
random.seed(seed)
for _ in range(sum(state.size_progress.values()) * 10):
random.random()
seen = set(state.seen_fingerprints)
all_samples = list(state.all_samples)
progress = {int(k): v for k, v in state.size_progress.items()}
since_ckpt = 0
with tqdm(total=sum(num_list), initial=sum(progress.values()),
desc="Total", unit="puzzle") as pbar:
for grid_size, target in zip(sizes, num_list):
generated = progress.get(grid_size, 0)
if generated >= target:
continue
min_len = max(1, int(grid_size * grid_size * min_path_ratio))
with tqdm(total=target, initial=generated,
desc=f"Size {grid_size:3d}", unit="puzzle", leave=False) as pbar_sz:
for _ in range((target - generated) * 5):
if generated >= target:
break
try:
desc, path = proc.generate_auto(
grid_size, p=p, min_path_len=min_len
)
except RuntimeError:
continue
fp = proc.fingerprint(desc)
if fp in seen:
continue
seen.add(fp)
base = f"size{grid_size}_{generated:0{num_w}d}"
proc.render(desc, use_gym=use_gym).save(str(img_dir / f"{base}.png"))
vid_frames = proc.generate_video_frames(
desc, path, n_start=n_start, m_end=m_end,
frames=frames, use_gym=use_gym,
)
save_video_cv2(vid_frames, str(vid_dir / f"{base}.mp4"), fps=fps)
proc.save_table(str(tbl_dir / f"{base}.txt"), desc)
udrl = proc.path_to_udrl(path)
all_samples.append({
"prompt": prompt, "image": f"{base}.png",
"video": f"{base}.mp4", "table": f"{base}.txt",
"grid_size": grid_size, "grid_desc": desc,
"start": list(proc.find_start(desc)),
"path_udrl": udrl, "path_length": len(path),
"frame_count": len(vid_frames),
})
generated += 1
progress[grid_size] = generated
since_ckpt += 1
pbar_sz.update(1)
pbar.update(1)
if since_ckpt >= checkpoint_interval:
state.size_progress = progress
state.seen_fingerprints = list(seen)
state.all_samples = all_samples
save_checkpoint(out, state, params)
since_ckpt = 0
tqdm.write(f"Size {grid_size}: {generated} puzzles")
with open(out / "path.json", "w") as f:
json.dump(dict(sorted((s["image"], s["path_udrl"]) for s in all_samples)), f, indent=4)
# Stratified split: ensure each size is proportionally represented in test set
random.seed(seed + 1)
by_size: Dict[int, List[Dict]] = {}
for s in all_samples:
by_size.setdefault(s["maze_size"], []).append(s)
train_samples, test_samples = [], []
for sz in sorted(by_size):
group = by_size[sz]
random.shuffle(group)
sz_split = int(len(group) * train_ratio)
train_samples.extend(group[:sz_split])
test_samples.extend(group[sz_split:])
random.shuffle(train_samples)
random.shuffle(test_samples)
split = len(train_samples)
def _write_jsonl(samples, path):
with open(path, "w") as f:
for s in samples:
f.write(json.dumps(s) + "\n")
_write_jsonl(train_samples, out / "train.jsonl")
_write_jsonl(test_samples, out / "test.jsonl")
for name, samples in [("train", train_samples), ("test", test_samples)]:
with open(out / f"{name}.csv", "w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerow(["input_image", "video", "prompt"])
for s in samples:
w.writerow([f"images/{s['image']}", f"videos/{s['video']}", s["prompt"]])
state.size_progress = progress
state.seen_fingerprints = list(seen)
state.all_samples = all_samples
state.completed = True
save_checkpoint(out, state, params)
lengths = [s["path_length"] for s in all_samples]
fcounts = [s["frame_count"] for s in all_samples]
print(f"\n✓ Complete: {out}/ | {len(all_samples)} puzzles "
f"(train={split}, test={len(all_samples)-split})")
print(f" Paths: avg={np.mean(lengths):.1f} min={min(lengths)} max={max(lengths)}")
# ==================== Eval ====================
def eval_videos(
video_dir: str, table_dir: str,
output_json: Optional[str] = None, gt_json: Optional[str] = None,
use_gym: bool = True,
):
proc = FrozenLakeProcessor()
vid_root, tbl_root = Path(video_dir), Path(table_dir)
if output_json is None:
output_json = str(vid_root / "0_result.json")
videos = sorted(
vid_root.glob("*.mp4"),
key=lambda p: [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", p.stem)],
)
if not videos:
print(f"No .mp4 in {vid_root}"); return
extracted: Dict[str, str] = {}
missing_tbl = missing_frame = 0
for vp in tqdm(videos, desc="Extracting"):
desc = proc.load_table(str(tbl_root / f"{vp.stem}.txt"))
if desc is None:
missing_tbl += 1; continue
start = proc.find_start(desc)
if start is None:
missing_tbl += 1; continue
lf = extract_last_frame(str(vp))
if lf is None:
missing_frame += 1; continue
extracted[f"{vp.stem}.png"] = proc.extract_path_from_pixels(
lf, len(desc), len(desc[0]), start, desc)
with open(output_json, "w") as f:
json.dump(extracted, f, indent=4)
verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim
correct = total = 0
size_stats: Dict[int, Dict[str, int]] = {}
top: List[Dict] = []
for name, udrl in extracted.items():
desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt"))
if desc is None: continue
total += 1
sz = len(desc)
size_stats.setdefault(sz, {"total": 0, "correct": 0})
size_stats[sz]["total"] += 1
if verify_fn(desc, udrl):
correct += 1
size_stats[sz]["correct"] += 1
top.append({"name": name, "length": len(udrl)})
acc = correct / total * 100 if total else 0
print(f"\n{'='*50}\nEval: {correct}/{total} ({acc:.2f}%) | "
f"missing_tbl={missing_tbl} bad_frame={missing_frame}")
for sz in sorted(size_stats):
s = size_stats[sz]
print(f" Size {sz:3d}: {s['correct']}/{s['total']} "
f"({s['correct']/s['total']*100:.1f}%)")
top.sort(key=lambda x: x["length"], reverse=True)
for i, item in enumerate(top[:3]):
print(f" Top {i+1}: {item['name']} (len={item['length']})")
if gt_json:
try:
with open(gt_json) as f:
gt = json.load(f)
bins: Dict[str, Dict[str, int]] = {}
for name, pred in extracted.items():
if name not in gt: continue
lo = (len(gt[name]) // 10) * 10
label = f"{lo:3d}-{lo+9:3d}"
bins.setdefault(label, {"total": 0, "correct": 0})
bins[label]["total"] += 1
desc = proc.load_table(str(tbl_root / f"{name.replace('.png','')}.txt"))
if desc and verify_fn(desc, pred):
bins[label]["correct"] += 1
if bins:
print("\nBy GT path length:")
for label in sorted(bins):
b = bins[label]
print(f" {label}: {b['correct']}/{b['total']} "
f"({b['correct']/b['total']*100:.1f}%)")
except Exception:
pass
print(f"{'='*50}")
def verify_results(json_file: str, table_dir: str, use_gym: bool = True):
proc = FrozenLakeProcessor()
with open(json_file) as f:
solutions = json.load(f)
verify_fn = proc.verify_path_gym if use_gym else proc.verify_path_sim
correct = skipped = valid = 0
for name, udrl in solutions.items():
desc = proc.load_table(str(Path(table_dir) / f"{name.replace('.png','')}.txt"))
if desc is None:
skipped += 1; continue
valid += 1
if verify_fn(desc, udrl):
correct += 1
acc = correct / valid * 100 if valid else 0
print(f"\nVerification: {correct}/{valid} ({acc:.2f}%)")
# ==================== CLI ====================
def parse_args():
p = argparse.ArgumentParser(description="FrozenLake video dataset")
sub = p.add_subparsers(dest="command")
gen = sub.add_parser("generate")
gen.add_argument("--output-dir", default="frozenlake")
gen.add_argument("--sizes", type=int, nargs="+", default=[8, 12, 16, 32])
gen.add_argument("--num-per-size", type=int, nargs="+", default=[1000, 2000, 5000, 10000])
gen.add_argument("--p", type=float, default=0.5)
gen.add_argument("--min-path-ratio", type=float, default=0.1)
gen.add_argument("--img-size", type=int, default=1024)
gen.add_argument("--prompt", default="Draw a continuous red line connecting the Start point to the Goal point, avoiding all holes.")
gen.add_argument("--train-ratio", type=float, default=0.9)
gen.add_argument("--n-start", type=int, default=2)
gen.add_argument("--m-end", type=int, default=3)
gen.add_argument("--frames", type=int, default=None)
gen.add_argument("--fps", type=int, default=10)
gen.add_argument("--seed", type=int, default=42)
gen.add_argument("--no-gym", action="store_true")
gen.add_argument("--checkpoint-interval", type=int, default=50)
ev = sub.add_parser("eval")
ev.add_argument("video_dir")
ev.add_argument("--table-dir", required=True)
ev.add_argument("--output-json", default=None)
ev.add_argument("--gt-json", default=None)
ev.add_argument("--no-gym", action="store_true")
ver = sub.add_parser("verify")
ver.add_argument("json_file")
ver.add_argument("--table-dir", required=True)
ver.add_argument("--no-gym", action="store_true")
return p.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.command == "generate":
kw = {k: v for k, v in vars(args).items() if k not in ("command", "no_gym")}
kw["use_gym"] = not args.no_gym
generate_dataset(**kw)
elif args.command == "eval":
eval_videos(args.video_dir, args.table_dir, args.output_json,
args.gt_json, not args.no_gym)
elif args.command == "verify":
verify_results(args.json_file, args.table_dir, not args.no_gym)
else:
print("Usage: python frozenlake_video_gen.py {generate|eval|verify} ...")