arcsu1's picture
Add web interface for face generation
55ed626
from flask import Flask, jsonify, request, send_file, render_template
from flask_cors import CORS
import numpy as np
from keras.models import load_model
from PIL import Image
import io
app = Flask(__name__)
# Enable CORS for all routes
CORS(app)
# Global variables for model
MODEL_PATH = "./models/face-gen-gan/generator_model_100.h5"
model = None
latent_dim = None
def load_gan_model():
"""Load the GAN model"""
global model, latent_dim
if model is None:
print(f"Loading face generation GAN model from {MODEL_PATH}...")
model = load_model(MODEL_PATH)
latent_dim = model.input_shape[1]
print(f"Model loaded successfully! Latent dimension: {latent_dim}")
# Load model on startup
load_gan_model()
@app.route("/")
def index():
"""Serve the web interface"""
return render_template('index.html')
@app.route("/api")
def root():
return jsonify({
"message": "Face Generator API",
"status": "running",
"model": "face-gen-gan",
"latent_dim": latent_dim
})
@app.route("/health")
def health():
return jsonify({
"status": "healthy",
"model_loaded": model is not None,
"latent_dim": latent_dim
})
@app.route("/generate", methods=["POST"])
def generate_faces():
"""
Generate face images using the GAN model
Returns a PNG image (single face or grid of faces)
"""
if model is None:
return jsonify({"error": "Model not loaded"}), 500
try:
# Get request data
data = request.get_json() or {}
n_samples = data.get("n_samples", 1)
seed = data.get("seed", None)
# Validate n_samples
n_samples = max(1, min(int(n_samples), 16)) # Limit to 1-16
# Set seed if provided
if seed is not None:
np.random.seed(int(seed))
# Generate random latent points
latent_points = np.random.randn(n_samples, latent_dim)
# Generate images
generated_images = model.predict(latent_points, verbose=0)
# Scale from [-1, 1] to [0, 255]
generated_images = ((generated_images + 1) / 2.0 * 255).astype(np.uint8)
if n_samples == 1:
# Single image
img = Image.fromarray(generated_images[0])
else:
# Create a grid
grid_size = int(np.ceil(np.sqrt(n_samples)))
img_height, img_width = generated_images.shape[1:3]
# Create blank canvas
grid_img = np.ones((grid_size * img_height, grid_size * img_width, 3), dtype=np.uint8) * 255
# Fill grid with generated images
for i in range(n_samples):
row = i // grid_size
col = i % grid_size
grid_img[row*img_height:(row+1)*img_height,
col*img_width:(col+1)*img_width] = generated_images[i]
img = Image.fromarray(grid_img)
# Convert to bytes
buf = io.BytesIO()
img.save(buf, format='PNG')
buf.seek(0)
return send_file(buf, mimetype='image/png')
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route("/generate-single")
def generate_single_face():
"""
Quick endpoint to generate a single face
"""
seed = request.args.get('seed', None)
if model is None:
return jsonify({"error": "Model not loaded"}), 500
try:
# Set seed if provided
if seed is not None:
np.random.seed(int(seed))
# Generate random latent points
latent_points = np.random.randn(1, latent_dim)
# Generate images
generated_images = model.predict(latent_points, verbose=0)
# Scale from [-1, 1] to [0, 255]
generated_images = ((generated_images + 1) / 2.0 * 255).astype(np.uint8)
# Single image
img = Image.fromarray(generated_images[0])
# Convert to bytes
buf = io.BytesIO()
img.save(buf, format='PNG')
buf.seek(0)
return send_file(buf, mimetype='image/png')
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8002, debug=False)