Spaces:
Running
Running
| """ | |
| FastAPI Binary Segmentation Service | |
| Hugging Face Space compatible | |
| """ | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import Response, JSONResponse, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import logging | |
| from typing import Literal, Optional | |
| import base64 | |
| import os | |
| from binary_segmentation import BinarySegmenter | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="Binary Segmentation API", | |
| description="Remove background from images using AI models", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files | |
| if os.path.exists("static"): | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Global model instance (lazy loading) | |
| segmenter_cache = {} | |
| def get_segmenter(model_type: str = "u2netp") -> BinarySegmenter: | |
| """Get or create segmenter instance""" | |
| if model_type not in segmenter_cache: | |
| logger.info(f"Loading {model_type} model...") | |
| segmenter_cache[model_type] = BinarySegmenter(model_type=model_type) | |
| logger.info(f"{model_type} model loaded successfully") | |
| return segmenter_cache[model_type] | |
| async def root(): | |
| """Serve the web interface""" | |
| if os.path.exists("static/index.html"): | |
| return FileResponse("static/index.html") | |
| # Fallback to API info | |
| return { | |
| "name": "Binary Segmentation API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "/segment": "POST - Segment image and return PNG with transparency", | |
| "/segment/mask": "POST - Return binary mask only", | |
| "/segment/base64": "POST - Return base64 encoded results", | |
| "/health": "GET - Health check", | |
| "/models": "GET - List available models" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "models_loaded": list(segmenter_cache.keys()) | |
| } | |
| async def list_models(): | |
| """List available segmentation models""" | |
| return { | |
| "models": [ | |
| { | |
| "name": "u2netp", | |
| "description": "Lightweight, fast model (1.1M params)", | |
| "speed": "⚡⚡⚡", | |
| "accuracy": "⭐⭐", | |
| "size": "4.7 MB" | |
| }, | |
| { | |
| "name": "birefnet", | |
| "description": "High accuracy model", | |
| "speed": "⚡", | |
| "accuracy": "⭐⭐⭐", | |
| "size": "~400 MB", | |
| "requires": "transformers package" | |
| }, | |
| { | |
| "name": "rmbg", | |
| "description": "Balanced model", | |
| "speed": "⚡⚡", | |
| "accuracy": "⭐⭐⭐", | |
| "size": "~200 MB", | |
| "requires": "transformers package" | |
| } | |
| ], | |
| "default": "u2netp" | |
| } | |
| async def segment_image( | |
| file: UploadFile = File(..., description="Image file to segment"), | |
| model: str = Form("u2netp", description="Model to use: u2netp, birefnet, or rmbg"), | |
| threshold: float = Form(0.5, description="Segmentation threshold (0.0-1.0)", ge=0.0, le=1.0) | |
| ): | |
| """ | |
| Segment image and return PNG with transparent background. | |
| Returns: PNG image with transparency | |
| """ | |
| try: | |
| # Validate model | |
| if model not in ["u2netp", "birefnet", "rmbg"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid model: {model}. Choose from: u2netp, birefnet, rmbg" | |
| ) | |
| # Read image | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="Invalid image file") | |
| # Get segmenter | |
| segmenter = get_segmenter(model) | |
| # Segment image | |
| logger.info(f"Segmenting with model={model}, threshold={threshold}") | |
| _, rgba = segmenter.segment(image, threshold=threshold, return_type="rgba") | |
| if rgba is None: | |
| raise HTTPException(status_code=500, detail="Segmentation failed") | |
| # Convert to bytes | |
| img_byte_arr = io.BytesIO() | |
| rgba.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) | |
| logger.info("Segmentation successful") | |
| return Response( | |
| content=img_byte_arr.getvalue(), | |
| media_type="image/png", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=segmented_{file.filename}" | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in segmentation: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def segment_mask( | |
| file: UploadFile = File(..., description="Image file to segment"), | |
| model: str = Form("u2netp", description="Model to use"), | |
| threshold: float = Form(0.5, description="Segmentation threshold (0.0-1.0)", ge=0.0, le=1.0) | |
| ): | |
| """ | |
| Segment image and return binary mask only. | |
| Returns: PNG image (binary mask - black and white) | |
| """ | |
| try: | |
| # Validate model | |
| if model not in ["u2netp", "birefnet", "rmbg"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid model: {model}. Choose from: u2netp, birefnet, rmbg" | |
| ) | |
| # Read image | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="Invalid image file") | |
| # Get segmenter | |
| segmenter = get_segmenter(model) | |
| # Segment image | |
| logger.info(f"Generating mask with model={model}, threshold={threshold}") | |
| mask, _ = segmenter.segment(image, threshold=threshold, return_type="mask") | |
| if mask is None: | |
| raise HTTPException(status_code=500, detail="Segmentation failed") | |
| # Convert to PNG | |
| _, buffer = cv2.imencode('.png', mask) | |
| logger.info("Mask generation successful") | |
| return Response( | |
| content=buffer.tobytes(), | |
| media_type="image/png", | |
| headers={ | |
| "Content-Disposition": f"attachment; filename=mask_{file.filename}" | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in mask generation: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def segment_base64( | |
| file: UploadFile = File(..., description="Image file to segment"), | |
| model: str = Form("u2netp", description="Model to use"), | |
| threshold: float = Form(0.5, description="Segmentation threshold (0.0-1.0)", ge=0.0, le=1.0), | |
| return_type: str = Form("rgba", description="Return type: rgba, mask, or both") | |
| ): | |
| """ | |
| Segment image and return base64 encoded results. | |
| Returns: JSON with base64 encoded images | |
| """ | |
| try: | |
| # Validate inputs | |
| if model not in ["u2netp", "birefnet", "rmbg"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid model: {model}. Choose from: u2netp, birefnet, rmbg" | |
| ) | |
| if return_type not in ["rgba", "mask", "both"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid return_type: {return_type}. Choose from: rgba, mask, both" | |
| ) | |
| # Read image | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if image is None: | |
| raise HTTPException(status_code=400, detail="Invalid image file") | |
| # Get segmenter | |
| segmenter = get_segmenter(model) | |
| # Segment image | |
| logger.info(f"Segmenting (base64) with model={model}, threshold={threshold}, return_type={return_type}") | |
| mask, rgba = segmenter.segment(image, threshold=threshold, return_type=return_type) | |
| # Prepare response | |
| response = { | |
| "success": True, | |
| "model": model, | |
| "threshold": threshold | |
| } | |
| # Encode mask if requested | |
| if return_type in ["mask", "both"] and mask is not None: | |
| _, buffer = cv2.imencode('.png', mask) | |
| mask_base64 = base64.b64encode(buffer).decode('utf-8') | |
| response["mask"] = f"data:image/png;base64,{mask_base64}" | |
| # Encode RGBA if requested | |
| if return_type in ["rgba", "both"] and rgba is not None: | |
| img_byte_arr = io.BytesIO() | |
| rgba.save(img_byte_arr, format='PNG') | |
| rgba_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') | |
| response["rgba"] = f"data:image/png;base64,{rgba_base64}" | |
| logger.info("Base64 encoding successful") | |
| return JSONResponse(content=response) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in base64 encoding: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def segment_batch( | |
| files: list[UploadFile] = File(..., description="Multiple image files"), | |
| model: str = Form("u2netp", description="Model to use"), | |
| threshold: float = Form(0.5, description="Segmentation threshold (0.0-1.0)", ge=0.0, le=1.0) | |
| ): | |
| """ | |
| Segment multiple images and return base64 encoded results. | |
| Returns: JSON with array of base64 encoded images | |
| """ | |
| try: | |
| # Validate model | |
| if model not in ["u2netp", "birefnet", "rmbg"]: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid model: {model}. Choose from: u2netp, birefnet, rmbg" | |
| ) | |
| # Limit batch size | |
| if len(files) > 10: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Maximum batch size is 10 images" | |
| ) | |
| # Get segmenter | |
| segmenter = get_segmenter(model) | |
| results = [] | |
| for idx, file in enumerate(files): | |
| try: | |
| # Read image | |
| contents = await file.read() | |
| nparr = np.frombuffer(contents, np.uint8) | |
| image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| if image is None: | |
| results.append({ | |
| "filename": file.filename, | |
| "success": False, | |
| "error": "Invalid image file" | |
| }) | |
| continue | |
| # Segment | |
| logger.info(f"Processing batch image {idx+1}/{len(files)}: {file.filename}") | |
| _, rgba = segmenter.segment(image, threshold=threshold, return_type="rgba") | |
| # Encode to base64 | |
| img_byte_arr = io.BytesIO() | |
| rgba.save(img_byte_arr, format='PNG') | |
| rgba_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') | |
| results.append({ | |
| "filename": file.filename, | |
| "success": True, | |
| "rgba": f"data:image/png;base64,{rgba_base64}" | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error processing {file.filename}: {e}") | |
| results.append({ | |
| "filename": file.filename, | |
| "success": False, | |
| "error": str(e) | |
| }) | |
| logger.info(f"Batch processing complete: {len(results)} images") | |
| return JSONResponse(content={ | |
| "total": len(files), | |
| "results": results, | |
| "model": model, | |
| "threshold": threshold | |
| }) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in batch processing: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # For local development | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=7860, | |
| reload=True | |
| ) | |