from __future__ import annotations import base64 import io import json import threading from dataclasses import dataclass from pathlib import Path from typing import Any from PIL import Image from agentic_upsampling.clients import ImageGenerationClient, PromptRewriterClient from agentic_upsampling.constants import ( DEFAULT_CRITIC_ENDPOINT_URL, DEFAULT_CRITIC_MODEL, DEFAULT_FLOW_SHIFT, DEFAULT_GENERATION_EXTRA_ARGS, DEFAULT_GENERATION_MODEL, DEFAULT_LLM_EXTRA_BODY, DEFAULT_REWRITER_MODEL, ) from agentic_upsampling.data import PromptItem, load_prompt_items, prompt_dir_name from agentic_upsampling.extract_best import extract_best_images from agentic_upsampling.prompt_upsampler import ( Text2ImagePromptUpsampler, apply_t2i_output_parameters, normalize_openai_base_url, ) from agentic_upsampling.rubric import parse_analysis_response from agentic_upsampling.runner import AgenticUpsamplerRunner, RunnerConfig def _item(prompt_id: str = "1", prompt: str = "a red cube") -> PromptItem: return PromptItem(prompt_id=prompt_id, row_number=0, prompt=prompt) def _valid_t2i_prompt(caption: str) -> dict[str, Any]: return { "subjects": [], "subject_details": {}, "background_setting": "plain studio", "lighting": {"conditions": "soft", "direction": "front", "shadows": "soft", "illumination_effect": "clear"}, "aesthetics": { "composition": "centered", "color_scheme": "balanced", "mood_atmosphere": "precise", "patterns": "", }, "cinematography": { "framing": "centered", "camera_angle": "eye-level", "depth_of_field": "deep", "focus": "sharp", "lens_focal_length": "standard", }, "style_medium": "digital render", "artistic_style": "clean realistic render", "context": "test prompt", "text_and_signage_elements": [], "quadrant_scan": { "top_left": "", "top_right": "", "bottom_left": "", "bottom_right": "", "absolute_center": "", }, "comprehensive_t2i_caption": caption, "resolution": {"H": 960, "W": 960}, "aspect_ratio": "1,1", } class FakeChatClient: messages: list[dict[str, Any]] response_format_json: bool def __init__(self, response: dict[str, Any]) -> None: self.response = response self.messages = [] self.response_format_json = False def complete(self, messages: list[dict[str, Any]], *, response_format_json: bool = False) -> str: self.messages = messages self.response_format_json = response_format_json return json.dumps(self.response) def test_defaults_are_public_provider_defaults() -> None: assert DEFAULT_REWRITER_MODEL == "gpt-5.5" assert DEFAULT_LLM_EXTRA_BODY == {"reasoning_effort": "low"} assert DEFAULT_CRITIC_MODEL == "gemini-3.1-pro-preview" assert DEFAULT_CRITIC_ENDPOINT_URL == "https://generativelanguage.googleapis.com/v1beta/openai/" def test_gemini_openai_compatible_base_url_is_not_modified() -> None: assert ( normalize_openai_base_url("https://generativelanguage.googleapis.com/v1beta/openai/") == "https://generativelanguage.googleapis.com/v1beta/openai" ) assert ( normalize_openai_base_url("https://generativelanguage.googleapis.com/v1beta/openai/chat/completions") == "https://generativelanguage.googleapis.com/v1beta/openai" ) def test_prompt_loaders_support_text_jsonl_and_csv(tmp_path: Path) -> None: txt_path = tmp_path / "prompts.txt" txt_path.write_text("one\n\ntwo\n", encoding="utf-8") assert [item.prompt for item in load_prompt_items(prompts_path=txt_path)] == ["one", "two"] jsonl_path = tmp_path / "prompts.jsonl" jsonl_path.write_text('{"id":"custom id","prompt":"three"}\n"four"\n', encoding="utf-8") jsonl_items = load_prompt_items(prompts_path=jsonl_path) assert [item.prompt for item in jsonl_items] == ["three", "four"] assert prompt_dir_name(jsonl_items[0]) == "custom_id" csv_path = tmp_path / "prompts.csv" csv_path.write_text("id,prompt\nfive_id,five\n", encoding="utf-8") csv_items = load_prompt_items(prompts_path=csv_path) assert csv_items[0].prompt_id == "five_id" assert csv_items[0].prompt == "five" def test_prompt_upsampler_applies_resolution_and_requests_json() -> None: prompt_json = _valid_t2i_prompt("initial cube prompt") fake_client = FakeChatClient(prompt_json) upsampler = Text2ImagePromptUpsampler(fake_client) # type: ignore[arg-type] result = upsampler.upsample("a cube", prompt_id="cube", resolution="720", aspect_ratio="16,9") assert result["resolution"] == {"H": 720, "W": 1280} assert result["aspect_ratio"] == "16,9" assert fake_client.response_format_json is True def test_apply_t2i_output_parameters_rejects_bad_canvas() -> None: try: apply_t2i_output_parameters(_valid_t2i_prompt("x"), resolution="999", aspect_ratio="1,1") except ValueError as exc: assert "Unsupported resolution" in str(exc) else: raise AssertionError("Expected unsupported resolution error.") def test_prompt_rewriter_joint_rewrite_uses_vlm_feedback() -> None: previous_prompt = _valid_t2i_prompt("old cube prompt") rewritten_prompt = _valid_t2i_prompt("new cube prompt with no 4x4 grid") analysis = { "overall_score": 2.0, "prompt_adherence_score": 3.0, "category_score": 3.0, "issues": [ { "category": "geometry", "description": "Generated a 4x4 grid instead of a 3x3 cube.", "severity": "severe", } ], "improvement_directives": ["Strictly enforce 3x3x3 geometry."], "raw_response": "large omitted blob", } rewriter = PromptRewriterClient(api_token="unused") fake_client = FakeChatClient({"positive_prompt": rewritten_prompt, "negative_prompt": "4x4 grid"}) rewriter.rewrite_client = fake_client # type: ignore[assignment] positive_prompt, negative_prompt = rewriter.rewrite_prompt_pair( _item("39", "A Rubik's cube mid twist with the top layer rotated exactly 45 degrees"), previous_prompt, "", analysis, [{"iteration": 0, "analysis": analysis}], ) assert positive_prompt["comprehensive_t2i_caption"] == "new cube prompt with no 4x4 grid" assert negative_prompt == "4x4 grid" assert fake_client.response_format_json is True user_message = str(fake_client.messages[1]["content"]) assert "Generated a 4x4 grid" in user_message assert "Strictly enforce 3x3x3 geometry" in user_message assert "raw_response" not in user_message def test_generation_payload_uses_vllm_omni_images_api() -> None: client = ImageGenerationClient(endpoint="https://example.test/v1", model="test/model") payload = client.build_payload({"comprehensive_t2i_caption": "x"}, prompt_id="3", seed=100, negative_prompt="blur") assert client.endpoint == "https://example.test" assert payload["model"] == "test/model" assert payload["prompt"] == '{"comprehensive_t2i_caption":"x"}' assert payload["size"] == "1024x1024" assert payload["n"] == 1 assert payload["response_format"] == "b64_json" assert payload["negative_prompt"] == "blur" assert payload["num_inference_steps"] == 50 assert payload["guidance_scale"] == 4.0 assert payload["flow_shift"] == DEFAULT_FLOW_SHIFT assert payload["extra_args"] == DEFAULT_GENERATION_EXTRA_ARGS assert payload["seed"] == 100 assert "model_mode" not in payload assert "prompt_upsampling" not in payload def test_generation_payload_allows_custom_extra_args() -> None: client = ImageGenerationClient(endpoint="https://example.test", extra_args={"guardrails": True}) payload = client.build_payload({"comprehensive_t2i_caption": "x"}, prompt_id="3") assert payload["extra_args"] == {"guardrails": True} class FakeImageResponse: ok: bool = True status_code: int = 200 text: str = "ok" def __init__(self, payload: dict[str, Any]) -> None: self.payload = payload def json(self) -> dict[str, Any]: return self.payload class FakeImageSession: calls: list[dict[str, Any]] def __init__(self, response_payload: dict[str, Any]) -> None: self.response_payload = response_payload self.calls = [] def request(self, method: str, url: str, **kwargs: Any) -> FakeImageResponse: self.calls.append({"method": method, "url": url, "kwargs": kwargs}) return FakeImageResponse(self.response_payload) def _tiny_png_b64() -> str: buf = io.BytesIO() Image.new("RGB", (4, 4), (0, 255, 0)).save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode("ascii") def test_generation_client_decodes_vllm_omni_b64_response(tmp_path: Path) -> None: session = FakeImageSession({"created": 1, "data": [{"b64_json": _tiny_png_b64(), "revised_prompt": None}]}) client = ImageGenerationClient(endpoint="example.test", auth_key="secret-token", session=session) # type: ignore[arg-type] result = client.generate(prompt_json=_valid_t2i_prompt("x"), prompt_id="3", output_dir=tmp_path, seed=5) assert result.image_path.exists() assert session.calls[0]["method"] == "POST" assert session.calls[0]["url"] == "https://example.test/v1/images/generations" assert session.calls[0]["kwargs"]["headers"] == {"Authorization": "Bearer secret-token"} assert session.calls[0]["kwargs"]["json"]["model"] == DEFAULT_GENERATION_MODEL meta = json.loads(result.meta_path.read_text(encoding="utf-8")) assert meta["status"] == "completed" assert meta["response"]["data"][0]["b64_json"].startswith(" None: analysis = parse_analysis_response( """ { "prompt_adherence_score": 9, "visual_quality_score": 9, "aesthetics_score": 8.5, "physical_plausibility_score": 8, "category_score": 9, "text_rendering_score": 9, "photorealism_score": null, "overall_score": 9.1, "issues": [], "category_findings": {}, "improvement_directives": [], "rationale": "Strong." } """, ) assert analysis["threshold_cleared"] is True class FakeRewriter: initial_calls: int joint_rewrite_calls: int previous_scores: list[float] def __init__(self) -> None: self.initial_calls = 0 self.joint_rewrite_calls = 0 self.previous_scores = [] def initial_prompt(self, item: PromptItem) -> dict[str, Any]: self.initial_calls += 1 return _valid_t2i_prompt(f"initial {item.prompt_id}") def rewrite_prompt_pair( self, item: PromptItem, previous_prompt: dict[str, Any], previous_negative_prompt: str, previous_analysis: dict[str, Any], history: list[dict[str, Any]], ) -> tuple[dict[str, Any], str]: self.joint_rewrite_calls += 1 self.previous_scores.append(float(previous_analysis["overall_score"])) return _valid_t2i_prompt(f"rewrite {len(history)}"), f"negative {len(history)}" @dataclass(frozen=True, slots=True) class FakeGeneration: image_path: Path meta_path: Path meta: dict[str, Any] class FakeGenerator: seeds: list[int | None] negative_prompts: list[str] def __init__(self) -> None: self.seeds = [] self.negative_prompts = [] def generate( self, *, prompt_json: dict[str, Any], prompt_id: str, output_dir: Path, seed: int | None = None, negative_prompt: str = "", jpeg_quality: int = 95, ) -> FakeGeneration: self.seeds.append(seed) self.negative_prompts.append(negative_prompt) output_dir.mkdir(parents=True, exist_ok=True) image_path = output_dir / "image.jpg" Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path) meta_path = output_dir / "generation_meta.json" meta_path.write_text('{"status":"completed"}\n', encoding="utf-8") return FakeGeneration(image_path=image_path, meta_path=meta_path, meta={"status": "completed"}) class BarrierGenerator(FakeGenerator): barrier: threading.Barrier lock: threading.Lock def __init__(self, parties: int) -> None: super().__init__() self.barrier = threading.Barrier(parties) self.lock = threading.Lock() def generate( self, *, prompt_json: dict[str, Any], prompt_id: str, output_dir: Path, seed: int | None = None, negative_prompt: str = "", jpeg_quality: int = 95, ) -> FakeGeneration: with self.lock: self.seeds.append(seed) self.negative_prompts.append(negative_prompt) self.barrier.wait(timeout=2.0) output_dir.mkdir(parents=True, exist_ok=True) image_path = output_dir / "image.jpg" Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path) meta_path = output_dir / "generation_meta.json" meta_path.write_text('{"status":"completed"}\n', encoding="utf-8") return FakeGeneration(image_path=image_path, meta_path=meta_path, meta={"status": "completed"}) class FakeJudge: calls: int scores: list[float] def __init__(self, scores: list[float]) -> None: self.calls = 0 self.scores = scores def score_image( self, *, item: PromptItem, image_path: Path, ) -> dict[str, Any]: score = self.scores[self.calls] self.calls += 1 return { "overall_score": score, "prompt_adherence_score": score, "visual_quality_score": score, "aesthetics_score": score, "physical_plausibility_score": score, "category_score": score, "issues": [], "improvement_directives": [], "threshold_cleared": score >= 9, } def test_runner_early_stops_by_default(tmp_path: Path) -> None: rewriter = FakeRewriter() generator = FakeGenerator() runner = AgenticUpsamplerRunner( rewriter=rewriter, generator=generator, # type: ignore[arg-type] judge=FakeJudge([9.1, 8.0]), config=RunnerConfig(output_dir=tmp_path, max_iterations=3, samples_per_iteration=1), ) result = runner.run_item(_item()) assert result["best_iteration"] == 0 assert rewriter.initial_calls == 1 assert rewriter.joint_rewrite_calls == 0 assert generator.seeds == [None] def test_runner_can_disable_early_stop_and_select_best_sample(tmp_path: Path) -> None: rewriter = FakeRewriter() generator = FakeGenerator() runner = AgenticUpsamplerRunner( rewriter=rewriter, generator=generator, # type: ignore[arg-type] judge=FakeJudge([5.0, 9.0, 7.0, 6.0, 10.0, 8.0]), config=RunnerConfig( output_dir=tmp_path, max_iterations=2, samples_per_iteration=3, seed_base=1000, early_stop=False, ), ) result = runner.run_item(_item("8", "exactly 12 balloons with exact color counts")) assert generator.seeds == [1000, 1001, 1002, 1000, 1001, 1002] assert rewriter.previous_scores == [9.0] assert result["best_iteration"] == 1 assert result["best"]["selected_sample_index"] == 1 assert result["iterations"][0]["selected_sample_index"] == 1 def test_runner_generates_seed_samples_in_parallel(tmp_path: Path) -> None: rewriter = FakeRewriter() generator = BarrierGenerator(parties=3) runner = AgenticUpsamplerRunner( rewriter=rewriter, generator=generator, # type: ignore[arg-type] judge=FakeJudge([5.0, 6.0, 7.0]), config=RunnerConfig( output_dir=tmp_path, max_iterations=1, samples_per_iteration=3, seed_base=2000, early_stop=False, ), ) result = runner.run_item(_item("parallel", "a parallel seed test")) assert sorted(generator.seeds) == [2000, 2001, 2002] assert result["best"]["selected_sample_index"] == 2 assert result["iterations"][0]["sample_count"] == 3 def test_extract_best_images_copies_images_and_writes_manifests(tmp_path: Path) -> None: output_dir = tmp_path / "run" image_dir = output_dir / "0001" / "iter_00" image_dir.mkdir(parents=True) image_path = image_dir / "image.jpg" Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path) best_json = { "prompt_id": "1", "prompt": "a red square", "best_iteration": 0, "best_score": 9.25, "threshold_cleared_any": True, "best": { "selected_sample_index": 0, "image_path": str(image_path), "analysis_path": str(image_dir / "analysis.json"), }, "iterations": [], } (output_dir / "0001" / "best.json").write_text(json.dumps(best_json), encoding="utf-8") records = extract_best_images(output_dir, tmp_path / "export") assert len(records) == 1 copied_path = Path(records[0]["copied_image_path"]) assert copied_path.exists() assert copied_path.name == "1.jpg" assert (tmp_path / "export" / "best_generations.jsonl").exists() assert (tmp_path / "export" / "best_generations.csv").exists()