Spaces:
Running
Running
| from flask import Flask, jsonify, request, send_file, render_template | |
| from flask_cors import CORS | |
| import numpy as np | |
| from keras.models import load_model | |
| from PIL import Image | |
| import io | |
| app = Flask(__name__) | |
| # Enable CORS for all routes | |
| CORS(app) | |
| # Global variables for model | |
| MODEL_PATH = "./models/face-gen-gan/generator_model_100.h5" | |
| model = None | |
| latent_dim = None | |
| def load_gan_model(): | |
| """Load the GAN model""" | |
| global model, latent_dim | |
| if model is None: | |
| print(f"Loading face generation GAN model from {MODEL_PATH}...") | |
| model = load_model(MODEL_PATH) | |
| latent_dim = model.input_shape[1] | |
| print(f"Model loaded successfully! Latent dimension: {latent_dim}") | |
| # Load model on startup | |
| load_gan_model() | |
| def index(): | |
| """Serve the web interface""" | |
| return render_template('index.html') | |
| def root(): | |
| return jsonify({ | |
| "message": "Face Generator API", | |
| "status": "running", | |
| "model": "face-gen-gan", | |
| "latent_dim": latent_dim | |
| }) | |
| def health(): | |
| return jsonify({ | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "latent_dim": latent_dim | |
| }) | |
| def generate_faces(): | |
| """ | |
| Generate face images using the GAN model | |
| Returns a PNG image (single face or grid of faces) | |
| """ | |
| if model is None: | |
| return jsonify({"error": "Model not loaded"}), 500 | |
| try: | |
| # Get request data | |
| data = request.get_json() or {} | |
| n_samples = data.get("n_samples", 1) | |
| seed = data.get("seed", None) | |
| # Validate n_samples | |
| n_samples = max(1, min(int(n_samples), 16)) # Limit to 1-16 | |
| # Set seed if provided | |
| if seed is not None: | |
| np.random.seed(int(seed)) | |
| # Generate random latent points | |
| latent_points = np.random.randn(n_samples, latent_dim) | |
| # Generate images | |
| generated_images = model.predict(latent_points, verbose=0) | |
| # Scale from [-1, 1] to [0, 255] | |
| generated_images = ((generated_images + 1) / 2.0 * 255).astype(np.uint8) | |
| if n_samples == 1: | |
| # Single image | |
| img = Image.fromarray(generated_images[0]) | |
| else: | |
| # Create a grid | |
| grid_size = int(np.ceil(np.sqrt(n_samples))) | |
| img_height, img_width = generated_images.shape[1:3] | |
| # Create blank canvas | |
| grid_img = np.ones((grid_size * img_height, grid_size * img_width, 3), dtype=np.uint8) * 255 | |
| # Fill grid with generated images | |
| for i in range(n_samples): | |
| row = i // grid_size | |
| col = i % grid_size | |
| grid_img[row*img_height:(row+1)*img_height, | |
| col*img_width:(col+1)*img_width] = generated_images[i] | |
| img = Image.fromarray(grid_img) | |
| # Convert to bytes | |
| buf = io.BytesIO() | |
| img.save(buf, format='PNG') | |
| buf.seek(0) | |
| return send_file(buf, mimetype='image/png') | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def generate_single_face(): | |
| """ | |
| Quick endpoint to generate a single face | |
| """ | |
| seed = request.args.get('seed', None) | |
| if model is None: | |
| return jsonify({"error": "Model not loaded"}), 500 | |
| try: | |
| # Set seed if provided | |
| if seed is not None: | |
| np.random.seed(int(seed)) | |
| # Generate random latent points | |
| latent_points = np.random.randn(1, latent_dim) | |
| # Generate images | |
| generated_images = model.predict(latent_points, verbose=0) | |
| # Scale from [-1, 1] to [0, 255] | |
| generated_images = ((generated_images + 1) / 2.0 * 255).astype(np.uint8) | |
| # Single image | |
| img = Image.fromarray(generated_images[0]) | |
| # Convert to bytes | |
| buf = io.BytesIO() | |
| img.save(buf, format='PNG') | |
| buf.seek(0) | |
| return send_file(buf, mimetype='image/png') | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=8002, debug=False) | |