import os, json from typing import List, Dict, Any, Optional from PIL import Image import torch import gradio as gr import spaces from huggingface_hub import snapshot_download from diffusers import ( StableDiffusionPipeline, # SD 1.x/2.x single-file loader StableDiffusionXLPipeline, # SDXL single-file loader DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, ) # -------- Config -------- MODEL_REPO_ID = os.getenv("MODEL_REPO_ID", "DB2169/mixy").strip() CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "realismIllustriousBy_v50FP16.safetensors").strip() HF_TOKEN = os.getenv("HF_TOKEN", None) DO_WARMUP = os.getenv("WARMUP", "1") == "1" LORAS_JSON = os.getenv("LORAS_JSON", "").strip() REPO_DIR = "/home/user/model" SCHEDULERS = { "default": None, "euler_a": EulerAncestralDiscreteScheduler, "euler": EulerDiscreteScheduler, "ddim": DDIMScheduler, "lms": LMSDiscreteScheduler, "pndm": PNDMScheduler, "dpmpp_2m": DPMSolverMultistepScheduler, } # -------- Globals -------- pipe = None IS_SDXL = False LORA_MANIFEST: Dict[str, Dict[str, str]] = {} INIT_ERROR: Optional[str] = None # -------- Helpers -------- def load_lora_manifest(repo_dir: str) -> Dict[str, Dict[str, str]]: if LORAS_JSON: try: parsed = json.loads(LORAS_JSON) if isinstance(parsed, dict): return parsed except Exception as e: print(f"[WARN] Failed to parse LORAS_JSON: {e}") repo_manifest = os.path.join(repo_dir, "loras.json") if os.path.exists(repo_manifest): try: with open(repo_manifest, "r", encoding="utf-8") as f: parsed = json.load(f) if isinstance(parsed, dict): return parsed except Exception as e: print(f"[WARN] Failed to parse repo loras.json: {e}") local_manifest = os.path.join(os.getcwd(), "loras.json") if os.path.exists(local_manifest): try: with open(local_manifest, "r", encoding="utf-8") as f: parsed = json.load(f) if isinstance(parsed, dict): return parsed except Exception as e: print(f"[WARN] Failed to parse local loras.json: {e}") print("[INFO] Using built-in LoRA fallback manifest.") return { "MoriiMee_Gothic": { "repo": "LyliaEngine/MoriiMee_Gothic_Niji_Style_Illustrious_r1", "weight_name": "MoriiMee_Gothic_Niji_Style_Illustrious_r1.safetensors" } } # -------- Bootstrap (CPU) -------- def bootstrap_model(): """ Try SD (1.x/2.x) single-file first, then SDXL single-file, to maximize compatibility with older diffusers that don’t expose DiffusionPipeline.from_single_file. """ global pipe, IS_SDXL, LORA_MANIFEST, INIT_ERROR INIT_ERROR = None if not MODEL_REPO_ID or not CHECKPOINT_FILENAME: INIT_ERROR = "Missing MODEL_REPO_ID or CHECKPOINT_FILENAME." print(f"[ERROR] {INIT_ERROR}") return try: local_dir = snapshot_download( repo_id=MODEL_REPO_ID, token=HF_TOKEN, local_dir=REPO_DIR, ignore_patterns=["*.md"], ) except Exception as e: INIT_ERROR = f"Failed to download repo {MODEL_REPO_ID}: {e}" print(f"[ERROR] {INIT_ERROR}") return ckpt_path = os.path.join(local_dir, CHECKPOINT_FILENAME) if not os.path.exists(ckpt_path): INIT_ERROR = f"Checkpoint not found at {ckpt_path}. Check CHECKPOINT_FILENAME." print(f"[ERROR] {INIT_ERROR}") return _pipe = None _is_sdxl = False # 1) SD 1.x/2.x first (most single-file merges are SD), then SDXL try: _pipe = StableDiffusionPipeline.from_single_file( ckpt_path, torch_dtype=torch.float16, use_safetensors=True ) _is_sdxl = False except Exception as e_sd: print(f"[INFO] SD load failed or not SD: {e_sd}") try: _pipe = StableDiffusionXLPipeline.from_single_file( ckpt_path, torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False ) _is_sdxl = True except Exception as e_sdxl: INIT_ERROR = f"Failed to load pipeline (SD and SDXL): SD={e_sd} | SDXL={e_sdxl}" print(f"[ERROR] {INIT_ERROR}") return if hasattr(_pipe, "enable_attention_slicing"): _pipe.enable_attention_slicing("max") if hasattr(_pipe, "enable_vae_slicing"): _pipe.enable_vae_slicing() if hasattr(_pipe, "set_progress_bar_config"): _pipe.set_progress_bar_config(disable=True) manifest = load_lora_manifest(local_dir) print(f"[INFO] LoRAs available: {list(manifest.keys())}") pipe = _pipe IS_SDXL = _is_sdxl LORA_MANIFEST = manifest def apply_loras(selected: List[str], scale: float, repo_dir: str): if not selected or scale <= 0: return for name in selected: meta = LORA_MANIFEST.get(name) if not meta: print(f"[WARN] Requested LoRA '{name}' not in manifest.") continue try: if "path" in meta: pipe.load_lora_weights(os.path.join(repo_dir, meta["path"]), adapter_name=name) else: pipe.load_lora_weights(meta.get("repo", ""), weight_name=meta.get("weight_name"), adapter_name=name) print(f"[INFO] Loaded LoRA: {name}") except Exception as e: print(f"[WARN] LoRA load failed for {name}: {e}") try: pipe.set_adapters(selected, adapter_weights=[float(scale)] * len(selected)) print(f"[INFO] Activated LoRAs: {selected} at scale {scale}") except Exception as e: print(f"[WARN] set_adapters failed: {e}") # -------- Generation (ZeroGPU) -------- @spaces.GPU def txt2img( prompt: str, negative: str, width: int, height: int, steps: int, guidance: float, images: int, seed: Optional[int], scheduler: str, loras: List[str], lora_scale: float, fuse_lora: bool, ): if pipe is None: raise RuntimeError(f"Model not initialized. {INIT_ERROR or 'Check Space secrets and logs.'}") local_device = "cuda" if torch.cuda.is_available() else "cpu" pipe.to(local_device) if scheduler in SCHEDULERS and SCHEDULERS[scheduler] is not None: try: pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config) except Exception as e: print(f"[WARN] Scheduler switch failed: {e}") apply_loras(loras, lora_scale, REPO_DIR) if fuse_lora and loras: try: pipe.fuse_lora(lora_scale=float(lora_scale)) except Exception as e: print(f"[WARN] fuse_lora failed: {e}") generator = torch.Generator(device=local_device).manual_seed(int(seed)) if seed not in (None, "") else None kwargs: Dict[str, Any] = dict( prompt=prompt or "", negative_prompt=negative or None, width=int(width), height=int(height), num_inference_steps=int(steps), guidance_scale=float(guidance), num_images_per_prompt=int(images), generator=generator, ) with torch.inference_mode(): out = pipe(**kwargs) return out.images # -------- UI -------- with gr.Blocks(title="SDXL/SD single-file (ZeroGPU, LoRA-ready)") as demo: status = gr.Markdown("") with gr.Row(): prompt = gr.Textbox(label="Prompt", lines=3) negative = gr.Textbox(label="Negative Prompt", lines=3) with gr.Row(): width = gr.Slider(256, 1536, 1024, step=64, label="Width") height = gr.Slider(256, 1536, 1024, step=64, label="Height") with gr.Row(): steps = gr.Slider(5, 80, 30, step=1, label="Steps") guidance = gr.Slider(0.0, 20.0, 6.5, step=0.1, label="Guidance") images = gr.Slider(1, 4, 1, step=1, label="Images") with gr.Row(): seed = gr.Number(value=None, precision=0, label="Seed (blank=random)") scheduler = gr.Dropdown(list(SCHEDULERS.keys()), value="dpmpp_2m", label="Scheduler") lora_names = gr.CheckboxGroup(choices=[], label="LoRAs (from loras.json; select any)") lora_scale = gr.Slider(0.0, 1.5, 0.7, step=0.05, label="LoRA scale") fuse = gr.Checkbox(label="Fuse LoRA (faster after load)") btn = gr.Button("Generate", variant="primary", interactive=False) gallery = gr.Gallery(columns=4, height=420) def _startup(): bootstrap_model() if INIT_ERROR: return ( gr.update(value=f"❌ Init failed: {INIT_ERROR}"), gr.update(choices=[]), gr.update(value=1024, minimum=256, maximum=1536, step=64), gr.update(value=1024, minimum=256, maximum=1536, step=64), gr.update(interactive=False), ) default_wh = 1024 if IS_SDXL else 512 msg = f"✅ Model loaded from {MODEL_REPO_ID} ({'SDXL' if IS_SDXL else 'SD'})" # Warm up only after model is ready (avoids race) if DO_WARMUP: try: _ = txt2img("warmup", "", default_wh, default_wh, 4, 4.0, 1, 1234, "default", [], 0.0, False) except Exception as e: print(f"[WARN] Warmup failed: {e}") return ( gr.update(value=msg), gr.update(choices=list(LORA_MANIFEST.keys())), gr.update(value=default_wh, minimum=256, maximum=1536, step=64), gr.update(value=default_wh, minimum=256, maximum=1536, step=64), gr.update(interactive=True), ) demo.load(_startup, outputs=[status, lora_names, width, height, btn]) btn.click( txt2img, inputs=[prompt, negative, width, height, steps, guidance, images, seed, scheduler, lora_names, lora_scale, fuse], outputs=[gallery], api_name="txt2img", concurrency_limit=1, concurrency_id="gpu_queue", ) demo.queue(max_size=32, default_concurrency_limit=1).launch()