Cosmos3-Super-Text2Image / tests /test_agentic_upsampling.py
mingyuliutw's picture
Super-squash branch 'main' using huggingface_hub
fdafd05
from __future__ import annotations
import base64
import io
import json
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from PIL import Image
from agentic_upsampling.clients import ImageGenerationClient, PromptRewriterClient
from agentic_upsampling.constants import (
DEFAULT_CRITIC_ENDPOINT_URL,
DEFAULT_CRITIC_MODEL,
DEFAULT_FLOW_SHIFT,
DEFAULT_GENERATION_EXTRA_ARGS,
DEFAULT_GENERATION_MODEL,
DEFAULT_LLM_EXTRA_BODY,
DEFAULT_REWRITER_MODEL,
)
from agentic_upsampling.data import PromptItem, load_prompt_items, prompt_dir_name
from agentic_upsampling.extract_best import extract_best_images
from agentic_upsampling.prompt_upsampler import (
Text2ImagePromptUpsampler,
apply_t2i_output_parameters,
normalize_openai_base_url,
)
from agentic_upsampling.rubric import parse_analysis_response
from agentic_upsampling.runner import AgenticUpsamplerRunner, RunnerConfig
def _item(prompt_id: str = "1", prompt: str = "a red cube") -> PromptItem:
return PromptItem(prompt_id=prompt_id, row_number=0, prompt=prompt)
def _valid_t2i_prompt(caption: str) -> dict[str, Any]:
return {
"subjects": [],
"subject_details": {},
"background_setting": "plain studio",
"lighting": {"conditions": "soft", "direction": "front", "shadows": "soft", "illumination_effect": "clear"},
"aesthetics": {
"composition": "centered",
"color_scheme": "balanced",
"mood_atmosphere": "precise",
"patterns": "",
},
"cinematography": {
"framing": "centered",
"camera_angle": "eye-level",
"depth_of_field": "deep",
"focus": "sharp",
"lens_focal_length": "standard",
},
"style_medium": "digital render",
"artistic_style": "clean realistic render",
"context": "test prompt",
"text_and_signage_elements": [],
"quadrant_scan": {
"top_left": "",
"top_right": "",
"bottom_left": "",
"bottom_right": "",
"absolute_center": "",
},
"comprehensive_t2i_caption": caption,
"resolution": {"H": 960, "W": 960},
"aspect_ratio": "1,1",
}
class FakeChatClient:
messages: list[dict[str, Any]]
response_format_json: bool
def __init__(self, response: dict[str, Any]) -> None:
self.response = response
self.messages = []
self.response_format_json = False
def complete(self, messages: list[dict[str, Any]], *, response_format_json: bool = False) -> str:
self.messages = messages
self.response_format_json = response_format_json
return json.dumps(self.response)
def test_defaults_are_public_provider_defaults() -> None:
assert DEFAULT_REWRITER_MODEL == "gpt-5.5"
assert DEFAULT_LLM_EXTRA_BODY == {"reasoning_effort": "low"}
assert DEFAULT_CRITIC_MODEL == "gemini-3.1-pro-preview"
assert DEFAULT_CRITIC_ENDPOINT_URL == "https://generativelanguage.googleapis.com/v1beta/openai/"
def test_gemini_openai_compatible_base_url_is_not_modified() -> None:
assert (
normalize_openai_base_url("https://generativelanguage.googleapis.com/v1beta/openai/")
== "https://generativelanguage.googleapis.com/v1beta/openai"
)
assert (
normalize_openai_base_url("https://generativelanguage.googleapis.com/v1beta/openai/chat/completions")
== "https://generativelanguage.googleapis.com/v1beta/openai"
)
def test_prompt_loaders_support_text_jsonl_and_csv(tmp_path: Path) -> None:
txt_path = tmp_path / "prompts.txt"
txt_path.write_text("one\n\ntwo\n", encoding="utf-8")
assert [item.prompt for item in load_prompt_items(prompts_path=txt_path)] == ["one", "two"]
jsonl_path = tmp_path / "prompts.jsonl"
jsonl_path.write_text('{"id":"custom id","prompt":"three"}\n"four"\n', encoding="utf-8")
jsonl_items = load_prompt_items(prompts_path=jsonl_path)
assert [item.prompt for item in jsonl_items] == ["three", "four"]
assert prompt_dir_name(jsonl_items[0]) == "custom_id"
csv_path = tmp_path / "prompts.csv"
csv_path.write_text("id,prompt\nfive_id,five\n", encoding="utf-8")
csv_items = load_prompt_items(prompts_path=csv_path)
assert csv_items[0].prompt_id == "five_id"
assert csv_items[0].prompt == "five"
def test_prompt_upsampler_applies_resolution_and_requests_json() -> None:
prompt_json = _valid_t2i_prompt("initial cube prompt")
fake_client = FakeChatClient(prompt_json)
upsampler = Text2ImagePromptUpsampler(fake_client) # type: ignore[arg-type]
result = upsampler.upsample("a cube", prompt_id="cube", resolution="720", aspect_ratio="16,9")
assert result["resolution"] == {"H": 720, "W": 1280}
assert result["aspect_ratio"] == "16,9"
assert fake_client.response_format_json is True
def test_apply_t2i_output_parameters_rejects_bad_canvas() -> None:
try:
apply_t2i_output_parameters(_valid_t2i_prompt("x"), resolution="999", aspect_ratio="1,1")
except ValueError as exc:
assert "Unsupported resolution" in str(exc)
else:
raise AssertionError("Expected unsupported resolution error.")
def test_prompt_rewriter_joint_rewrite_uses_vlm_feedback() -> None:
previous_prompt = _valid_t2i_prompt("old cube prompt")
rewritten_prompt = _valid_t2i_prompt("new cube prompt with no 4x4 grid")
analysis = {
"overall_score": 2.0,
"prompt_adherence_score": 3.0,
"category_score": 3.0,
"issues": [
{
"category": "geometry",
"description": "Generated a 4x4 grid instead of a 3x3 cube.",
"severity": "severe",
}
],
"improvement_directives": ["Strictly enforce 3x3x3 geometry."],
"raw_response": "large omitted blob",
}
rewriter = PromptRewriterClient(api_token="unused")
fake_client = FakeChatClient({"positive_prompt": rewritten_prompt, "negative_prompt": "4x4 grid"})
rewriter.rewrite_client = fake_client # type: ignore[assignment]
positive_prompt, negative_prompt = rewriter.rewrite_prompt_pair(
_item("39", "A Rubik's cube mid twist with the top layer rotated exactly 45 degrees"),
previous_prompt,
"",
analysis,
[{"iteration": 0, "analysis": analysis}],
)
assert positive_prompt["comprehensive_t2i_caption"] == "new cube prompt with no 4x4 grid"
assert negative_prompt == "4x4 grid"
assert fake_client.response_format_json is True
user_message = str(fake_client.messages[1]["content"])
assert "Generated a 4x4 grid" in user_message
assert "Strictly enforce 3x3x3 geometry" in user_message
assert "raw_response" not in user_message
def test_generation_payload_uses_vllm_omni_images_api() -> None:
client = ImageGenerationClient(endpoint="https://example.test/v1", model="test/model")
payload = client.build_payload({"comprehensive_t2i_caption": "x"}, prompt_id="3", seed=100, negative_prompt="blur")
assert client.endpoint == "https://example.test"
assert payload["model"] == "test/model"
assert payload["prompt"] == '{"comprehensive_t2i_caption":"x"}'
assert payload["size"] == "1024x1024"
assert payload["n"] == 1
assert payload["response_format"] == "b64_json"
assert payload["negative_prompt"] == "blur"
assert payload["num_inference_steps"] == 50
assert payload["guidance_scale"] == 4.0
assert payload["flow_shift"] == DEFAULT_FLOW_SHIFT
assert payload["extra_args"] == DEFAULT_GENERATION_EXTRA_ARGS
assert payload["seed"] == 100
assert "model_mode" not in payload
assert "prompt_upsampling" not in payload
def test_generation_payload_allows_custom_extra_args() -> None:
client = ImageGenerationClient(endpoint="https://example.test", extra_args={"guardrails": True})
payload = client.build_payload({"comprehensive_t2i_caption": "x"}, prompt_id="3")
assert payload["extra_args"] == {"guardrails": True}
class FakeImageResponse:
ok: bool = True
status_code: int = 200
text: str = "ok"
def __init__(self, payload: dict[str, Any]) -> None:
self.payload = payload
def json(self) -> dict[str, Any]:
return self.payload
class FakeImageSession:
calls: list[dict[str, Any]]
def __init__(self, response_payload: dict[str, Any]) -> None:
self.response_payload = response_payload
self.calls = []
def request(self, method: str, url: str, **kwargs: Any) -> FakeImageResponse:
self.calls.append({"method": method, "url": url, "kwargs": kwargs})
return FakeImageResponse(self.response_payload)
def _tiny_png_b64() -> str:
buf = io.BytesIO()
Image.new("RGB", (4, 4), (0, 255, 0)).save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("ascii")
def test_generation_client_decodes_vllm_omni_b64_response(tmp_path: Path) -> None:
session = FakeImageSession({"created": 1, "data": [{"b64_json": _tiny_png_b64(), "revised_prompt": None}]})
client = ImageGenerationClient(endpoint="example.test", auth_key="secret-token", session=session) # type: ignore[arg-type]
result = client.generate(prompt_json=_valid_t2i_prompt("x"), prompt_id="3", output_dir=tmp_path, seed=5)
assert result.image_path.exists()
assert session.calls[0]["method"] == "POST"
assert session.calls[0]["url"] == "https://example.test/v1/images/generations"
assert session.calls[0]["kwargs"]["headers"] == {"Authorization": "Bearer secret-token"}
assert session.calls[0]["kwargs"]["json"]["model"] == DEFAULT_GENERATION_MODEL
meta = json.loads(result.meta_path.read_text(encoding="utf-8"))
assert meta["status"] == "completed"
assert meta["response"]["data"][0]["b64_json"].startswith("<base64 image omitted:")
def test_parse_analysis_response_sets_threshold_flag() -> None:
analysis = parse_analysis_response(
"""
{
"prompt_adherence_score": 9,
"visual_quality_score": 9,
"aesthetics_score": 8.5,
"physical_plausibility_score": 8,
"category_score": 9,
"text_rendering_score": 9,
"photorealism_score": null,
"overall_score": 9.1,
"issues": [],
"category_findings": {},
"improvement_directives": [],
"rationale": "Strong."
}
""",
)
assert analysis["threshold_cleared"] is True
class FakeRewriter:
initial_calls: int
joint_rewrite_calls: int
previous_scores: list[float]
def __init__(self) -> None:
self.initial_calls = 0
self.joint_rewrite_calls = 0
self.previous_scores = []
def initial_prompt(self, item: PromptItem) -> dict[str, Any]:
self.initial_calls += 1
return _valid_t2i_prompt(f"initial {item.prompt_id}")
def rewrite_prompt_pair(
self,
item: PromptItem,
previous_prompt: dict[str, Any],
previous_negative_prompt: str,
previous_analysis: dict[str, Any],
history: list[dict[str, Any]],
) -> tuple[dict[str, Any], str]:
self.joint_rewrite_calls += 1
self.previous_scores.append(float(previous_analysis["overall_score"]))
return _valid_t2i_prompt(f"rewrite {len(history)}"), f"negative {len(history)}"
@dataclass(frozen=True, slots=True)
class FakeGeneration:
image_path: Path
meta_path: Path
meta: dict[str, Any]
class FakeGenerator:
seeds: list[int | None]
negative_prompts: list[str]
def __init__(self) -> None:
self.seeds = []
self.negative_prompts = []
def generate(
self,
*,
prompt_json: dict[str, Any],
prompt_id: str,
output_dir: Path,
seed: int | None = None,
negative_prompt: str = "",
jpeg_quality: int = 95,
) -> FakeGeneration:
self.seeds.append(seed)
self.negative_prompts.append(negative_prompt)
output_dir.mkdir(parents=True, exist_ok=True)
image_path = output_dir / "image.jpg"
Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path)
meta_path = output_dir / "generation_meta.json"
meta_path.write_text('{"status":"completed"}\n', encoding="utf-8")
return FakeGeneration(image_path=image_path, meta_path=meta_path, meta={"status": "completed"})
class BarrierGenerator(FakeGenerator):
barrier: threading.Barrier
lock: threading.Lock
def __init__(self, parties: int) -> None:
super().__init__()
self.barrier = threading.Barrier(parties)
self.lock = threading.Lock()
def generate(
self,
*,
prompt_json: dict[str, Any],
prompt_id: str,
output_dir: Path,
seed: int | None = None,
negative_prompt: str = "",
jpeg_quality: int = 95,
) -> FakeGeneration:
with self.lock:
self.seeds.append(seed)
self.negative_prompts.append(negative_prompt)
self.barrier.wait(timeout=2.0)
output_dir.mkdir(parents=True, exist_ok=True)
image_path = output_dir / "image.jpg"
Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path)
meta_path = output_dir / "generation_meta.json"
meta_path.write_text('{"status":"completed"}\n', encoding="utf-8")
return FakeGeneration(image_path=image_path, meta_path=meta_path, meta={"status": "completed"})
class FakeJudge:
calls: int
scores: list[float]
def __init__(self, scores: list[float]) -> None:
self.calls = 0
self.scores = scores
def score_image(
self,
*,
item: PromptItem,
image_path: Path,
) -> dict[str, Any]:
score = self.scores[self.calls]
self.calls += 1
return {
"overall_score": score,
"prompt_adherence_score": score,
"visual_quality_score": score,
"aesthetics_score": score,
"physical_plausibility_score": score,
"category_score": score,
"issues": [],
"improvement_directives": [],
"threshold_cleared": score >= 9,
}
def test_runner_early_stops_by_default(tmp_path: Path) -> None:
rewriter = FakeRewriter()
generator = FakeGenerator()
runner = AgenticUpsamplerRunner(
rewriter=rewriter,
generator=generator, # type: ignore[arg-type]
judge=FakeJudge([9.1, 8.0]),
config=RunnerConfig(output_dir=tmp_path, max_iterations=3, samples_per_iteration=1),
)
result = runner.run_item(_item())
assert result["best_iteration"] == 0
assert rewriter.initial_calls == 1
assert rewriter.joint_rewrite_calls == 0
assert generator.seeds == [None]
def test_runner_can_disable_early_stop_and_select_best_sample(tmp_path: Path) -> None:
rewriter = FakeRewriter()
generator = FakeGenerator()
runner = AgenticUpsamplerRunner(
rewriter=rewriter,
generator=generator, # type: ignore[arg-type]
judge=FakeJudge([5.0, 9.0, 7.0, 6.0, 10.0, 8.0]),
config=RunnerConfig(
output_dir=tmp_path,
max_iterations=2,
samples_per_iteration=3,
seed_base=1000,
early_stop=False,
),
)
result = runner.run_item(_item("8", "exactly 12 balloons with exact color counts"))
assert generator.seeds == [1000, 1001, 1002, 1000, 1001, 1002]
assert rewriter.previous_scores == [9.0]
assert result["best_iteration"] == 1
assert result["best"]["selected_sample_index"] == 1
assert result["iterations"][0]["selected_sample_index"] == 1
def test_runner_generates_seed_samples_in_parallel(tmp_path: Path) -> None:
rewriter = FakeRewriter()
generator = BarrierGenerator(parties=3)
runner = AgenticUpsamplerRunner(
rewriter=rewriter,
generator=generator, # type: ignore[arg-type]
judge=FakeJudge([5.0, 6.0, 7.0]),
config=RunnerConfig(
output_dir=tmp_path,
max_iterations=1,
samples_per_iteration=3,
seed_base=2000,
early_stop=False,
),
)
result = runner.run_item(_item("parallel", "a parallel seed test"))
assert sorted(generator.seeds) == [2000, 2001, 2002]
assert result["best"]["selected_sample_index"] == 2
assert result["iterations"][0]["sample_count"] == 3
def test_extract_best_images_copies_images_and_writes_manifests(tmp_path: Path) -> None:
output_dir = tmp_path / "run"
image_dir = output_dir / "0001" / "iter_00"
image_dir.mkdir(parents=True)
image_path = image_dir / "image.jpg"
Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path)
best_json = {
"prompt_id": "1",
"prompt": "a red square",
"best_iteration": 0,
"best_score": 9.25,
"threshold_cleared_any": True,
"best": {
"selected_sample_index": 0,
"image_path": str(image_path),
"analysis_path": str(image_dir / "analysis.json"),
},
"iterations": [],
}
(output_dir / "0001" / "best.json").write_text(json.dumps(best_json), encoding="utf-8")
records = extract_best_images(output_dir, tmp_path / "export")
assert len(records) == 1
copied_path = Path(records[0]["copied_image_path"])
assert copied_path.exists()
assert copied_path.name == "1.jpg"
assert (tmp_path / "export" / "best_generations.jsonl").exists()
assert (tmp_path / "export" / "best_generations.csv").exists()