GeoGen-API / app.py
SaniaE's picture
added seed during generate
cb1158f verified
import os
import io
import torch
from torch import nn
from PIL import Image
import torchvision.utils as vutils
from fastapi import FastAPI, Response, HTTPException, Query
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import hf_hub_download, login
from models import Generator
app = FastAPI()
# CORS Configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Configuration Constants
Z_DIM = 100
DEVICE = torch.device("cpu")
REPO_ID = "SaniaE/GeoGen"
FILENAME = "dcgans_model_checkpoint.pt"
gen_model = None
@app.on_event("startup")
def load_model():
global gen_model
try:
token = os.getenv("HF_Token")
if token:
login(token=token)
print("Login successful.")
else:
print("No HF_TOKEN found - attempting public download.")
model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token)
print(f"File downloaded to: {model_path}")
checkpoint = torch.load(model_path, map_location=DEVICE)
gen_model = Generator(z_dim=Z_DIM).to(DEVICE)
missing, unexpected = gen_model.load_state_dict(
checkpoint["gen_state_dict"], strict=False
)
print("Unexpected keys: ", unexpected)
print("Missing keys: ", missing)
gen_model.eval()
print("SUCCESS: Petrol Pump GAN is live!")
except Exception as e:
print(f"Error loading model: {e}")
def postprocess_image(tensor):
# Unnormalize: tanh output [-1, 1] -> [0, 1]
img_tensor = (tensor + 1) / 2
img_tensor = img_tensor.clamp(0, 1)
grid = vutils.make_grid(img_tensor, padding=0, normalize=False)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
return Image.fromarray(ndarr)
def get_image_stream(tensor):
"""Helper to convert tensor to a streaming-ready PNG."""
img_tensor = (tensor + 1) / 2
img_tensor = img_tensor.clamp(0, 1)
grid = vutils.make_grid(img_tensor, padding=0)
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
pil_img = Image.fromarray(ndarr)
buf = io.BytesIO()
pil_img.save(buf, format="PNG")
buf.seek(0)
return buf
@app.get("/")
def read_root():
return {"status": "online", "model": REPO_ID}
@app.get("/generate")
def generate_random(seed: int = Query(None)):
"""Endpoint 1: Fixed context generation for a session."""
if gen_model is None: raise HTTPException(status_code=503)
# Use the provided session seed or fallback to random
active_seed = seed if seed is not None else torch.seed()
torch.manual_seed(active_seed)
with torch.inference_mode():
noise = torch.randn(1, Z_DIM, device=DEVICE)
fake_img = gen_model(noise)
return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
@app.get("/explore")
def explore_latent(seed: int, x_shift: float = Query(0.0, ge=-5.0, le=5.0), y_shift: float = Query(0.0, ge=-5.0, le=5.0)):
"""Endpoint 2: Controlled generation for 'Tuning'."""
if gen_model is None: raise HTTPException(status_code=503)
try:
with torch.inference_mode():
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
noise = torch.randn(1, Z_DIM, device=DEVICE)
# Structured control
noise[:, :10] += x_shift
noise[:, 10:20] += y_shift
# Random direction
direction = torch.randn_like(noise)
noise = noise + 0.3 * direction * (abs(x_shift) + abs(y_shift))
print("NOISE:", noise[0, :5])
fake_img = gen_model(noise)
return StreamingResponse(get_image_stream(fake_img), media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))