"""CLI for standalone agentic Cosmos3 text-to-image prompt upsampling.""" from __future__ import annotations import argparse import json from pathlib import Path from agentic_upsampling.clients import ( ImageGenerationClient, PromptRewriterClient, VLMQualityJudge, read_api_token, read_optional_generation_auth_key, ) from agentic_upsampling.constants import ( DEFAULT_ASPECT_RATIO, DEFAULT_CRITIC_ENDPOINT_URL, DEFAULT_CRITIC_MODEL, DEFAULT_FLOW_SHIFT, DEFAULT_GENERATION_AUTH_KEY_ENV, DEFAULT_GENERATION_EXTRA_ARGS, DEFAULT_GENERATION_MODEL, DEFAULT_GEMINI_API_KEY_ENV, DEFAULT_GUIDANCE, DEFAULT_IMAGE_SIZE, DEFAULT_LLM_EXTRA_BODY, DEFAULT_MAX_ITERATIONS, DEFAULT_NUM_STEPS, DEFAULT_OPENAI_API_KEY_ENV, DEFAULT_RESOLUTION, DEFAULT_REWRITER_ENDPOINT_URL, DEFAULT_REWRITER_MODEL, DEFAULT_SAMPLES_PER_ITERATION, DEFAULT_UPSAMPLER_ENDPOINT_URL, DEFAULT_UPSAMPLER_MODEL, ) from agentic_upsampling.data import load_prompt_items from agentic_upsampling.extract_best import extract_best_images from agentic_upsampling.io_utils import write_json_atomic from agentic_upsampling.runner import AgenticUpsamplerRunner, RunnerConfig, write_run_manifest def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) input_group = parser.add_mutually_exclusive_group(required=True) input_group.add_argument("--prompt", default=None, help="Single text prompt to run.") input_group.add_argument("--prompts", type=Path, default=None, help="Path to .txt, .jsonl, or .csv prompts.") parser.add_argument("--limit", type=int, default=None, help="Optional maximum number of prompts to run.") parser.add_argument("--output-dir", type=Path, required=True) parser.add_argument("--overwrite", action="store_true") parser.add_argument("--max-iterations", type=int, default=DEFAULT_MAX_ITERATIONS) parser.add_argument("--samples-per-iteration", type=int, default=DEFAULT_SAMPLES_PER_ITERATION) parser.add_argument("--seed-base", type=int, default=None) parser.add_argument("--disable-early-stop", action="store_true") parser.add_argument("--quiet", action="store_true") parser.add_argument("--extract-best", action="store_true", help="Copy best images after the run finishes.") parser.add_argument("--generation-endpoint", required=True) parser.add_argument("--generation-model", default=DEFAULT_GENERATION_MODEL) parser.add_argument("--size", default=DEFAULT_IMAGE_SIZE, help="vLLM-Omni image size in WIDTHxHEIGHT format.") parser.add_argument("--generation-auth-key", default="") parser.add_argument("--generation-auth-key-env", default=DEFAULT_GENERATION_AUTH_KEY_ENV) parser.add_argument("--resolution", default=DEFAULT_RESOLUTION) parser.add_argument("--aspect-ratio", default=DEFAULT_ASPECT_RATIO) parser.add_argument("--num-steps", type=int, default=DEFAULT_NUM_STEPS) parser.add_argument("--guidance", type=float, default=DEFAULT_GUIDANCE) parser.add_argument("--flow-shift", type=float, default=DEFAULT_FLOW_SHIFT) parser.add_argument("--generation-extra-args", type=json.loads, default=DEFAULT_GENERATION_EXTRA_ARGS) parser.add_argument("--upsampler-endpoint-url", default=DEFAULT_UPSAMPLER_ENDPOINT_URL) parser.add_argument("--upsampler-model", default=DEFAULT_UPSAMPLER_MODEL) parser.add_argument("--rewriter-endpoint-url", default=DEFAULT_REWRITER_ENDPOINT_URL) parser.add_argument("--rewriter-model", default=DEFAULT_REWRITER_MODEL) parser.add_argument("--openai-api-key-env", default=DEFAULT_OPENAI_API_KEY_ENV) parser.add_argument("--openai-api-key-file", type=Path, default=None) parser.add_argument("--llm-extra-body", type=json.loads, default=DEFAULT_LLM_EXTRA_BODY) parser.add_argument("--initial-negative-prompt", default="") parser.add_argument("--critic-endpoint-url", default=DEFAULT_CRITIC_ENDPOINT_URL) parser.add_argument("--critic-model", default=DEFAULT_CRITIC_MODEL) parser.add_argument("--gemini-api-key-env", default=DEFAULT_GEMINI_API_KEY_ENV) parser.add_argument("--gemini-api-key-file", type=Path, default=None) return parser.parse_args() def main() -> int: args = parse_args() args.output_dir.mkdir(parents=True, exist_ok=True) items = load_prompt_items(prompt=args.prompt, prompts_path=args.prompts, limit=args.limit) if not items: raise RuntimeError("No prompts selected.") if args.samples_per_iteration < 1: raise ValueError("--samples-per-iteration must be >= 1.") if not isinstance(args.generation_extra_args, dict): raise ValueError("--generation-extra-args must decode to a JSON object.") openai_token = read_api_token(args.openai_api_key_env, args.openai_api_key_file) gemini_token = read_api_token(args.gemini_api_key_env, args.gemini_api_key_file) generation_auth_key = read_optional_generation_auth_key(args.generation_auth_key, args.generation_auth_key_env) write_json_atomic( args.output_dir / "run_config.json", { "selected_prompts": len(items), "max_iterations": args.max_iterations, "samples_per_iteration": args.samples_per_iteration, "early_stop": not args.disable_early_stop, "generation_endpoint": args.generation_endpoint, "generation_model": args.generation_model, "size": args.size, "resolution": args.resolution, "aspect_ratio": args.aspect_ratio, "num_steps": args.num_steps, "guidance": args.guidance, "flow_shift": args.flow_shift, "generation_extra_args": args.generation_extra_args, "upsampler_endpoint_url": args.upsampler_endpoint_url, "upsampler_model": args.upsampler_model, "rewriter_endpoint_url": args.rewriter_endpoint_url, "rewriter_model": args.rewriter_model, "llm_extra_body": args.llm_extra_body, "critic_endpoint_url": args.critic_endpoint_url, "critic_model": args.critic_model, "initial_negative_prompt": args.initial_negative_prompt, }, ) rewriter = PromptRewriterClient( api_token=openai_token, upsampler_endpoint_url=args.upsampler_endpoint_url, upsampler_model=args.upsampler_model, rewriter_endpoint_url=args.rewriter_endpoint_url, rewriter_model=args.rewriter_model, extra_body=args.llm_extra_body, resolution=args.resolution, aspect_ratio=args.aspect_ratio, ) generator = ImageGenerationClient( endpoint=args.generation_endpoint, auth_key=generation_auth_key, model=args.generation_model, size=args.size, num_steps=args.num_steps, guidance=args.guidance, flow_shift=args.flow_shift, extra_args=args.generation_extra_args, ) judge = VLMQualityJudge( api_token=gemini_token, endpoint_url=args.critic_endpoint_url, model=args.critic_model, ) runner = AgenticUpsamplerRunner( rewriter=rewriter, generator=generator, judge=judge, config=RunnerConfig( output_dir=args.output_dir, max_iterations=args.max_iterations, samples_per_iteration=args.samples_per_iteration, overwrite=args.overwrite, seed_base=args.seed_base, initial_negative_prompt=args.initial_negative_prompt, early_stop=not args.disable_early_stop, verbose=not args.quiet, ), ) results = [runner.run_item_safely(item) for item in items] write_run_manifest(args.output_dir, results) failures = sum(1 for item in results if item.get("error")) summary = {"selected_prompts": len(items), "completed": len(items) - failures, "failures": failures} write_json_atomic(args.output_dir / "summary.json", summary) print(json.dumps(summary, indent=2), flush=True) if args.extract_best and not failures: export_dir = args.output_dir / "best_generations" extract_best_images(args.output_dir, export_dir, overwrite=args.overwrite) print(f"Exported best images to {export_dir}", flush=True) return 1 if failures else 0 if __name__ == "__main__": raise SystemExit(main())