Instructions to use nvidia/Cosmos3-Super-Text2Image with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Cosmos
How to use nvidia/Cosmos3-Super-Text2Image with Cosmos:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Diffusers
How to use nvidia/Cosmos3-Super-Text2Image with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("nvidia/Cosmos3-Super-Text2Image", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| import base64 | |
| import io | |
| import json | |
| import threading | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| from PIL import Image | |
| from agentic_upsampling.clients import ImageGenerationClient, PromptRewriterClient | |
| from agentic_upsampling.constants import ( | |
| DEFAULT_CRITIC_ENDPOINT_URL, | |
| DEFAULT_CRITIC_MODEL, | |
| DEFAULT_FLOW_SHIFT, | |
| DEFAULT_GENERATION_EXTRA_ARGS, | |
| DEFAULT_GENERATION_MODEL, | |
| DEFAULT_LLM_EXTRA_BODY, | |
| DEFAULT_REWRITER_MODEL, | |
| ) | |
| from agentic_upsampling.data import PromptItem, load_prompt_items, prompt_dir_name | |
| from agentic_upsampling.extract_best import extract_best_images | |
| from agentic_upsampling.prompt_upsampler import ( | |
| Text2ImagePromptUpsampler, | |
| apply_t2i_output_parameters, | |
| normalize_openai_base_url, | |
| ) | |
| from agentic_upsampling.rubric import parse_analysis_response | |
| from agentic_upsampling.runner import AgenticUpsamplerRunner, RunnerConfig | |
| def _item(prompt_id: str = "1", prompt: str = "a red cube") -> PromptItem: | |
| return PromptItem(prompt_id=prompt_id, row_number=0, prompt=prompt) | |
| def _valid_t2i_prompt(caption: str) -> dict[str, Any]: | |
| return { | |
| "subjects": [], | |
| "subject_details": {}, | |
| "background_setting": "plain studio", | |
| "lighting": {"conditions": "soft", "direction": "front", "shadows": "soft", "illumination_effect": "clear"}, | |
| "aesthetics": { | |
| "composition": "centered", | |
| "color_scheme": "balanced", | |
| "mood_atmosphere": "precise", | |
| "patterns": "", | |
| }, | |
| "cinematography": { | |
| "framing": "centered", | |
| "camera_angle": "eye-level", | |
| "depth_of_field": "deep", | |
| "focus": "sharp", | |
| "lens_focal_length": "standard", | |
| }, | |
| "style_medium": "digital render", | |
| "artistic_style": "clean realistic render", | |
| "context": "test prompt", | |
| "text_and_signage_elements": [], | |
| "quadrant_scan": { | |
| "top_left": "", | |
| "top_right": "", | |
| "bottom_left": "", | |
| "bottom_right": "", | |
| "absolute_center": "", | |
| }, | |
| "comprehensive_t2i_caption": caption, | |
| "resolution": {"H": 960, "W": 960}, | |
| "aspect_ratio": "1,1", | |
| } | |
| class FakeChatClient: | |
| messages: list[dict[str, Any]] | |
| response_format_json: bool | |
| def __init__(self, response: dict[str, Any]) -> None: | |
| self.response = response | |
| self.messages = [] | |
| self.response_format_json = False | |
| def complete(self, messages: list[dict[str, Any]], *, response_format_json: bool = False) -> str: | |
| self.messages = messages | |
| self.response_format_json = response_format_json | |
| return json.dumps(self.response) | |
| def test_defaults_are_public_provider_defaults() -> None: | |
| assert DEFAULT_REWRITER_MODEL == "gpt-5.5" | |
| assert DEFAULT_LLM_EXTRA_BODY == {"reasoning_effort": "low"} | |
| assert DEFAULT_CRITIC_MODEL == "gemini-3.1-pro-preview" | |
| assert DEFAULT_CRITIC_ENDPOINT_URL == "https://generativelanguage.googleapis.com/v1beta/openai/" | |
| def test_gemini_openai_compatible_base_url_is_not_modified() -> None: | |
| assert ( | |
| normalize_openai_base_url("https://generativelanguage.googleapis.com/v1beta/openai/") | |
| == "https://generativelanguage.googleapis.com/v1beta/openai" | |
| ) | |
| assert ( | |
| normalize_openai_base_url("https://generativelanguage.googleapis.com/v1beta/openai/chat/completions") | |
| == "https://generativelanguage.googleapis.com/v1beta/openai" | |
| ) | |
| def test_prompt_loaders_support_text_jsonl_and_csv(tmp_path: Path) -> None: | |
| txt_path = tmp_path / "prompts.txt" | |
| txt_path.write_text("one\n\ntwo\n", encoding="utf-8") | |
| assert [item.prompt for item in load_prompt_items(prompts_path=txt_path)] == ["one", "two"] | |
| jsonl_path = tmp_path / "prompts.jsonl" | |
| jsonl_path.write_text('{"id":"custom id","prompt":"three"}\n"four"\n', encoding="utf-8") | |
| jsonl_items = load_prompt_items(prompts_path=jsonl_path) | |
| assert [item.prompt for item in jsonl_items] == ["three", "four"] | |
| assert prompt_dir_name(jsonl_items[0]) == "custom_id" | |
| csv_path = tmp_path / "prompts.csv" | |
| csv_path.write_text("id,prompt\nfive_id,five\n", encoding="utf-8") | |
| csv_items = load_prompt_items(prompts_path=csv_path) | |
| assert csv_items[0].prompt_id == "five_id" | |
| assert csv_items[0].prompt == "five" | |
| def test_prompt_upsampler_applies_resolution_and_requests_json() -> None: | |
| prompt_json = _valid_t2i_prompt("initial cube prompt") | |
| fake_client = FakeChatClient(prompt_json) | |
| upsampler = Text2ImagePromptUpsampler(fake_client) # type: ignore[arg-type] | |
| result = upsampler.upsample("a cube", prompt_id="cube", resolution="720", aspect_ratio="16,9") | |
| assert result["resolution"] == {"H": 720, "W": 1280} | |
| assert result["aspect_ratio"] == "16,9" | |
| assert fake_client.response_format_json is True | |
| def test_apply_t2i_output_parameters_rejects_bad_canvas() -> None: | |
| try: | |
| apply_t2i_output_parameters(_valid_t2i_prompt("x"), resolution="999", aspect_ratio="1,1") | |
| except ValueError as exc: | |
| assert "Unsupported resolution" in str(exc) | |
| else: | |
| raise AssertionError("Expected unsupported resolution error.") | |
| def test_prompt_rewriter_joint_rewrite_uses_vlm_feedback() -> None: | |
| previous_prompt = _valid_t2i_prompt("old cube prompt") | |
| rewritten_prompt = _valid_t2i_prompt("new cube prompt with no 4x4 grid") | |
| analysis = { | |
| "overall_score": 2.0, | |
| "prompt_adherence_score": 3.0, | |
| "category_score": 3.0, | |
| "issues": [ | |
| { | |
| "category": "geometry", | |
| "description": "Generated a 4x4 grid instead of a 3x3 cube.", | |
| "severity": "severe", | |
| } | |
| ], | |
| "improvement_directives": ["Strictly enforce 3x3x3 geometry."], | |
| "raw_response": "large omitted blob", | |
| } | |
| rewriter = PromptRewriterClient(api_token="unused") | |
| fake_client = FakeChatClient({"positive_prompt": rewritten_prompt, "negative_prompt": "4x4 grid"}) | |
| rewriter.rewrite_client = fake_client # type: ignore[assignment] | |
| positive_prompt, negative_prompt = rewriter.rewrite_prompt_pair( | |
| _item("39", "A Rubik's cube mid twist with the top layer rotated exactly 45 degrees"), | |
| previous_prompt, | |
| "", | |
| analysis, | |
| [{"iteration": 0, "analysis": analysis}], | |
| ) | |
| assert positive_prompt["comprehensive_t2i_caption"] == "new cube prompt with no 4x4 grid" | |
| assert negative_prompt == "4x4 grid" | |
| assert fake_client.response_format_json is True | |
| user_message = str(fake_client.messages[1]["content"]) | |
| assert "Generated a 4x4 grid" in user_message | |
| assert "Strictly enforce 3x3x3 geometry" in user_message | |
| assert "raw_response" not in user_message | |
| def test_generation_payload_uses_vllm_omni_images_api() -> None: | |
| client = ImageGenerationClient(endpoint="https://example.test/v1", model="test/model") | |
| payload = client.build_payload({"comprehensive_t2i_caption": "x"}, prompt_id="3", seed=100, negative_prompt="blur") | |
| assert client.endpoint == "https://example.test" | |
| assert payload["model"] == "test/model" | |
| assert payload["prompt"] == '{"comprehensive_t2i_caption":"x"}' | |
| assert payload["size"] == "1024x1024" | |
| assert payload["n"] == 1 | |
| assert payload["response_format"] == "b64_json" | |
| assert payload["negative_prompt"] == "blur" | |
| assert payload["num_inference_steps"] == 50 | |
| assert payload["guidance_scale"] == 4.0 | |
| assert payload["flow_shift"] == DEFAULT_FLOW_SHIFT | |
| assert payload["extra_args"] == DEFAULT_GENERATION_EXTRA_ARGS | |
| assert payload["seed"] == 100 | |
| assert "model_mode" not in payload | |
| assert "prompt_upsampling" not in payload | |
| def test_generation_payload_allows_custom_extra_args() -> None: | |
| client = ImageGenerationClient(endpoint="https://example.test", extra_args={"guardrails": True}) | |
| payload = client.build_payload({"comprehensive_t2i_caption": "x"}, prompt_id="3") | |
| assert payload["extra_args"] == {"guardrails": True} | |
| class FakeImageResponse: | |
| ok: bool = True | |
| status_code: int = 200 | |
| text: str = "ok" | |
| def __init__(self, payload: dict[str, Any]) -> None: | |
| self.payload = payload | |
| def json(self) -> dict[str, Any]: | |
| return self.payload | |
| class FakeImageSession: | |
| calls: list[dict[str, Any]] | |
| def __init__(self, response_payload: dict[str, Any]) -> None: | |
| self.response_payload = response_payload | |
| self.calls = [] | |
| def request(self, method: str, url: str, **kwargs: Any) -> FakeImageResponse: | |
| self.calls.append({"method": method, "url": url, "kwargs": kwargs}) | |
| return FakeImageResponse(self.response_payload) | |
| def _tiny_png_b64() -> str: | |
| buf = io.BytesIO() | |
| Image.new("RGB", (4, 4), (0, 255, 0)).save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode("ascii") | |
| def test_generation_client_decodes_vllm_omni_b64_response(tmp_path: Path) -> None: | |
| session = FakeImageSession({"created": 1, "data": [{"b64_json": _tiny_png_b64(), "revised_prompt": None}]}) | |
| client = ImageGenerationClient(endpoint="example.test", auth_key="secret-token", session=session) # type: ignore[arg-type] | |
| result = client.generate(prompt_json=_valid_t2i_prompt("x"), prompt_id="3", output_dir=tmp_path, seed=5) | |
| assert result.image_path.exists() | |
| assert session.calls[0]["method"] == "POST" | |
| assert session.calls[0]["url"] == "https://example.test/v1/images/generations" | |
| assert session.calls[0]["kwargs"]["headers"] == {"Authorization": "Bearer secret-token"} | |
| assert session.calls[0]["kwargs"]["json"]["model"] == DEFAULT_GENERATION_MODEL | |
| meta = json.loads(result.meta_path.read_text(encoding="utf-8")) | |
| assert meta["status"] == "completed" | |
| assert meta["response"]["data"][0]["b64_json"].startswith("<base64 image omitted:") | |
| def test_parse_analysis_response_sets_threshold_flag() -> None: | |
| analysis = parse_analysis_response( | |
| """ | |
| { | |
| "prompt_adherence_score": 9, | |
| "visual_quality_score": 9, | |
| "aesthetics_score": 8.5, | |
| "physical_plausibility_score": 8, | |
| "category_score": 9, | |
| "text_rendering_score": 9, | |
| "photorealism_score": null, | |
| "overall_score": 9.1, | |
| "issues": [], | |
| "category_findings": {}, | |
| "improvement_directives": [], | |
| "rationale": "Strong." | |
| } | |
| """, | |
| ) | |
| assert analysis["threshold_cleared"] is True | |
| class FakeRewriter: | |
| initial_calls: int | |
| joint_rewrite_calls: int | |
| previous_scores: list[float] | |
| def __init__(self) -> None: | |
| self.initial_calls = 0 | |
| self.joint_rewrite_calls = 0 | |
| self.previous_scores = [] | |
| def initial_prompt(self, item: PromptItem) -> dict[str, Any]: | |
| self.initial_calls += 1 | |
| return _valid_t2i_prompt(f"initial {item.prompt_id}") | |
| def rewrite_prompt_pair( | |
| self, | |
| item: PromptItem, | |
| previous_prompt: dict[str, Any], | |
| previous_negative_prompt: str, | |
| previous_analysis: dict[str, Any], | |
| history: list[dict[str, Any]], | |
| ) -> tuple[dict[str, Any], str]: | |
| self.joint_rewrite_calls += 1 | |
| self.previous_scores.append(float(previous_analysis["overall_score"])) | |
| return _valid_t2i_prompt(f"rewrite {len(history)}"), f"negative {len(history)}" | |
| class FakeGeneration: | |
| image_path: Path | |
| meta_path: Path | |
| meta: dict[str, Any] | |
| class FakeGenerator: | |
| seeds: list[int | None] | |
| negative_prompts: list[str] | |
| def __init__(self) -> None: | |
| self.seeds = [] | |
| self.negative_prompts = [] | |
| def generate( | |
| self, | |
| *, | |
| prompt_json: dict[str, Any], | |
| prompt_id: str, | |
| output_dir: Path, | |
| seed: int | None = None, | |
| negative_prompt: str = "", | |
| jpeg_quality: int = 95, | |
| ) -> FakeGeneration: | |
| self.seeds.append(seed) | |
| self.negative_prompts.append(negative_prompt) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| image_path = output_dir / "image.jpg" | |
| Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path) | |
| meta_path = output_dir / "generation_meta.json" | |
| meta_path.write_text('{"status":"completed"}\n', encoding="utf-8") | |
| return FakeGeneration(image_path=image_path, meta_path=meta_path, meta={"status": "completed"}) | |
| class BarrierGenerator(FakeGenerator): | |
| barrier: threading.Barrier | |
| lock: threading.Lock | |
| def __init__(self, parties: int) -> None: | |
| super().__init__() | |
| self.barrier = threading.Barrier(parties) | |
| self.lock = threading.Lock() | |
| def generate( | |
| self, | |
| *, | |
| prompt_json: dict[str, Any], | |
| prompt_id: str, | |
| output_dir: Path, | |
| seed: int | None = None, | |
| negative_prompt: str = "", | |
| jpeg_quality: int = 95, | |
| ) -> FakeGeneration: | |
| with self.lock: | |
| self.seeds.append(seed) | |
| self.negative_prompts.append(negative_prompt) | |
| self.barrier.wait(timeout=2.0) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| image_path = output_dir / "image.jpg" | |
| Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path) | |
| meta_path = output_dir / "generation_meta.json" | |
| meta_path.write_text('{"status":"completed"}\n', encoding="utf-8") | |
| return FakeGeneration(image_path=image_path, meta_path=meta_path, meta={"status": "completed"}) | |
| class FakeJudge: | |
| calls: int | |
| scores: list[float] | |
| def __init__(self, scores: list[float]) -> None: | |
| self.calls = 0 | |
| self.scores = scores | |
| def score_image( | |
| self, | |
| *, | |
| item: PromptItem, | |
| image_path: Path, | |
| ) -> dict[str, Any]: | |
| score = self.scores[self.calls] | |
| self.calls += 1 | |
| return { | |
| "overall_score": score, | |
| "prompt_adherence_score": score, | |
| "visual_quality_score": score, | |
| "aesthetics_score": score, | |
| "physical_plausibility_score": score, | |
| "category_score": score, | |
| "issues": [], | |
| "improvement_directives": [], | |
| "threshold_cleared": score >= 9, | |
| } | |
| def test_runner_early_stops_by_default(tmp_path: Path) -> None: | |
| rewriter = FakeRewriter() | |
| generator = FakeGenerator() | |
| runner = AgenticUpsamplerRunner( | |
| rewriter=rewriter, | |
| generator=generator, # type: ignore[arg-type] | |
| judge=FakeJudge([9.1, 8.0]), | |
| config=RunnerConfig(output_dir=tmp_path, max_iterations=3, samples_per_iteration=1), | |
| ) | |
| result = runner.run_item(_item()) | |
| assert result["best_iteration"] == 0 | |
| assert rewriter.initial_calls == 1 | |
| assert rewriter.joint_rewrite_calls == 0 | |
| assert generator.seeds == [None] | |
| def test_runner_can_disable_early_stop_and_select_best_sample(tmp_path: Path) -> None: | |
| rewriter = FakeRewriter() | |
| generator = FakeGenerator() | |
| runner = AgenticUpsamplerRunner( | |
| rewriter=rewriter, | |
| generator=generator, # type: ignore[arg-type] | |
| judge=FakeJudge([5.0, 9.0, 7.0, 6.0, 10.0, 8.0]), | |
| config=RunnerConfig( | |
| output_dir=tmp_path, | |
| max_iterations=2, | |
| samples_per_iteration=3, | |
| seed_base=1000, | |
| early_stop=False, | |
| ), | |
| ) | |
| result = runner.run_item(_item("8", "exactly 12 balloons with exact color counts")) | |
| assert generator.seeds == [1000, 1001, 1002, 1000, 1001, 1002] | |
| assert rewriter.previous_scores == [9.0] | |
| assert result["best_iteration"] == 1 | |
| assert result["best"]["selected_sample_index"] == 1 | |
| assert result["iterations"][0]["selected_sample_index"] == 1 | |
| def test_runner_generates_seed_samples_in_parallel(tmp_path: Path) -> None: | |
| rewriter = FakeRewriter() | |
| generator = BarrierGenerator(parties=3) | |
| runner = AgenticUpsamplerRunner( | |
| rewriter=rewriter, | |
| generator=generator, # type: ignore[arg-type] | |
| judge=FakeJudge([5.0, 6.0, 7.0]), | |
| config=RunnerConfig( | |
| output_dir=tmp_path, | |
| max_iterations=1, | |
| samples_per_iteration=3, | |
| seed_base=2000, | |
| early_stop=False, | |
| ), | |
| ) | |
| result = runner.run_item(_item("parallel", "a parallel seed test")) | |
| assert sorted(generator.seeds) == [2000, 2001, 2002] | |
| assert result["best"]["selected_sample_index"] == 2 | |
| assert result["iterations"][0]["sample_count"] == 3 | |
| def test_extract_best_images_copies_images_and_writes_manifests(tmp_path: Path) -> None: | |
| output_dir = tmp_path / "run" | |
| image_dir = output_dir / "0001" / "iter_00" | |
| image_dir.mkdir(parents=True) | |
| image_path = image_dir / "image.jpg" | |
| Image.new("RGB", (8, 8), (255, 0, 0)).save(image_path) | |
| best_json = { | |
| "prompt_id": "1", | |
| "prompt": "a red square", | |
| "best_iteration": 0, | |
| "best_score": 9.25, | |
| "threshold_cleared_any": True, | |
| "best": { | |
| "selected_sample_index": 0, | |
| "image_path": str(image_path), | |
| "analysis_path": str(image_dir / "analysis.json"), | |
| }, | |
| "iterations": [], | |
| } | |
| (output_dir / "0001" / "best.json").write_text(json.dumps(best_json), encoding="utf-8") | |
| records = extract_best_images(output_dir, tmp_path / "export") | |
| assert len(records) == 1 | |
| copied_path = Path(records[0]["copied_image_path"]) | |
| assert copied_path.exists() | |
| assert copied_path.name == "1.jpg" | |
| assert (tmp_path / "export" / "best_generations.jsonl").exists() | |
| assert (tmp_path / "export" / "best_generations.csv").exists() | |