"""Generic prompt loading and text-to-image JSON validation.""" from __future__ import annotations import csv import json import re from dataclasses import dataclass, field from pathlib import Path from typing import Any @dataclass(frozen=True, slots=True) class PromptItem: """One text-to-image prompt to process.""" prompt_id: str row_number: int prompt: str metadata: dict[str, Any] = field(default_factory=dict) REQUIRED_T2I_KEYS = { "subjects", "subject_details", "background_setting", "lighting", "text_and_signage_elements", "resolution", "aspect_ratio", "comprehensive_t2i_caption", } PROMPT_COLUMNS = ("prompt", "Prompt") ID_COLUMNS = ("id", "ID", "prompt_id", "Prompt ID") _SAFE_ID_RE = re.compile(r"[^A-Za-z0-9_.-]+") def prompt_dir_name(item: PromptItem) -> str: """Return the deterministic output directory name for a prompt.""" raw_id = item.prompt_id.strip() if raw_id.isdigit(): return f"{int(raw_id):04d}" cleaned = _SAFE_ID_RE.sub("_", raw_id).strip("._-") return cleaned or f"row_{item.row_number + 1:04d}" def load_prompt_items( *, prompt: str | None = None, prompts_path: Path | None = None, limit: int | None = None, ) -> list[PromptItem]: """Load prompts from a literal prompt or a txt/jsonl/csv file.""" if bool(prompt) == bool(prompts_path): raise ValueError("Provide exactly one of --prompt or --prompts.") if prompt: items = [PromptItem(prompt_id="1", row_number=0, prompt=prompt.strip())] elif prompts_path is not None: items = _load_prompts_path(prompts_path) else: items = [] items = [item for item in items if item.prompt.strip()] if limit is not None and limit >= 0: items = items[:limit] _validate_unique_output_dirs(items) return items def _load_prompts_path(path: Path) -> list[PromptItem]: suffix = path.suffix.lower() if suffix == ".txt": return _load_txt_prompts(path) if suffix == ".jsonl": return _load_jsonl_prompts(path) if suffix == ".csv": return _load_csv_prompts(path) raise ValueError(f"Unsupported prompt file extension {suffix!r}. Use .txt, .jsonl, or .csv.") def _load_txt_prompts(path: Path) -> list[PromptItem]: items: list[PromptItem] = [] for row_number, line in enumerate(path.read_text(encoding="utf-8").splitlines()): prompt = line.strip() if not prompt: continue items.append(PromptItem(prompt_id=str(len(items) + 1), row_number=row_number, prompt=prompt)) return items def _load_jsonl_prompts(path: Path) -> list[PromptItem]: items: list[PromptItem] = [] with path.open(encoding="utf-8") as f: for row_number, line in enumerate(f): raw = line.strip() if not raw: continue parsed = json.loads(raw) if isinstance(parsed, str): prompt = parsed.strip() prompt_id = str(len(items) + 1) metadata: dict[str, Any] = {} elif isinstance(parsed, dict): prompt = str(parsed.get("prompt") or parsed.get("Prompt") or "").strip() prompt_id = str(parsed.get("id") or parsed.get("prompt_id") or len(items) + 1) metadata = {key: value for key, value in parsed.items() if key not in {"prompt", "Prompt"}} else: raise ValueError(f"JSONL row {row_number + 1} must be a string or object.") if prompt: items.append(PromptItem(prompt_id=prompt_id, row_number=row_number, prompt=prompt, metadata=metadata)) return items def _load_csv_prompts(path: Path) -> list[PromptItem]: items: list[PromptItem] = [] with path.open(newline="", encoding="utf-8") as f: reader = csv.DictReader(f) for row_number, row in enumerate(reader): prompt_key = _first_present_key(row, PROMPT_COLUMNS) if prompt_key is None: raise ValueError(f"CSV must include one of these prompt columns: {', '.join(PROMPT_COLUMNS)}.") prompt = str(row.get(prompt_key) or "").strip() if not prompt: continue id_key = _first_present_key(row, ID_COLUMNS) prompt_id = str(row.get(id_key) or len(items) + 1) if id_key is not None else str(len(items) + 1) items.append(PromptItem(prompt_id=prompt_id, row_number=row_number, prompt=prompt, metadata=dict(row))) return items def _first_present_key(row: dict[str, Any], keys: tuple[str, ...]) -> str | None: for key in keys: if key in row: return key return None def _validate_unique_output_dirs(items: list[PromptItem]) -> None: seen: dict[str, str] = {} for item in items: dirname = prompt_dir_name(item) previous = seen.get(dirname) if previous is not None: raise ValueError(f"Prompt ids {previous!r} and {item.prompt_id!r} map to the same output dir {dirname!r}.") seen[dirname] = item.prompt_id def validate_t2i_json(data: dict[str, Any], prompt_id: str) -> None: """Validate the minimum structured T2I JSON shape expected by Cosmos3.""" missing = sorted(REQUIRED_T2I_KEYS - set(data)) if missing: raise ValueError(f"Prompt JSON for {prompt_id} is missing required keys: {missing}") if not isinstance(data.get("subjects"), list): raise ValueError(f"Prompt JSON for {prompt_id}: subjects must be a list.") if not isinstance(data.get("text_and_signage_elements"), list): raise ValueError(f"Prompt JSON for {prompt_id}: text_and_signage_elements must be a list.") caption = data.get("comprehensive_t2i_caption") if not isinstance(caption, str) or not caption.strip(): raise ValueError(f"Prompt JSON for {prompt_id}: comprehensive_t2i_caption is empty.") resolution = data.get("resolution") if not isinstance(resolution, dict) or not {"H", "W"}.issubset(resolution): raise ValueError(f"Prompt JSON for {prompt_id}: resolution must contain H and W.") aspect_ratio = data.get("aspect_ratio") if not isinstance(aspect_ratio, str) or not aspect_ratio.strip(): raise ValueError(f"Prompt JSON for {prompt_id}: aspect_ratio must be a non-empty string.")