""" Hybrid FastAPI/Gradio application for image generation with SDXL-Turbo Provides both API endpoints and a Gradio UI """ from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image import torch import os import time import uuid import logging import math from typing import Optional import gradio as gr # FastAPI imports from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks from fastapi.responses import FileResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles # PIL for image processing from PIL import Image # Try to import Intel extensions if available try: import intel_extension_for_pytorch as ipex except: pass # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("hybrid-app") # Get environment variables SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None) HF_TOKEN = os.environ.get("HF_TOKEN", None) # Initialize FastAPI app app = FastAPI( title="Avatar Generator API", description="Generate avatar images based on pose, shirt, and face inputs using SDXL-Turbo", version="0.1.0" ) # Add CORS middleware to allow the Gradio UI to communicate with the API app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Configure device settings - copied directly from reference implementation # check if MPS is available OSX only M1/M2/M3 chips mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available() device = torch.device( "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu" ) torch_device = device torch_dtype = torch.float16 logger.info(f"SAFETY_CHECKER: {SAFETY_CHECKER}") logger.info(f"TORCH_COMPILE: {TORCH_COMPILE}") logger.info(f"device: {device}") if mps_available: device = torch.device("mps") torch_device = "cpu" torch_dtype = torch.float32 logger.info("MPS available, using MPS device for Apple Silicon") # Global variables to store loaded pipelines i2i_pipe = None t2i_pipe = None def setup_directories(): """Create necessary directories if they don't exist""" dirs = ["/app/input", "/app/output", "/app/.cache"] for dir_path in dirs: os.makedirs(dir_path, exist_ok=True) if not os.access(dir_path, os.W_OK): logger.warning(f"Directory {dir_path} is not writable!") def load_pipelines(): """Load and return the optimized pipelines""" global i2i_pipe, t2i_pipe # Only load if not already loaded if i2i_pipe is not None and t2i_pipe is not None: return i2i_pipe, t2i_pipe logger.info("Setting up model pipelines...") setup_directories() try: # Set model cache directory os.environ["HF_HUB_CACHE"] = "/app/.cache" os.environ["TRANSFORMERS_CACHE"] = "/app/.cache" os.environ["HF_HOME"] = "/app/.cache" # Load pipelines based on SAFETY_CHECKER setting - just like reference implementation if SAFETY_CHECKER == "True": logger.info("Loading pipelines with safety checker") i2i_pipe = AutoPipelineForImage2Image.from_pretrained( "stabilityai/sdxl-turbo", torch_dtype=torch_dtype, variant="fp16" if torch_dtype == torch.float16 else "fp32", ) t2i_pipe = AutoPipelineForText2Image.from_pretrained( "stabilityai/sdxl-turbo", torch_dtype=torch_dtype, variant="fp16" if torch_dtype == torch.float16 else "fp32", ) else: logger.info("Loading pipelines without safety checker") i2i_pipe = AutoPipelineForImage2Image.from_pretrained( "stabilityai/sdxl-turbo", safety_checker=None, torch_dtype=torch_dtype, variant="fp16" if torch_dtype == torch.float16 else "fp32", ) t2i_pipe = AutoPipelineForText2Image.from_pretrained( "stabilityai/sdxl-turbo", safety_checker=None, torch_dtype=torch_dtype, variant="fp16" if torch_dtype == torch.float16 else "fp32", ) # Move to appropriate device t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device) t2i_pipe.set_progress_bar_config(disable=True) i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device) i2i_pipe.set_progress_bar_config(disable=True) # Apply memory optimizations if on CUDA if device.type == "cuda": try: # Try to use xformers for memory efficiency i2i_pipe.enable_xformers_memory_efficient_attention() t2i_pipe.enable_xformers_memory_efficient_attention() logger.info("Enabled xformers memory efficient attention") except Exception as e: logger.info(f"Could not enable xformers: {str(e)}") i2i_pipe.enable_attention_slicing() t2i_pipe.enable_attention_slicing() logger.info("Using attention slicing instead") logger.info(f"Pipelines loaded successfully on {device}") return i2i_pipe, t2i_pipe except Exception as e: logger.error(f"Error loading pipelines: {str(e)}") import traceback logger.error(traceback.format_exc()) raise RuntimeError(f"Failed to load generation pipelines: {str(e)}") def resize_crop(image, width=512, height=512): """Resize and crop image to target size while maintaining aspect ratio""" image = image.convert("RGB") # Get original aspect ratio orig_width, orig_height = image.size orig_aspect = orig_width / orig_height target_aspect = width / height # Determine dimensions for resizing before crop if orig_aspect > target_aspect: # Image is wider than target, resize to match height new_height = height new_width = int(orig_aspect * new_height) else: # Image is taller than target, resize to match width new_width = width new_height = int(new_width / orig_aspect) # Resize with proper filtering image = image.resize((new_width, new_height), Image.BICUBIC) # Center crop to target dimensions left = (new_width - width) // 2 top = (new_height - height) // 2 right = left + width bottom = top + height # Crop and return return image.crop((left, top, right, bottom)) def cleanup_files(files_to_clean): """Background task to clean up temporary files""" for file_path in files_to_clean: try: if os.path.exists(file_path): os.remove(file_path) logger.info(f"Cleaned up file: {file_path}") except Exception as e: logger.error(f"Error cleaning up file {file_path}: {str(e)}") @app.on_event("startup") async def startup_event(): """Initialize the model pipeline on startup""" logger.info("Initializing model pipelines...") try: # Load pipelines in background - will be ready for first request # Note: First request might be slow if pipeline is still loading load_pipelines() logger.info("Pipelines initialized successfully") except Exception as e: logger.error(f"Error initializing pipelines: {str(e)}") # Continue with startup, pipeline will be loaded on first request # # FastAPI Endpoints # @app.post("/generate") async def generate( background_tasks: BackgroundTasks, prompt: str = Form(...), name: str = Form(...), role: str = Form(...), pose_image: UploadFile = File(...), shirt_image: UploadFile = File(...), face_image: Optional[UploadFile] = File(None), steps: Optional[int] = Form(2), # Default to 2 steps like in the reference guidance_scale: Optional[float] = Form(0.0), # Default to 0.0 like in the reference strength: Optional[float] = Form(0.7), # Default to 0.7 like in the reference width: Optional[int] = Form(512), # Default to 512 like in the reference height: Optional[int] = Form(512), # Default to 512 like in the reference seed: Optional[int] = Form(None) ): """ Generate an avatar image based on the provided inputs Args: prompt: Text prompt describing the desired image name: Person name (for filename) role: Role/job (for styling and filename) pose_image: Image file for pose reference shirt_image: Image file for shirt reference face_image: Optional image file for face reference steps: Number of inference steps guidance_scale: How much to weigh the prompt strength: How much to transform the pose image (0.0 to 1.0) width: Output image width height: Output image height seed: Random seed for reproducibility Returns: The generated image """ # Validate inputs if not prompt or len(prompt) < 3: raise HTTPException(status_code=400, detail="Prompt must be at least 3 characters") if not name or len(name) < 2: raise HTTPException(status_code=400, detail="Name must be at least 2 characters") if not role or len(role) < 2: raise HTTPException(status_code=400, detail="Role must be at least 2 characters") # Create unique filenames for uploaded files files_to_clean = [] try: # Save uploaded files with unique names input_dir = "/app/input" pose_path = os.path.join(input_dir, f"pose_{uuid.uuid4()}.jpg") with open(pose_path, "wb") as f: f.write(await pose_image.read()) files_to_clean.append(pose_path) shirt_path = os.path.join(input_dir, f"shirt_{uuid.uuid4()}.jpg") with open(shirt_path, "wb") as f: f.write(await shirt_image.read()) files_to_clean.append(shirt_path) face_path = None if face_image: face_path = os.path.join(input_dir, f"face_{uuid.uuid4()}.jpg") with open(face_path, "wb") as f: f.write(await face_image.read()) files_to_clean.append(face_path) # Generate the image using the common generation function output_path = generate_image_internal( prompt=prompt, name=name, role=role, pose_path=pose_path, shirt_path=shirt_path, face_path=face_path, steps=steps, guidance_scale=guidance_scale, strength=strength, width=width, height=height, seed=seed ) # Schedule cleanup of temporary files background_tasks.add_task(cleanup_files, files_to_clean) # Return the generated image return FileResponse( output_path, media_type="image/png", filename=f"{role}_{name}.png" ) except HTTPException: # Clean up files if there was a validation error cleanup_files(files_to_clean) raise except Exception as e: # Clean up files and return error logger.error(f"Error in generate endpoint: {str(e)}") cleanup_files(files_to_clean) raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}") @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy"} @app.get("/") async def root(): """Root endpoint - redirects to docs""" return {"message": "Welcome to Avatar Generator API", "docs": "/docs", "ui": "/ui"} # # Internal processing functions # def generate_image_internal( prompt: str, name: str, role: str, pose_path: str, shirt_path: str, face_path: Optional[str] = None, steps: int = 2, guidance_scale: float = 0.0, strength: float = 0.7, width: int = 512, height: int = 512, seed: Optional[int] = None ): """ Internal function to generate an image based on the provided inputs This is used by both the API endpoint and the Gradio interface Returns: Path to the generated image """ # Load pipelines if not already loaded i2i_pipe, _ = load_pipelines() # Process pose image for use with image-to-image pose_image = Image.open(pose_path).convert("RGB") pose_image = resize_crop(pose_image, width, height) # Analyze shirt for color information to enhance prompt shirt_image = Image.open(shirt_path).convert("RGB") shirt_colors = shirt_image.resize((1, 1)).getpixel((0, 0)) color_text = f"wearing {role} clothes in color similar to RGB({shirt_colors[0]},{shirt_colors[1]},{shirt_colors[2]})" # Enhance prompt with role and color information enhanced_prompt = f"{prompt}. {name} as {role} style, highly detailed, {color_text}" logger.info(f"Enhanced prompt: {enhanced_prompt}") # Ensure valid strength and steps if int(steps * strength) < 1: steps = math.ceil(1 / max(0.10, strength)) logger.info(f"Adjusted steps to {steps} to ensure at least one denoising step") # Set seed for reproducibility if seed is None: seed = int(time.time()) generator = torch.Generator(device=i2i_pipe.device).manual_seed(seed) logger.info(f"Starting generation with seed {seed}") start_time = time.time() # Use image-to-image pipeline with the pose image as input results = i2i_pipe( prompt=enhanced_prompt, image=pose_image, generator=generator, num_inference_steps=steps, guidance_scale=guidance_scale, strength=strength, width=width, height=height, output_type="pil", ) generation_time = time.time() - start_time logger.info(f"Image generated in {generation_time:.2f} seconds") # Check for NSFW content nsfw_content_detected = ( results.nsfw_content_detected[0] if "nsfw_content_detected" in results else False ) if nsfw_content_detected: logger.warning("NSFW content detected, returning placeholder image") image = Image.new("RGB", (width, height), color=(100, 100, 100)) else: image = results.images[0] # Create output file name and path output_filename = f"{role}_{name}_{seed}.png" output_path = os.path.join("/app/output", output_filename) # Save the generated image image.save(output_path) logger.info(f"Saved output image to {output_path}") return output_path # # Gradio UI # # Function for Gradio's prediction interface def gradio_predict( pose_image, prompt, name, role, shirt_image, face_image=None, strength=0.7, steps=2, seed=None ): """ Generate an image for the Gradio UI Args match the generate_image_internal function but adapted for Gradio's interface """ if pose_image is None or shirt_image is None: return None if not prompt or len(prompt) < 3: raise gr.Error("Prompt must be at least 3 characters") if not name or len(name) < 2: raise gr.Error("Name must be at least 2 characters") if not role or len(role) < 2: raise gr.Error("Role must be at least 2 characters") # Save images to temporary files input_dir = "/app/input" os.makedirs(input_dir, exist_ok=True) # Create file paths pose_path = os.path.join(input_dir, f"pose_gradio_{uuid.uuid4()}.jpg") shirt_path = os.path.join(input_dir, f"shirt_gradio_{uuid.uuid4()}.jpg") face_path = None files_to_clean = [pose_path, shirt_path] try: # Save images pose_image.save(pose_path) shirt_image.save(shirt_path) if face_image is not None: face_path = os.path.join(input_dir, f"face_gradio_{uuid.uuid4()}.jpg") face_image.save(face_path) files_to_clean.append(face_path) # Generate image if seed is None or seed == 0: seed = int(time.time()) output_path = generate_image_internal( prompt=prompt, name=name, role=role, pose_path=pose_path, shirt_path=shirt_path, face_path=face_path, steps=steps, guidance_scale=0.0, # Fixed at 0.0 for SDXL Turbo as in reference strength=strength, width=512, # Fixed at 512 for Gradio UI as in reference height=512, # Fixed at 512 for Gradio UI as in reference seed=seed ) # Load the image to return to Gradio result_image = Image.open(output_path) # Clean up temporary files in background cleanup_files(files_to_clean) return result_image except Exception as e: # Clean up files and propagate error cleanup_files(files_to_clean) raise gr.Error(f"Error generating image: {str(e)}") # Create the Gradio interface def create_gradio_interface(): """Create and return the Gradio interface""" css = """ #container{ margin: 0 auto; max-width: 80rem; } #intro{ max-width: 100%; text-align: center; margin: 0 auto; } """ with gr.Blocks(css=css) as demo: gr.Markdown( """# Avatar Generator ## Generate customized avatars based on pose, shirt, and prompts Upload a pose image and a shirt reference, then describe the avatar you want to generate. """, elem_id="intro", ) with gr.Row(): with gr.Column(): # Input elements pose_image = gr.Image(label="Pose Reference", type="pil", sources=["upload", "webcam", "clipboard"]) shirt_image = gr.Image(label="Shirt Reference", type="pil", sources=["upload", "webcam", "clipboard"]) face_image = gr.Image(label="Face Reference (Optional)", type="pil", sources=["upload", "webcam", "clipboard"]) prompt = gr.Textbox(label="Prompt", placeholder="Describe the avatar you want to generate...") name = gr.Textbox(label="Name", placeholder="Enter name...") role = gr.Textbox(label="Role/Profession", placeholder="Enter role (e.g., doctor, engineer)...") with gr.Accordion("Advanced Options", open=False): strength = gr.Slider( label="Strength", minimum=0.1, maximum=1.0, value=0.7, step=0.05, info="How much to transform the pose image (higher = more creative)" ) steps = gr.Slider( label="Steps", minimum=1, maximum=10, value=2, step=1, info="Number of denoising steps (higher = more detail but slower)" ) seed = gr.Slider( label="Seed", minimum=0, maximum=9999999999, value=0, step=1, info="Random seed for reproducibility (0 = random)" ) generate_button = gr.Button("Generate Avatar", variant="primary") with gr.Column(): # Output image output_image = gr.Image(label="Generated Avatar", type="pil") # Example inputs examples = [ [ None, # Will be filled with pose image "Professional portrait, high quality", "John", "Doctor", None, # Will be filled with shirt image None, # No face image 0.7, # Strength 2, # Steps 42 # Seed ] ] # Event handlers generate_button.click( fn=gradio_predict, inputs=[pose_image, prompt, name, role, shirt_image, face_image, strength, steps, seed], outputs=output_image ) # Examples don't work well with image inputs, so commenting out for now # gr.Examples( # examples=examples, # inputs=[pose_image, prompt, name, role, shirt_image, face_image, strength, steps, seed], # outputs=output_image, # fn=gradio_predict, # cache_examples=True, # ) return demo # Create and mount the Gradio app at the /ui route gradio_app = create_gradio_interface() app = gr.mount_gradio_app(app, gradio_app, path="/ui") if __name__ == "__main__": import uvicorn uvicorn.run("hybrid_app:app", host="0.0.0.0", port=7860, reload=False)