Spaces:
Running on Zero
Running on Zero
| import os | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| # Optional heavy deps are imported lazily inside the loaders so the app boots | |
| # even when one stack is missing. A missing dep surfaces as a clear error on the | |
| # method that needs it -- never a silent fallback to a different method. | |
| try: | |
| import cv2 | |
| _CV2_AVAILABLE = True | |
| except ImportError: | |
| cv2 = None | |
| _CV2_AVAILABLE = False | |
| try: | |
| import spaces | |
| GPU = spaces.GPU | |
| except Exception: | |
| def GPU(fn): | |
| return fn | |
| # ============================================================================= | |
| # CONSTANTS | |
| # ============================================================================= | |
| # --- Method registry --- | |
| M_SAUVOLA = "Sauvola (classical)" | |
| M_NIBLACK = "Niblack (classical)" | |
| M_OTSU = "Otsu (classical)" | |
| M_ADAPTIVE = "Adaptive Gaussian (classical)" | |
| M_B5 = "Tzefa b5 — MAnet/mit_b5 (neural)" | |
| M_SBB = "SBB ResNet50-UNet (neural)" | |
| METHODS = [M_SAUVOLA, M_NIBLACK, M_OTSU, M_ADAPTIVE, M_B5, M_SBB] | |
| # --- Sauvola params --- | |
| SAUVOLA_WINDOW_MIN, SAUVOLA_WINDOW_MAX, SAUVOLA_WINDOW_STEP, SAUVOLA_WINDOW_DEFAULT = 3, 99, 2, 25 | |
| SAUVOLA_K_MIN, SAUVOLA_K_MAX, SAUVOLA_K_STEP, SAUVOLA_K_DEFAULT = 0.0, 1.0, 0.01, 0.2 | |
| SAUVOLA_R_MIN, SAUVOLA_R_MAX, SAUVOLA_R_STEP, SAUVOLA_R_DEFAULT = 1, 256, 1, 128 | |
| # --- Niblack params --- | |
| NIBLACK_WINDOW_MIN, NIBLACK_WINDOW_MAX, NIBLACK_WINDOW_STEP, NIBLACK_WINDOW_DEFAULT = 3, 99, 2, 25 | |
| NIBLACK_K_MIN, NIBLACK_K_MAX, NIBLACK_K_STEP, NIBLACK_K_DEFAULT = -1.0, 1.0, 0.01, -0.2 | |
| # --- Otsu params --- | |
| OTSU_BLUR_MIN, OTSU_BLUR_MAX, OTSU_BLUR_STEP, OTSU_BLUR_DEFAULT = 0, 31, 1, 5 # Gaussian pre-blur kernel; 0 disables | |
| # --- Adaptive Gaussian params --- | |
| ADAPTIVE_BLOCK_MIN, ADAPTIVE_BLOCK_MAX, ADAPTIVE_BLOCK_STEP, ADAPTIVE_BLOCK_DEFAULT = 3, 99, 2, 31 | |
| ADAPTIVE_C_MIN, ADAPTIVE_C_MAX, ADAPTIVE_C_STEP, ADAPTIVE_C_DEFAULT = -30, 30, 1, 10 | |
| # --- Tzefa b5 (neural) --- | |
| B5_REPO = "WARAJA/b5_model" | |
| B5_WEIGHTS_FILE = "b5_model.pth" | |
| B5_ENCODER = "mit_b5" | |
| B5_TILE_SIZE = 640 | |
| B5_DECODER_CHANNELS = (256, 128, 64, 32, 16) | |
| # ImageNet normalisation used at training time. | |
| IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) | |
| IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) | |
| THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEP, THRESHOLD_DEFAULT = 0.01, 0.99, 0.01, 0.5 | |
| # --- SBB (neural) --- | |
| SBB_REPO = "SBB/sbb_binarization" | |
| SBB_DEFAULT_PATCH = 256 # used only if the model has no fixed input size | |
| SBB_OVERLAP_FRACTION = 0.25 # patch overlap to suppress seam artifacts | |
| SBB_PAD_VALUE = 255 # reflect-free white pad for the foreground class | |
| # ============================================================================= | |
| # MODEL ARCHITECTURE (Tzefa b5) — mirrors the author's reference Space | |
| # ============================================================================= | |
| def _build_highres_manet(): | |
| # Imported here (not at module top) so the app boots and classical methods | |
| # work even when torch is absent. `forward` closes over this `torch`. | |
| import torch | |
| import torch.nn as nn | |
| import segmentation_models_pytorch as smp # noqa: F401 (pulls timm encoders) | |
| class HighResMAnet(nn.Module): | |
| def __init__(self, encoder_name=B5_ENCODER, encoder_weights=None, classes=1): | |
| super().__init__() | |
| self.base_model = smp.MAnet( | |
| encoder_name=encoder_name, | |
| encoder_weights=encoder_weights, | |
| in_channels=3, | |
| classes=classes, | |
| encoder_depth=5, | |
| decoder_channels=B5_DECODER_CHANNELS, | |
| ) | |
| self.high_res_stem = nn.Sequential( | |
| nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1), | |
| nn.BatchNorm2d(16), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.final_fusion = nn.Sequential( | |
| nn.Conv2d(16 + 32, 16, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(16, classes, kernel_size=1), | |
| ) | |
| def forward(self, x): | |
| high_res_features = self.high_res_stem(x) | |
| features = self.base_model.encoder(x) | |
| decoder_output = self.base_model.decoder(features) | |
| combined = torch.cat([decoder_output, high_res_features], dim=1) | |
| return self.final_fusion(combined) | |
| return HighResMAnet() | |
| # ============================================================================= | |
| # MODEL CACHES & LOADERS | |
| # ============================================================================= | |
| _b5_model = None | |
| _b5_device = None | |
| _sbb_model = None | |
| def _load_b5(): | |
| """Load the gated Tzefa b5 checkpoint. Raises on any failure (no fallback).""" | |
| global _b5_model, _b5_device | |
| if _b5_model is not None: | |
| return _b5_model, _b5_device | |
| token = os.environ.get("HF_TOKEN") | |
| if not token: | |
| raise RuntimeError( | |
| f"{B5_REPO} is gated. Set the HF_TOKEN secret (with access granted) to run this model." | |
| ) | |
| try: | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| except ImportError as e: | |
| raise RuntimeError(f"Tzefa b5 needs torch + huggingface_hub: {e}") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| weights = hf_hub_download(repo_id=B5_REPO, filename=B5_WEIGHTS_FILE, token=token, repo_type="model") | |
| model = _build_highres_manet() | |
| ckpt = torch.load(weights, map_location=device) | |
| state_dict = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt | |
| model.load_state_dict(state_dict) | |
| model = model.to(device).eval() | |
| _b5_model, _b5_device = model, device | |
| return _b5_model, _b5_device | |
| def _load_sbb(): | |
| """Load the SBB tf-keras SavedModel. Raises on any failure (no fallback).""" | |
| global _sbb_model | |
| if _sbb_model is not None: | |
| return _sbb_model | |
| try: | |
| import tensorflow as tf | |
| from huggingface_hub import snapshot_download | |
| except ImportError as e: | |
| raise RuntimeError(f"SBB needs tensorflow + huggingface_hub: {e}") | |
| local_dir = snapshot_download(repo_id=SBB_REPO, repo_type="model") | |
| # Locate the directory that actually contains saved_model.pb | |
| saved_dir = None | |
| for root, _dirs, files in os.walk(local_dir): | |
| if "saved_model.pb" in files: | |
| saved_dir = root | |
| break | |
| if saved_dir is None: | |
| raise RuntimeError(f"No saved_model.pb found in {SBB_REPO}") | |
| _sbb_model = tf.keras.models.load_model(saved_dir, compile=False) | |
| return _sbb_model | |
| # ============================================================================= | |
| # CLASSICAL METHODS | |
| # ============================================================================= | |
| def _to_gray(image_pil): | |
| return np.array(image_pil.convert("L"), dtype=np.uint8) | |
| def run_sauvola(image_pil, window_size, k, r): | |
| from skimage.filters import threshold_sauvola | |
| gray = _to_gray(image_pil) | |
| window_size = int(window_size) | 1 # force odd | |
| thresh = threshold_sauvola(gray, window_size=window_size, k=float(k), r=float(r)) | |
| binary = (gray > thresh).astype(np.uint8) * 255 | |
| return Image.fromarray(binary) | |
| def run_niblack(image_pil, window_size, k): | |
| from skimage.filters import threshold_niblack | |
| gray = _to_gray(image_pil) | |
| window_size = int(window_size) | 1 | |
| thresh = threshold_niblack(gray, window_size=window_size, k=float(k)) | |
| binary = (gray > thresh).astype(np.uint8) * 255 | |
| return Image.fromarray(binary) | |
| def run_otsu(image_pil, blur_ksize): | |
| gray = _to_gray(image_pil) | |
| blur_ksize = int(blur_ksize) | |
| if not _CV2_AVAILABLE: | |
| raise RuntimeError("Otsu pre-blur path needs opencv (cv2). Install opencv-python-headless.") | |
| if blur_ksize > 0: | |
| blur_ksize |= 1 # odd kernel | |
| gray = cv2.GaussianBlur(gray, (blur_ksize, blur_ksize), 0) | |
| _t, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| return Image.fromarray(binary) | |
| def run_adaptive_gaussian(image_pil, block_size, c): | |
| if not _CV2_AVAILABLE: | |
| raise RuntimeError("Adaptive Gaussian needs opencv (cv2). Install opencv-python-headless.") | |
| gray = _to_gray(image_pil) | |
| block_size = int(block_size) | 1 | |
| binary = cv2.adaptiveThreshold( | |
| gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, int(c) | |
| ) | |
| return Image.fromarray(binary) | |
| # ============================================================================= | |
| # NEURAL METHODS | |
| # ============================================================================= | |
| def _b5_preprocess(tile_pil): | |
| import torch | |
| arr = np.array(tile_pil).astype(np.float32) / 255.0 | |
| arr = (arr - IMAGENET_MEAN) / IMAGENET_STD | |
| return torch.from_numpy(arr.transpose(2, 0, 1)) | |
| def run_b5(image_pil, threshold): | |
| """Tiled inference -> sigmoid foreground probability -> threshold.""" | |
| import torch | |
| import torch.nn.functional as F | |
| model, device = _load_b5() | |
| image_pil = image_pil.convert("RGB") | |
| orig_w, orig_h = image_pil.size | |
| pad_w = (B5_TILE_SIZE - (orig_w % B5_TILE_SIZE)) % B5_TILE_SIZE | |
| pad_h = (B5_TILE_SIZE - (orig_h % B5_TILE_SIZE)) % B5_TILE_SIZE | |
| padded = Image.new("RGB", (orig_w + pad_w, orig_h + pad_h), (255, 255, 255)) | |
| padded.paste(image_pil, (0, 0)) | |
| new_w, new_h = padded.size | |
| prob_map = np.zeros((new_h, new_w), dtype=np.float32) | |
| for y in range(0, new_h, B5_TILE_SIZE): | |
| for x in range(0, new_w, B5_TILE_SIZE): | |
| tile = padded.crop((x, y, x + B5_TILE_SIZE, y + B5_TILE_SIZE)) | |
| inp = _b5_preprocess(tile).unsqueeze(0).to(device).float() | |
| with torch.no_grad(): | |
| logits = model(inp) | |
| if logits.shape[-2:] != (B5_TILE_SIZE, B5_TILE_SIZE): | |
| logits = F.interpolate(logits, size=(B5_TILE_SIZE, B5_TILE_SIZE), mode="bilinear") | |
| prob_map[y:y + B5_TILE_SIZE, x:x + B5_TILE_SIZE] = torch.sigmoid(logits).cpu().numpy()[0, 0] | |
| prob_map = prob_map[:orig_h, :orig_w] | |
| binary = ((prob_map < float(threshold)) * 255).astype(np.uint8) # high fg prob -> black text | |
| return Image.fromarray(binary) | |
| def run_sbb(image_pil, threshold, invert): | |
| """Patch-tiled tf-keras inference -> foreground probability -> threshold.""" | |
| model = _load_sbb() | |
| rgb = np.array(image_pil.convert("RGB"), dtype=np.float32) / 255.0 | |
| H, W = rgb.shape[:2] | |
| # Determine the model's native patch size; fall back to a sane default. | |
| try: | |
| in_shape = model.inputs[0].shape | |
| ph = int(in_shape[1]) if in_shape[1] is not None else SBB_DEFAULT_PATCH | |
| pw = int(in_shape[2]) if in_shape[2] is not None else SBB_DEFAULT_PATCH | |
| except Exception: | |
| ph = pw = SBB_DEFAULT_PATCH | |
| step_h = max(1, int(ph * (1 - SBB_OVERLAP_FRACTION))) | |
| step_w = max(1, int(pw * (1 - SBB_OVERLAP_FRACTION))) | |
| # White-pad so the image fully tiles. | |
| pad_h = (step_h - (H - ph) % step_h) % step_h if H > ph else ph - H | |
| pad_w = (step_w - (W - pw) % step_w) % step_w if W > pw else pw - W | |
| padded = np.ones((H + max(0, pad_h), W + max(0, pad_w), 3), dtype=np.float32) | |
| padded[:H, :W] = rgb | |
| PH, PW = padded.shape[:2] | |
| prob_sum = np.zeros((PH, PW), dtype=np.float32) | |
| count = np.zeros((PH, PW), dtype=np.float32) | |
| for y in range(0, PH - ph + 1, step_h): | |
| for x in range(0, PW - pw + 1, step_w): | |
| patch = padded[y:y + ph, x:x + pw][None, ...] | |
| pred = model.predict(patch, verbose=0)[0] # (h, w, 2) | |
| fg = pred[..., 1] if pred.ndim == 3 and pred.shape[-1] >= 2 else np.squeeze(pred) | |
| prob_sum[y:y + ph, x:x + pw] += fg | |
| count[y:y + ph, x:x + pw] += 1.0 | |
| count[count == 0] = 1.0 | |
| prob_map = (prob_sum / count)[:H, :W] | |
| if invert: | |
| prob_map = 1.0 - prob_map | |
| binary = ((prob_map < float(threshold)) * 255).astype(np.uint8) | |
| return Image.fromarray(binary) | |
| # ============================================================================= | |
| # DISPATCH | |
| # ============================================================================= | |
| def process_image( | |
| input_img, algo, | |
| sauvola_w, sauvola_k, sauvola_r, | |
| niblack_w, niblack_k, | |
| otsu_blur, | |
| adaptive_block, adaptive_c, | |
| threshold, sbb_invert, | |
| ): | |
| if input_img is None: | |
| raise gr.Error("Upload an image first.") | |
| try: | |
| if algo == M_SAUVOLA: | |
| return run_sauvola(input_img, sauvola_w, sauvola_k, sauvola_r) | |
| if algo == M_NIBLACK: | |
| return run_niblack(input_img, niblack_w, niblack_k) | |
| if algo == M_OTSU: | |
| return run_otsu(input_img, otsu_blur) | |
| if algo == M_ADAPTIVE: | |
| return run_adaptive_gaussian(input_img, adaptive_block, adaptive_c) | |
| if algo == M_B5: | |
| return run_b5(input_img, threshold) | |
| if algo == M_SBB: | |
| return run_sbb(input_img, threshold, sbb_invert) | |
| raise gr.Error(f"Unknown method: {algo}") | |
| except gr.Error: | |
| raise | |
| except Exception as e: | |
| # Surface the real failure; do NOT fall back to another method. | |
| raise gr.Error(f"{algo} failed: {e}") | |
| # ============================================================================= | |
| # UI | |
| # ============================================================================= | |
| def _visibility(algo): | |
| return [ | |
| gr.update(visible=algo == M_SAUVOLA), | |
| gr.update(visible=algo == M_NIBLACK), | |
| gr.update(visible=algo == M_OTSU), | |
| gr.update(visible=algo == M_ADAPTIVE), | |
| gr.update(visible=algo in (M_B5, M_SBB)), | |
| gr.update(visible=algo == M_SBB), | |
| ] | |
| with gr.Blocks(title="Xibi Binarization") as demo: | |
| gr.Markdown("# 📄 Document Image Binarization Suite") | |
| gr.Markdown( | |
| "Compare classical adaptive thresholding against neural binarizers. " | |
| "Each method runs on its own — if a model fails to load it reports the error, " | |
| "it never silently substitutes a different method." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="Upload Document") | |
| algo = gr.Dropdown(choices=METHODS, value=M_SAUVOLA, label="Binarization Method") | |
| with gr.Group(visible=True) as g_sauvola: | |
| s_w = gr.Slider(SAUVOLA_WINDOW_MIN, SAUVOLA_WINDOW_MAX, value=SAUVOLA_WINDOW_DEFAULT, | |
| step=SAUVOLA_WINDOW_STEP, label="Window size") | |
| s_k = gr.Slider(SAUVOLA_K_MIN, SAUVOLA_K_MAX, value=SAUVOLA_K_DEFAULT, | |
| step=SAUVOLA_K_STEP, label="k") | |
| s_r = gr.Slider(SAUVOLA_R_MIN, SAUVOLA_R_MAX, value=SAUVOLA_R_DEFAULT, | |
| step=SAUVOLA_R_STEP, label="r (dynamic range)") | |
| with gr.Group(visible=False) as g_niblack: | |
| n_w = gr.Slider(NIBLACK_WINDOW_MIN, NIBLACK_WINDOW_MAX, value=NIBLACK_WINDOW_DEFAULT, | |
| step=NIBLACK_WINDOW_STEP, label="Window size") | |
| n_k = gr.Slider(NIBLACK_K_MIN, NIBLACK_K_MAX, value=NIBLACK_K_DEFAULT, | |
| step=NIBLACK_K_STEP, label="k") | |
| with gr.Group(visible=False) as g_otsu: | |
| o_blur = gr.Slider(OTSU_BLUR_MIN, OTSU_BLUR_MAX, value=OTSU_BLUR_DEFAULT, | |
| step=OTSU_BLUR_STEP, label="Gaussian pre-blur (0 = off)") | |
| with gr.Group(visible=False) as g_adaptive: | |
| a_block = gr.Slider(ADAPTIVE_BLOCK_MIN, ADAPTIVE_BLOCK_MAX, value=ADAPTIVE_BLOCK_DEFAULT, | |
| step=ADAPTIVE_BLOCK_STEP, label="Block size") | |
| a_c = gr.Slider(ADAPTIVE_C_MIN, ADAPTIVE_C_MAX, value=ADAPTIVE_C_DEFAULT, | |
| step=ADAPTIVE_C_STEP, label="C (mean offset)") | |
| with gr.Group(visible=False) as g_neural: | |
| thr = gr.Slider(THRESHOLD_MIN, THRESHOLD_MAX, value=THRESHOLD_DEFAULT, step=THRESHOLD_STEP, | |
| label="Binarization threshold", | |
| info="Higher = thinner strokes, lower = thicker") | |
| with gr.Group(visible=False) as g_sbb: | |
| sbb_inv = gr.Checkbox(value=False, label="Invert SBB output (if foreground/background flipped)") | |
| submit_btn = gr.Button("Binarize", variant="primary") | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(type="pil", label="Binarized Output") | |
| algo.change( | |
| fn=_visibility, | |
| inputs=algo, | |
| outputs=[g_sauvola, g_niblack, g_otsu, g_adaptive, g_neural, g_sbb], | |
| ) | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[ | |
| input_image, algo, | |
| s_w, s_k, s_r, | |
| n_w, n_k, | |
| o_blur, | |
| a_block, a_c, | |
| thr, sbb_inv, | |
| ], | |
| outputs=output_image, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(theme=gr.themes.Soft(), ssr_mode=False) | |