import os import io import torch from torch import nn from PIL import Image import torchvision.utils as vutils from fastapi import FastAPI, Response, HTTPException, Query from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware from huggingface_hub import hf_hub_download, login from models import Generator app = FastAPI() # CORS Configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Configuration Constants Z_DIM = 100 DEVICE = torch.device("cpu") REPO_ID = "SaniaE/GeoGen" FILENAME = "dcgans_model_checkpoint.pt" gen_model = None @app.on_event("startup") def load_model(): global gen_model try: token = os.getenv("HF_Token") if token: login(token=token) print("Login successful.") else: print("No HF_TOKEN found - attempting public download.") model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME, token=token) print(f"File downloaded to: {model_path}") checkpoint = torch.load(model_path, map_location=DEVICE) gen_model = Generator(z_dim=Z_DIM).to(DEVICE) missing, unexpected = gen_model.load_state_dict( checkpoint["gen_state_dict"], strict=False ) print("Unexpected keys: ", unexpected) print("Missing keys: ", missing) gen_model.eval() print("SUCCESS: Petrol Pump GAN is live!") except Exception as e: print(f"Error loading model: {e}") def postprocess_image(tensor): # Unnormalize: tanh output [-1, 1] -> [0, 1] img_tensor = (tensor + 1) / 2 img_tensor = img_tensor.clamp(0, 1) grid = vutils.make_grid(img_tensor, padding=0, normalize=False) ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() return Image.fromarray(ndarr) def get_image_stream(tensor): """Helper to convert tensor to a streaming-ready PNG.""" img_tensor = (tensor + 1) / 2 img_tensor = img_tensor.clamp(0, 1) grid = vutils.make_grid(img_tensor, padding=0) ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() pil_img = Image.fromarray(ndarr) buf = io.BytesIO() pil_img.save(buf, format="PNG") buf.seek(0) return buf @app.get("/") def read_root(): return {"status": "online", "model": REPO_ID} @app.get("/generate") def generate_random(seed: int = Query(None)): """Endpoint 1: Fixed context generation for a session.""" if gen_model is None: raise HTTPException(status_code=503) # Use the provided session seed or fallback to random active_seed = seed if seed is not None else torch.seed() torch.manual_seed(active_seed) with torch.inference_mode(): noise = torch.randn(1, Z_DIM, device=DEVICE) fake_img = gen_model(noise) return StreamingResponse(get_image_stream(fake_img), media_type="image/png") @app.get("/explore") def explore_latent(seed: int, x_shift: float = Query(0.0, ge=-5.0, le=5.0), y_shift: float = Query(0.0, ge=-5.0, le=5.0)): """Endpoint 2: Controlled generation for 'Tuning'.""" if gen_model is None: raise HTTPException(status_code=503) try: with torch.inference_mode(): torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) noise = torch.randn(1, Z_DIM, device=DEVICE) # Structured control noise[:, :10] += x_shift noise[:, 10:20] += y_shift # Random direction direction = torch.randn_like(noise) noise = noise + 0.3 * direction * (abs(x_shift) + abs(y_shift)) print("NOISE:", noise[0, :5]) fake_img = gen_model(noise) return StreamingResponse(get_image_stream(fake_img), media_type="image/png") except Exception as e: raise HTTPException(status_code=500, detail=str(e))