mingyuliutw's picture
Super-squash branch 'main' using huggingface_hub
fdafd05
raw
history blame
6.34 kB
"""Generic prompt loading and text-to-image JSON validation."""
from __future__ import annotations
import csv
import json
import re
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass(frozen=True, slots=True)
class PromptItem:
"""One text-to-image prompt to process."""
prompt_id: str
row_number: int
prompt: str
metadata: dict[str, Any] = field(default_factory=dict)
REQUIRED_T2I_KEYS = {
"subjects",
"subject_details",
"background_setting",
"lighting",
"text_and_signage_elements",
"resolution",
"aspect_ratio",
"comprehensive_t2i_caption",
}
PROMPT_COLUMNS = ("prompt", "Prompt")
ID_COLUMNS = ("id", "ID", "prompt_id", "Prompt ID")
_SAFE_ID_RE = re.compile(r"[^A-Za-z0-9_.-]+")
def prompt_dir_name(item: PromptItem) -> str:
"""Return the deterministic output directory name for a prompt."""
raw_id = item.prompt_id.strip()
if raw_id.isdigit():
return f"{int(raw_id):04d}"
cleaned = _SAFE_ID_RE.sub("_", raw_id).strip("._-")
return cleaned or f"row_{item.row_number + 1:04d}"
def load_prompt_items(
*,
prompt: str | None = None,
prompts_path: Path | None = None,
limit: int | None = None,
) -> list[PromptItem]:
"""Load prompts from a literal prompt or a txt/jsonl/csv file."""
if bool(prompt) == bool(prompts_path):
raise ValueError("Provide exactly one of --prompt or --prompts.")
if prompt:
items = [PromptItem(prompt_id="1", row_number=0, prompt=prompt.strip())]
elif prompts_path is not None:
items = _load_prompts_path(prompts_path)
else:
items = []
items = [item for item in items if item.prompt.strip()]
if limit is not None and limit >= 0:
items = items[:limit]
_validate_unique_output_dirs(items)
return items
def _load_prompts_path(path: Path) -> list[PromptItem]:
suffix = path.suffix.lower()
if suffix == ".txt":
return _load_txt_prompts(path)
if suffix == ".jsonl":
return _load_jsonl_prompts(path)
if suffix == ".csv":
return _load_csv_prompts(path)
raise ValueError(f"Unsupported prompt file extension {suffix!r}. Use .txt, .jsonl, or .csv.")
def _load_txt_prompts(path: Path) -> list[PromptItem]:
items: list[PromptItem] = []
for row_number, line in enumerate(path.read_text(encoding="utf-8").splitlines()):
prompt = line.strip()
if not prompt:
continue
items.append(PromptItem(prompt_id=str(len(items) + 1), row_number=row_number, prompt=prompt))
return items
def _load_jsonl_prompts(path: Path) -> list[PromptItem]:
items: list[PromptItem] = []
with path.open(encoding="utf-8") as f:
for row_number, line in enumerate(f):
raw = line.strip()
if not raw:
continue
parsed = json.loads(raw)
if isinstance(parsed, str):
prompt = parsed.strip()
prompt_id = str(len(items) + 1)
metadata: dict[str, Any] = {}
elif isinstance(parsed, dict):
prompt = str(parsed.get("prompt") or parsed.get("Prompt") or "").strip()
prompt_id = str(parsed.get("id") or parsed.get("prompt_id") or len(items) + 1)
metadata = {key: value for key, value in parsed.items() if key not in {"prompt", "Prompt"}}
else:
raise ValueError(f"JSONL row {row_number + 1} must be a string or object.")
if prompt:
items.append(PromptItem(prompt_id=prompt_id, row_number=row_number, prompt=prompt, metadata=metadata))
return items
def _load_csv_prompts(path: Path) -> list[PromptItem]:
items: list[PromptItem] = []
with path.open(newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row_number, row in enumerate(reader):
prompt_key = _first_present_key(row, PROMPT_COLUMNS)
if prompt_key is None:
raise ValueError(f"CSV must include one of these prompt columns: {', '.join(PROMPT_COLUMNS)}.")
prompt = str(row.get(prompt_key) or "").strip()
if not prompt:
continue
id_key = _first_present_key(row, ID_COLUMNS)
prompt_id = str(row.get(id_key) or len(items) + 1) if id_key is not None else str(len(items) + 1)
items.append(PromptItem(prompt_id=prompt_id, row_number=row_number, prompt=prompt, metadata=dict(row)))
return items
def _first_present_key(row: dict[str, Any], keys: tuple[str, ...]) -> str | None:
for key in keys:
if key in row:
return key
return None
def _validate_unique_output_dirs(items: list[PromptItem]) -> None:
seen: dict[str, str] = {}
for item in items:
dirname = prompt_dir_name(item)
previous = seen.get(dirname)
if previous is not None:
raise ValueError(f"Prompt ids {previous!r} and {item.prompt_id!r} map to the same output dir {dirname!r}.")
seen[dirname] = item.prompt_id
def validate_t2i_json(data: dict[str, Any], prompt_id: str) -> None:
"""Validate the minimum structured T2I JSON shape expected by Cosmos3."""
missing = sorted(REQUIRED_T2I_KEYS - set(data))
if missing:
raise ValueError(f"Prompt JSON for {prompt_id} is missing required keys: {missing}")
if not isinstance(data.get("subjects"), list):
raise ValueError(f"Prompt JSON for {prompt_id}: subjects must be a list.")
if not isinstance(data.get("text_and_signage_elements"), list):
raise ValueError(f"Prompt JSON for {prompt_id}: text_and_signage_elements must be a list.")
caption = data.get("comprehensive_t2i_caption")
if not isinstance(caption, str) or not caption.strip():
raise ValueError(f"Prompt JSON for {prompt_id}: comprehensive_t2i_caption is empty.")
resolution = data.get("resolution")
if not isinstance(resolution, dict) or not {"H", "W"}.issubset(resolution):
raise ValueError(f"Prompt JSON for {prompt_id}: resolution must contain H and W.")
aspect_ratio = data.get("aspect_ratio")
if not isinstance(aspect_ratio, str) or not aspect_ratio.strip():
raise ValueError(f"Prompt JSON for {prompt_id}: aspect_ratio must be a non-empty string.")