diffuser_gen / app.py
Defter77's picture
Upload app.py with huggingface_hub
4e646bd verified
"""
Hybrid FastAPI/Gradio application for image generation with SDXL-Turbo
Provides both API endpoints and a Gradio UI
"""
from diffusers import AutoPipelineForImage2Image, AutoPipelineForText2Image
import torch
import os
import time
import uuid
import logging
import math
from typing import Optional
import gradio as gr
# FastAPI imports
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
# PIL for image processing
from PIL import Image
# Try to import Intel extensions if available
try:
import intel_extension_for_pytorch as ipex
except:
pass
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("hybrid-app")
# Get environment variables
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None)
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None)
HF_TOKEN = os.environ.get("HF_TOKEN", None)
# Initialize FastAPI app
app = FastAPI(
title="Avatar Generator API",
description="Generate avatar images based on pose, shirt, and face inputs using SDXL-Turbo",
version="0.1.0"
)
# Add CORS middleware to allow the Gradio UI to communicate with the API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configure device settings - copied directly from reference implementation
# check if MPS is available OSX only M1/M2/M3 chips
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
device = torch.device(
"cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
)
torch_device = device
torch_dtype = torch.float16
logger.info(f"SAFETY_CHECKER: {SAFETY_CHECKER}")
logger.info(f"TORCH_COMPILE: {TORCH_COMPILE}")
logger.info(f"device: {device}")
if mps_available:
device = torch.device("mps")
torch_device = "cpu"
torch_dtype = torch.float32
logger.info("MPS available, using MPS device for Apple Silicon")
# Global variables to store loaded pipelines
i2i_pipe = None
t2i_pipe = None
def setup_directories():
"""Create necessary directories if they don't exist"""
dirs = ["/app/input", "/app/output", "/app/.cache"]
for dir_path in dirs:
os.makedirs(dir_path, exist_ok=True)
if not os.access(dir_path, os.W_OK):
logger.warning(f"Directory {dir_path} is not writable!")
def load_pipelines():
"""Load and return the optimized pipelines"""
global i2i_pipe, t2i_pipe
# Only load if not already loaded
if i2i_pipe is not None and t2i_pipe is not None:
return i2i_pipe, t2i_pipe
logger.info("Setting up model pipelines...")
setup_directories()
try:
# Set model cache directory
os.environ["HF_HUB_CACHE"] = "/app/.cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/.cache"
os.environ["HF_HOME"] = "/app/.cache"
# Load pipelines based on SAFETY_CHECKER setting - just like reference implementation
if SAFETY_CHECKER == "True":
logger.info("Loading pipelines with safety checker")
i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
t2i_pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
else:
logger.info("Loading pipelines without safety checker")
i2i_pipe = AutoPipelineForImage2Image.from_pretrained(
"stabilityai/sdxl-turbo",
safety_checker=None,
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
t2i_pipe = AutoPipelineForText2Image.from_pretrained(
"stabilityai/sdxl-turbo",
safety_checker=None,
torch_dtype=torch_dtype,
variant="fp16" if torch_dtype == torch.float16 else "fp32",
)
# Move to appropriate device
t2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
t2i_pipe.set_progress_bar_config(disable=True)
i2i_pipe.to(device=torch_device, dtype=torch_dtype).to(device)
i2i_pipe.set_progress_bar_config(disable=True)
# Apply memory optimizations if on CUDA
if device.type == "cuda":
try:
# Try to use xformers for memory efficiency
i2i_pipe.enable_xformers_memory_efficient_attention()
t2i_pipe.enable_xformers_memory_efficient_attention()
logger.info("Enabled xformers memory efficient attention")
except Exception as e:
logger.info(f"Could not enable xformers: {str(e)}")
i2i_pipe.enable_attention_slicing()
t2i_pipe.enable_attention_slicing()
logger.info("Using attention slicing instead")
logger.info(f"Pipelines loaded successfully on {device}")
return i2i_pipe, t2i_pipe
except Exception as e:
logger.error(f"Error loading pipelines: {str(e)}")
import traceback
logger.error(traceback.format_exc())
raise RuntimeError(f"Failed to load generation pipelines: {str(e)}")
def resize_crop(image, width=512, height=512):
"""Resize and crop image to target size while maintaining aspect ratio"""
image = image.convert("RGB")
# Get original aspect ratio
orig_width, orig_height = image.size
orig_aspect = orig_width / orig_height
target_aspect = width / height
# Determine dimensions for resizing before crop
if orig_aspect > target_aspect:
# Image is wider than target, resize to match height
new_height = height
new_width = int(orig_aspect * new_height)
else:
# Image is taller than target, resize to match width
new_width = width
new_height = int(new_width / orig_aspect)
# Resize with proper filtering
image = image.resize((new_width, new_height), Image.BICUBIC)
# Center crop to target dimensions
left = (new_width - width) // 2
top = (new_height - height) // 2
right = left + width
bottom = top + height
# Crop and return
return image.crop((left, top, right, bottom))
def cleanup_files(files_to_clean):
"""Background task to clean up temporary files"""
for file_path in files_to_clean:
try:
if os.path.exists(file_path):
os.remove(file_path)
logger.info(f"Cleaned up file: {file_path}")
except Exception as e:
logger.error(f"Error cleaning up file {file_path}: {str(e)}")
@app.on_event("startup")
async def startup_event():
"""Initialize the model pipeline on startup"""
logger.info("Initializing model pipelines...")
try:
# Load pipelines in background - will be ready for first request
# Note: First request might be slow if pipeline is still loading
load_pipelines()
logger.info("Pipelines initialized successfully")
except Exception as e:
logger.error(f"Error initializing pipelines: {str(e)}")
# Continue with startup, pipeline will be loaded on first request
#
# FastAPI Endpoints
#
@app.post("/generate")
async def generate(
background_tasks: BackgroundTasks,
prompt: str = Form(...),
name: str = Form(...),
role: str = Form(...),
pose_image: UploadFile = File(...),
shirt_image: UploadFile = File(...),
face_image: Optional[UploadFile] = File(None),
steps: Optional[int] = Form(2), # Default to 2 steps like in the reference
guidance_scale: Optional[float] = Form(0.0), # Default to 0.0 like in the reference
strength: Optional[float] = Form(0.7), # Default to 0.7 like in the reference
width: Optional[int] = Form(512), # Default to 512 like in the reference
height: Optional[int] = Form(512), # Default to 512 like in the reference
seed: Optional[int] = Form(None)
):
"""
Generate an avatar image based on the provided inputs
Args:
prompt: Text prompt describing the desired image
name: Person name (for filename)
role: Role/job (for styling and filename)
pose_image: Image file for pose reference
shirt_image: Image file for shirt reference
face_image: Optional image file for face reference
steps: Number of inference steps
guidance_scale: How much to weigh the prompt
strength: How much to transform the pose image (0.0 to 1.0)
width: Output image width
height: Output image height
seed: Random seed for reproducibility
Returns:
The generated image
"""
# Validate inputs
if not prompt or len(prompt) < 3:
raise HTTPException(status_code=400, detail="Prompt must be at least 3 characters")
if not name or len(name) < 2:
raise HTTPException(status_code=400, detail="Name must be at least 2 characters")
if not role or len(role) < 2:
raise HTTPException(status_code=400, detail="Role must be at least 2 characters")
# Create unique filenames for uploaded files
files_to_clean = []
try:
# Save uploaded files with unique names
input_dir = "/app/input"
pose_path = os.path.join(input_dir, f"pose_{uuid.uuid4()}.jpg")
with open(pose_path, "wb") as f:
f.write(await pose_image.read())
files_to_clean.append(pose_path)
shirt_path = os.path.join(input_dir, f"shirt_{uuid.uuid4()}.jpg")
with open(shirt_path, "wb") as f:
f.write(await shirt_image.read())
files_to_clean.append(shirt_path)
face_path = None
if face_image:
face_path = os.path.join(input_dir, f"face_{uuid.uuid4()}.jpg")
with open(face_path, "wb") as f:
f.write(await face_image.read())
files_to_clean.append(face_path)
# Generate the image using the common generation function
output_path = generate_image_internal(
prompt=prompt,
name=name,
role=role,
pose_path=pose_path,
shirt_path=shirt_path,
face_path=face_path,
steps=steps,
guidance_scale=guidance_scale,
strength=strength,
width=width,
height=height,
seed=seed
)
# Schedule cleanup of temporary files
background_tasks.add_task(cleanup_files, files_to_clean)
# Return the generated image
return FileResponse(
output_path,
media_type="image/png",
filename=f"{role}_{name}.png"
)
except HTTPException:
# Clean up files if there was a validation error
cleanup_files(files_to_clean)
raise
except Exception as e:
# Clean up files and return error
logger.error(f"Error in generate endpoint: {str(e)}")
cleanup_files(files_to_clean)
raise HTTPException(status_code=500, detail=f"Error generating image: {str(e)}")
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "healthy"}
@app.get("/")
async def root():
"""Root endpoint - redirects to docs"""
return {"message": "Welcome to Avatar Generator API", "docs": "/docs", "ui": "/ui"}
#
# Internal processing functions
#
def generate_image_internal(
prompt: str,
name: str,
role: str,
pose_path: str,
shirt_path: str,
face_path: Optional[str] = None,
steps: int = 2,
guidance_scale: float = 0.0,
strength: float = 0.7,
width: int = 512,
height: int = 512,
seed: Optional[int] = None
):
"""
Internal function to generate an image based on the provided inputs
This is used by both the API endpoint and the Gradio interface
Returns:
Path to the generated image
"""
# Load pipelines if not already loaded
i2i_pipe, _ = load_pipelines()
# Process pose image for use with image-to-image
pose_image = Image.open(pose_path).convert("RGB")
pose_image = resize_crop(pose_image, width, height)
# Analyze shirt for color information to enhance prompt
shirt_image = Image.open(shirt_path).convert("RGB")
shirt_colors = shirt_image.resize((1, 1)).getpixel((0, 0))
color_text = f"wearing {role} clothes in color similar to RGB({shirt_colors[0]},{shirt_colors[1]},{shirt_colors[2]})"
# Enhance prompt with role and color information
enhanced_prompt = f"{prompt}. {name} as {role} style, highly detailed, {color_text}"
logger.info(f"Enhanced prompt: {enhanced_prompt}")
# Ensure valid strength and steps
if int(steps * strength) < 1:
steps = math.ceil(1 / max(0.10, strength))
logger.info(f"Adjusted steps to {steps} to ensure at least one denoising step")
# Set seed for reproducibility
if seed is None:
seed = int(time.time())
generator = torch.Generator(device=i2i_pipe.device).manual_seed(seed)
logger.info(f"Starting generation with seed {seed}")
start_time = time.time()
# Use image-to-image pipeline with the pose image as input
results = i2i_pipe(
prompt=enhanced_prompt,
image=pose_image,
generator=generator,
num_inference_steps=steps,
guidance_scale=guidance_scale,
strength=strength,
width=width,
height=height,
output_type="pil",
)
generation_time = time.time() - start_time
logger.info(f"Image generated in {generation_time:.2f} seconds")
# Check for NSFW content
nsfw_content_detected = (
results.nsfw_content_detected[0]
if "nsfw_content_detected" in results
else False
)
if nsfw_content_detected:
logger.warning("NSFW content detected, returning placeholder image")
image = Image.new("RGB", (width, height), color=(100, 100, 100))
else:
image = results.images[0]
# Create output file name and path
output_filename = f"{role}_{name}_{seed}.png"
output_path = os.path.join("/app/output", output_filename)
# Save the generated image
image.save(output_path)
logger.info(f"Saved output image to {output_path}")
return output_path
#
# Gradio UI
#
# Function for Gradio's prediction interface
def gradio_predict(
pose_image,
prompt,
name,
role,
shirt_image,
face_image=None,
strength=0.7,
steps=2,
seed=None
):
"""
Generate an image for the Gradio UI
Args match the generate_image_internal function but adapted for Gradio's interface
"""
if pose_image is None or shirt_image is None:
return None
if not prompt or len(prompt) < 3:
raise gr.Error("Prompt must be at least 3 characters")
if not name or len(name) < 2:
raise gr.Error("Name must be at least 2 characters")
if not role or len(role) < 2:
raise gr.Error("Role must be at least 2 characters")
# Save images to temporary files
input_dir = "/app/input"
os.makedirs(input_dir, exist_ok=True)
# Create file paths
pose_path = os.path.join(input_dir, f"pose_gradio_{uuid.uuid4()}.jpg")
shirt_path = os.path.join(input_dir, f"shirt_gradio_{uuid.uuid4()}.jpg")
face_path = None
files_to_clean = [pose_path, shirt_path]
try:
# Save images
pose_image.save(pose_path)
shirt_image.save(shirt_path)
if face_image is not None:
face_path = os.path.join(input_dir, f"face_gradio_{uuid.uuid4()}.jpg")
face_image.save(face_path)
files_to_clean.append(face_path)
# Generate image
if seed is None or seed == 0:
seed = int(time.time())
output_path = generate_image_internal(
prompt=prompt,
name=name,
role=role,
pose_path=pose_path,
shirt_path=shirt_path,
face_path=face_path,
steps=steps,
guidance_scale=0.0, # Fixed at 0.0 for SDXL Turbo as in reference
strength=strength,
width=512, # Fixed at 512 for Gradio UI as in reference
height=512, # Fixed at 512 for Gradio UI as in reference
seed=seed
)
# Load the image to return to Gradio
result_image = Image.open(output_path)
# Clean up temporary files in background
cleanup_files(files_to_clean)
return result_image
except Exception as e:
# Clean up files and propagate error
cleanup_files(files_to_clean)
raise gr.Error(f"Error generating image: {str(e)}")
# Create the Gradio interface
def create_gradio_interface():
"""Create and return the Gradio interface"""
css = """
#container{
margin: 0 auto;
max-width: 80rem;
}
#intro{
max-width: 100%;
text-align: center;
margin: 0 auto;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(
"""# Avatar Generator
## Generate customized avatars based on pose, shirt, and prompts
Upload a pose image and a shirt reference, then describe the avatar you want to generate.
""",
elem_id="intro",
)
with gr.Row():
with gr.Column():
# Input elements
pose_image = gr.Image(label="Pose Reference", type="pil", sources=["upload", "webcam", "clipboard"])
shirt_image = gr.Image(label="Shirt Reference", type="pil", sources=["upload", "webcam", "clipboard"])
face_image = gr.Image(label="Face Reference (Optional)", type="pil", sources=["upload", "webcam", "clipboard"])
prompt = gr.Textbox(label="Prompt", placeholder="Describe the avatar you want to generate...")
name = gr.Textbox(label="Name", placeholder="Enter name...")
role = gr.Textbox(label="Role/Profession", placeholder="Enter role (e.g., doctor, engineer)...")
with gr.Accordion("Advanced Options", open=False):
strength = gr.Slider(
label="Strength",
minimum=0.1,
maximum=1.0,
value=0.7,
step=0.05,
info="How much to transform the pose image (higher = more creative)"
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=10,
value=2,
step=1,
info="Number of denoising steps (higher = more detail but slower)"
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=9999999999,
value=0,
step=1,
info="Random seed for reproducibility (0 = random)"
)
generate_button = gr.Button("Generate Avatar", variant="primary")
with gr.Column():
# Output image
output_image = gr.Image(label="Generated Avatar", type="pil")
# Example inputs
examples = [
[
None, # Will be filled with pose image
"Professional portrait, high quality",
"John",
"Doctor",
None, # Will be filled with shirt image
None, # No face image
0.7, # Strength
2, # Steps
42 # Seed
]
]
# Event handlers
generate_button.click(
fn=gradio_predict,
inputs=[pose_image, prompt, name, role, shirt_image, face_image, strength, steps, seed],
outputs=output_image
)
# Examples don't work well with image inputs, so commenting out for now
# gr.Examples(
# examples=examples,
# inputs=[pose_image, prompt, name, role, shirt_image, face_image, strength, steps, seed],
# outputs=output_image,
# fn=gradio_predict,
# cache_examples=True,
# )
return demo
# Create and mount the Gradio app at the /ui route
gradio_app = create_gradio_interface()
app = gr.mount_gradio_app(app, gradio_app, path="/ui")
if __name__ == "__main__":
import uvicorn
uvicorn.run("hybrid_app:app", host="0.0.0.0", port=7860, reload=False)