Spaces:
Runtime error
Runtime error
| """ | |
| Hybrid FastAPI/Gradio application for image generation with SDXL-Turbo | |
| Provides both API endpoints and a Gradio UI | |
| """ | |
| from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image | |
| import torch | |
| import os | |
| import time | |
| import uuid | |
| import logging | |
| import math | |
| from typing import Optional | |
| import gradio as gr | |
| # FastAPI imports | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks | |
| from fastapi.responses import FileResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| # PIL for image processing | |
| from PIL import Image | |
| # Try to import Intel extensions if available | |
| try: | |
| import intel_extension_for_pytorch as ipex | |
| except: | |
| pass | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("hybrid-app") | |
| # Get environment variables | |
| SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) | |
| TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None) | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Avatar Generator API", | |
| description="Generate avatar images based on pose, shirt, and face inputs using SDXL-Turbo", | |
| version="0.1.0" | |
| ) | |
| # Add CORS middleware to allow the Gradio UI to communicate with the API | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Configure device settings - copied directly from reference implementation | |
| # check if MPS is available OSX only M1/M2/M3 chips | |
| mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() | |
| xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available() | |
| device = torch.device( | |
| "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu" | |
| ) | |
| torch_device = device | |
| torch_dtype = torch.float16 | |
| logger.info(f"SAFETY_CHECKER: {SAFETY_CHECKER}") | |
| logger.info(f"TORCH_COMPILE: {TORCH_COMPILE}") | |
| logger.info(f"device: {device}") | |
| if mps_available: | |
| device = torch.device("mps") | |
| torch_device = "cpu" | |
| torch_dtype = torch.float32 | |
| logger.info("MPS available, using MPS device for Apple Silicon") | |
| # Global variables to store loaded pipelines | |
| i2i_pipe = None | |
| t2i_pipe = None | |
| def setup_directories(): | |
| """Create necessary directories if they don't exist""" | |
| dirs = ["/app/input", "/app/output", "/app/.cache"] | |
| for dir_path in dirs: | |
| os.makedirs(dir_path, exist_ok=True) | |
| if not os.access(dir_path, os.W_OK): | |
| logger.warning(f"Directory {dir_path} is not writable!") | |
| def load_pipelines(): | |
| """Load and return the optimized pipelines""" | |
| global i2i_pipe, t2i_pipe | |
| # Only load if not already loaded | |
| if i2i_pipe is not None and t2i_pipe is not None: | |
| return i2i_pipe, t2i_pipe | |
| logger.info("Setting up model pipelines...") | |
| setup_directories() | |
| try: | |
| # Set model cache directory | |
| os.environ["HF_HUB_CACHE"] = "/app/.cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/app/.cache" | |
| os.environ["HF_HOME"] = "/app/.cache" | |
| # Load pipelines based on SAFETY_CHECKER setting - just like reference implementation | |
| if SAFETY_CHECKER == "True": | |
| logger.info("Loading pipelines with safety checker") | |
| i2i_pipe = AutoPipelineForImage2Image.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| torch_dtype=torch_dtype, | |
| variant="fp16" if torch_dtype == torch.float16 else "fp32", | |
| ) | |
| t2i_pipe = AutoPipelineForText2Image.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| torch_dtype=torch_dtype, | |
| variant="fp16" if torch_dtype == torch.float16 else "fp32", | |
| ) | |
| else: | |
| logger.info("Loading pipelines without safety checker") | |
| i2i_pipe = AutoPipelineForImage2Image.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| safety_checker=None, | |
| torch_dtype=torch_dtype, | |
| variant="fp16" if torch_dtype == torch.float16 else "fp32", | |
| ) | |
| t2i_pipe = AutoPipelineForText2Image.from_pretrained( | |
| "stabilityai/sdxl-turbo", | |
| safety_checker=None, | |
| torch_dtype=torch_dtype, | |
| variant="fp16" if torch_dtype == torch.float16 else "fp32", | |
| ) | |
| # Move to appropriate device | |
| t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device) | |
| t2i_pipe.set_progress_bar_config(disable=True) | |
| i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device) | |
| i2i_pipe.set_progress_bar_config(disable=True) | |
| # Apply memory optimizations if on CUDA | |
| if device.type == "cuda": | |
| try: | |
| # Try to use xformers for memory efficiency | |
| i2i_pipe.enable_xformers_memory_efficient_attention() | |
| t2i_pipe.enable_xformers_memory_efficient_attention() | |
| logger.info("Enabled xformers memory efficient attention") | |
| except Exception as e: | |
| logger.info(f"Could not enable xformers: {str(e)}") | |
| i2i_pipe.enable_attention_slicing() | |
| t2i_pipe.enable_attention_slicing() | |
| logger.info("Using attention slicing instead") | |
| logger.info(f"Pipelines loaded successfully on {device}") | |
| return i2i_pipe, t2i_pipe | |
| except Exception as e: | |
| logger.error(f"Error loading pipelines: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| raise RuntimeError(f"Failed to load generation pipelines: {str(e)}") | |
| def resize_crop(image, width=512, height=512): | |
| """Resize and crop image to target size while maintaining aspect ratio""" | |
| image = image.convert("RGB") | |
| # Get original aspect ratio | |
| orig_width, orig_height = image.size | |
| orig_aspect = orig_width / orig_height | |
| target_aspect = width / height | |
| # Determine dimensions for resizing before crop | |
| if orig_aspect > target_aspect: | |
| # Image is wider than target, resize to match height | |
| new_height = height | |
| new_width = int(orig_aspect * new_height) | |
| else: | |
| # Image is taller than target, resize to match width | |
| new_width = width | |
| new_height = int(new_width / orig_aspect) | |
| # Resize with proper filtering | |
| image = image.resize((new_width, new_height), Image.BICUBIC) | |
| # Center crop to target dimensions | |
| left = (new_width - width) // 2 | |
| top = (new_height - height) // 2 | |
| right = left + width | |
| bottom = top + height | |
| # Crop and return | |
| return image.crop((left, top, right, bottom)) | |
| def cleanup_files(files_to_clean): | |
| """Background task to clean up temporary files""" | |
| for file_path in files_to_clean: | |
| try: | |
| if os.path.exists(file_path): | |
| os.remove(file_path) | |
| logger.info(f"Cleaned up file: {file_path}") | |
| except Exception as e: | |
| logger.error(f"Error cleaning up file {file_path}: {str(e)}") | |
| async def startup_event(): | |
| """Initialize the model pipeline on startup""" | |
| logger.info("Initializing model pipelines...") | |
| try: | |
| # Load pipelines in background - will be ready for first request | |
| # Note: First request might be slow if pipeline is still loading | |
| load_pipelines() | |
| logger.info("Pipelines initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Error initializing pipelines: {str(e)}") | |
| # Continue with startup, pipeline will be loaded on first request | |
| # | |
| # FastAPI Endpoints | |
| # | |
| async def generate( | |
| background_tasks: BackgroundTasks, | |
| prompt: str = Form(...), | |
| name: str = Form(...), | |
| role: str = Form(...), | |
| pose_image: UploadFile = File(...), | |
| shirt_image: UploadFile = File(...), | |
| face_image: Optional[UploadFile] = File(None), | |
| steps: Optional[int] = Form(2), # Default to 2 steps like in the reference | |
| guidance_scale: Optional[float] = Form(0.0), # Default to 0.0 like in the reference | |
| strength: Optional[float] = Form(0.7), # Default to 0.7 like in the reference | |
| width: Optional[int] = Form(512), # Default to 512 like in the reference | |
| height: Optional[int] = Form(512), # Default to 512 like in the reference | |
| seed: Optional[int] = Form(None) | |
| ): | |
| """ | |
| Generate an avatar image based on the provided inputs | |
| Args: | |
| prompt: Text prompt describing the desired image | |
| name: Person name (for filename) | |
| role: Role/job (for styling and filename) | |
| pose_image: Image file for pose reference | |
| shirt_image: Image file for shirt reference | |
| face_image: Optional image file for face reference | |
| steps: Number of inference steps | |
| guidance_scale: How much to weigh the prompt | |
| strength: How much to transform the pose image (0.0 to 1.0) | |
| width: Output image width | |
| height: Output image height | |
| seed: Random seed for reproducibility | |
| Returns: | |
| The generated image | |
| """ | |
| # Validate inputs | |
| if not prompt or len(prompt) < 3: | |
| raise HTTPException(status_code=400, detail="Prompt must be at least 3 characters") | |
| if not name or len(name) < 2: | |
| raise HTTPException(status_code=400, detail="Name must be at least 2 characters") | |
| if not role or len(role) < 2: | |
| raise HTTPException(status_code=400, detail="Role must be at least 2 characters") | |
| # Create unique filenames for uploaded files | |
| files_to_clean = [] | |
| try: | |
| # Save uploaded files with unique names | |
| input_dir = "/app/input" | |
| pose_path = os.path.join(input_dir, f"pose_{uuid.uuid4()}.jpg") | |
| with open(pose_path, "wb") as f: | |
| f.write(await pose_image.read()) | |
| files_to_clean.append(pose_path) | |
| shirt_path = os.path.join(input_dir, f"shirt_{uuid.uuid4()}.jpg") | |
| with open(shirt_path, "wb") as f: | |
| f.write(await shirt_image.read()) | |
| files_to_clean.append(shirt_path) | |
| face_path = None | |
| if face_image: | |
| face_path = os.path.join(input_dir, f"face_{uuid.uuid4()}.jpg") | |
| with open(face_path, "wb") as f: | |
| f.write(await face_image.read()) | |
| files_to_clean.append(face_path) | |
| # Generate the image using the common generation function | |
| output_path = generate_image_internal( | |
| prompt=prompt, | |
| name=name, | |
| role=role, | |
| pose_path=pose_path, | |
| shirt_path=shirt_path, | |
| face_path=face_path, | |
| steps=steps, | |
| guidance_scale=guidance_scale, | |
| strength=strength, | |
| width=width, | |
| height=height, | |
| seed=seed | |
| ) | |
| # Schedule cleanup of temporary files | |
| background_tasks.add_task(cleanup_files, files_to_clean) | |
| # Return the generated image | |
| return FileResponse( | |
| output_path, | |
| media_type="image/png", | |
| filename=f"{role}_{name}.png" | |
| ) | |
| except HTTPException: | |
| # Clean up files if there was a validation error | |
| cleanup_files(files_to_clean) | |
| raise | |
| except Exception as e: | |
| # Clean up files and return error | |
| logger.error(f"Error in generate endpoint: {str(e)}") | |
| cleanup_files(files_to_clean) | |
| raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "healthy"} | |
| async def root(): | |
| """Root endpoint - redirects to docs""" | |
| return {"message": "Welcome to Avatar Generator API", "docs": "/docs", "ui": "/ui"} | |
| # | |
| # Internal processing functions | |
| # | |
| def generate_image_internal( | |
| prompt: str, | |
| name: str, | |
| role: str, | |
| pose_path: str, | |
| shirt_path: str, | |
| face_path: Optional[str] = None, | |
| steps: int = 2, | |
| guidance_scale: float = 0.0, | |
| strength: float = 0.7, | |
| width: int = 512, | |
| height: int = 512, | |
| seed: Optional[int] = None | |
| ): | |
| """ | |
| Internal function to generate an image based on the provided inputs | |
| This is used by both the API endpoint and the Gradio interface | |
| Returns: | |
| Path to the generated image | |
| """ | |
| # Load pipelines if not already loaded | |
| i2i_pipe, _ = load_pipelines() | |
| # Process pose image for use with image-to-image | |
| pose_image = Image.open(pose_path).convert("RGB") | |
| pose_image = resize_crop(pose_image, width, height) | |
| # Analyze shirt for color information to enhance prompt | |
| shirt_image = Image.open(shirt_path).convert("RGB") | |
| shirt_colors = shirt_image.resize((1, 1)).getpixel((0, 0)) | |
| color_text = f"wearing {role} clothes in color similar to RGB({shirt_colors[0]},{shirt_colors[1]},{shirt_colors[2]})" | |
| # Enhance prompt with role and color information | |
| enhanced_prompt = f"{prompt}. {name} as {role} style, highly detailed, {color_text}" | |
| logger.info(f"Enhanced prompt: {enhanced_prompt}") | |
| # Ensure valid strength and steps | |
| if int(steps * strength) < 1: | |
| steps = math.ceil(1 / max(0.10, strength)) | |
| logger.info(f"Adjusted steps to {steps} to ensure at least one denoising step") | |
| # Set seed for reproducibility | |
| if seed is None: | |
| seed = int(time.time()) | |
| generator = torch.Generator(device=i2i_pipe.device).manual_seed(seed) | |
| logger.info(f"Starting generation with seed {seed}") | |
| start_time = time.time() | |
| # Use image-to-image pipeline with the pose image as input | |
| results = i2i_pipe( | |
| prompt=enhanced_prompt, | |
| image=pose_image, | |
| generator=generator, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| strength=strength, | |
| width=width, | |
| height=height, | |
| output_type="pil", | |
| ) | |
| generation_time = time.time() - start_time | |
| logger.info(f"Image generated in {generation_time:.2f} seconds") | |
| # Check for NSFW content | |
| nsfw_content_detected = ( | |
| results.nsfw_content_detected[0] | |
| if "nsfw_content_detected" in results | |
| else False | |
| ) | |
| if nsfw_content_detected: | |
| logger.warning("NSFW content detected, returning placeholder image") | |
| image = Image.new("RGB", (width, height), color=(100, 100, 100)) | |
| else: | |
| image = results.images[0] | |
| # Create output file name and path | |
| output_filename = f"{role}_{name}_{seed}.png" | |
| output_path = os.path.join("/app/output", output_filename) | |
| # Save the generated image | |
| image.save(output_path) | |
| logger.info(f"Saved output image to {output_path}") | |
| return output_path | |
| # | |
| # Gradio UI | |
| # | |
| # Function for Gradio's prediction interface | |
| def gradio_predict( | |
| pose_image, | |
| prompt, | |
| name, | |
| role, | |
| shirt_image, | |
| face_image=None, | |
| strength=0.7, | |
| steps=2, | |
| seed=None | |
| ): | |
| """ | |
| Generate an image for the Gradio UI | |
| Args match the generate_image_internal function but adapted for Gradio's interface | |
| """ | |
| if pose_image is None or shirt_image is None: | |
| return None | |
| if not prompt or len(prompt) < 3: | |
| raise gr.Error("Prompt must be at least 3 characters") | |
| if not name or len(name) < 2: | |
| raise gr.Error("Name must be at least 2 characters") | |
| if not role or len(role) < 2: | |
| raise gr.Error("Role must be at least 2 characters") | |
| # Save images to temporary files | |
| input_dir = "/app/input" | |
| os.makedirs(input_dir, exist_ok=True) | |
| # Create file paths | |
| pose_path = os.path.join(input_dir, f"pose_gradio_{uuid.uuid4()}.jpg") | |
| shirt_path = os.path.join(input_dir, f"shirt_gradio_{uuid.uuid4()}.jpg") | |
| face_path = None | |
| files_to_clean = [pose_path, shirt_path] | |
| try: | |
| # Save images | |
| pose_image.save(pose_path) | |
| shirt_image.save(shirt_path) | |
| if face_image is not None: | |
| face_path = os.path.join(input_dir, f"face_gradio_{uuid.uuid4()}.jpg") | |
| face_image.save(face_path) | |
| files_to_clean.append(face_path) | |
| # Generate image | |
| if seed is None or seed == 0: | |
| seed = int(time.time()) | |
| output_path = generate_image_internal( | |
| prompt=prompt, | |
| name=name, | |
| role=role, | |
| pose_path=pose_path, | |
| shirt_path=shirt_path, | |
| face_path=face_path, | |
| steps=steps, | |
| guidance_scale=0.0, # Fixed at 0.0 for SDXL Turbo as in reference | |
| strength=strength, | |
| width=512, # Fixed at 512 for Gradio UI as in reference | |
| height=512, # Fixed at 512 for Gradio UI as in reference | |
| seed=seed | |
| ) | |
| # Load the image to return to Gradio | |
| result_image = Image.open(output_path) | |
| # Clean up temporary files in background | |
| cleanup_files(files_to_clean) | |
| return result_image | |
| except Exception as e: | |
| # Clean up files and propagate error | |
| cleanup_files(files_to_clean) | |
| raise gr.Error(f"Error generating image: {str(e)}") | |
| # Create the Gradio interface | |
| def create_gradio_interface(): | |
| """Create and return the Gradio interface""" | |
| css = """ | |
| #container{ | |
| margin: 0 auto; | |
| max-width: 80rem; | |
| } | |
| #intro{ | |
| max-width: 100%; | |
| text-align: center; | |
| margin: 0 auto; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown( | |
| """# Avatar Generator | |
| ## Generate customized avatars based on pose, shirt, and prompts | |
| Upload a pose image and a shirt reference, then describe the avatar you want to generate. | |
| """, | |
| elem_id="intro", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input elements | |
| pose_image = gr.Image(label="Pose Reference", type="pil", sources=["upload", "webcam", "clipboard"]) | |
| shirt_image = gr.Image(label="Shirt Reference", type="pil", sources=["upload", "webcam", "clipboard"]) | |
| face_image = gr.Image(label="Face Reference (Optional)", type="pil", sources=["upload", "webcam", "clipboard"]) | |
| prompt = gr.Textbox(label="Prompt", placeholder="Describe the avatar you want to generate...") | |
| name = gr.Textbox(label="Name", placeholder="Enter name...") | |
| role = gr.Textbox(label="Role/Profession", placeholder="Enter role (e.g., doctor, engineer)...") | |
| with gr.Accordion("Advanced Options", open=False): | |
| strength = gr.Slider( | |
| label="Strength", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.05, | |
| info="How much to transform the pose image (higher = more creative)" | |
| ) | |
| steps = gr.Slider( | |
| label="Steps", | |
| minimum=1, | |
| maximum=10, | |
| value=2, | |
| step=1, | |
| info="Number of denoising steps (higher = more detail but slower)" | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=9999999999, | |
| value=0, | |
| step=1, | |
| info="Random seed for reproducibility (0 = random)" | |
| ) | |
| generate_button = gr.Button("Generate Avatar", variant="primary") | |
| with gr.Column(): | |
| # Output image | |
| output_image = gr.Image(label="Generated Avatar", type="pil") | |
| # Example inputs | |
| examples = [ | |
| [ | |
| None, # Will be filled with pose image | |
| "Professional portrait, high quality", | |
| "John", | |
| "Doctor", | |
| None, # Will be filled with shirt image | |
| None, # No face image | |
| 0.7, # Strength | |
| 2, # Steps | |
| 42 # Seed | |
| ] | |
| ] | |
| # Event handlers | |
| generate_button.click( | |
| fn=gradio_predict, | |
| inputs=[pose_image, prompt, name, role, shirt_image, face_image, strength, steps, seed], | |
| outputs=output_image | |
| ) | |
| # Examples don't work well with image inputs, so commenting out for now | |
| # gr.Examples( | |
| # examples=examples, | |
| # inputs=[pose_image, prompt, name, role, shirt_image, face_image, strength, steps, seed], | |
| # outputs=output_image, | |
| # fn=gradio_predict, | |
| # cache_examples=True, | |
| # ) | |
| return demo | |
| # Create and mount the Gradio app at the /ui route | |
| gradio_app = create_gradio_interface() | |
| app = gr.mount_gradio_app(app, gradio_app, path="/ui") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("hybrid_app:app", host="0.0.0.0", port=7860, reload=False) | |