import os import numpy as np import gradio as gr from PIL import Image # Optional heavy deps are imported lazily inside the loaders so the app boots # even when one stack is missing. A missing dep surfaces as a clear error on the # method that needs it -- never a silent fallback to a different method. try: import cv2 _CV2_AVAILABLE = True except ImportError: cv2 = None _CV2_AVAILABLE = False try: import spaces GPU = spaces.GPU except Exception: def GPU(fn): return fn # ============================================================================= # CONSTANTS # ============================================================================= # --- Method registry --- M_SAUVOLA = "Sauvola (classical)" M_NIBLACK = "Niblack (classical)" M_OTSU = "Otsu (classical)" M_ADAPTIVE = "Adaptive Gaussian (classical)" M_B5 = "Tzefa b5 — MAnet/mit_b5 (neural)" M_SBB = "SBB ResNet50-UNet (neural)" METHODS = [M_SAUVOLA, M_NIBLACK, M_OTSU, M_ADAPTIVE, M_B5, M_SBB] # --- Sauvola params --- SAUVOLA_WINDOW_MIN, SAUVOLA_WINDOW_MAX, SAUVOLA_WINDOW_STEP, SAUVOLA_WINDOW_DEFAULT = 3, 99, 2, 25 SAUVOLA_K_MIN, SAUVOLA_K_MAX, SAUVOLA_K_STEP, SAUVOLA_K_DEFAULT = 0.0, 1.0, 0.01, 0.2 SAUVOLA_R_MIN, SAUVOLA_R_MAX, SAUVOLA_R_STEP, SAUVOLA_R_DEFAULT = 1, 256, 1, 128 # --- Niblack params --- NIBLACK_WINDOW_MIN, NIBLACK_WINDOW_MAX, NIBLACK_WINDOW_STEP, NIBLACK_WINDOW_DEFAULT = 3, 99, 2, 25 NIBLACK_K_MIN, NIBLACK_K_MAX, NIBLACK_K_STEP, NIBLACK_K_DEFAULT = -1.0, 1.0, 0.01, -0.2 # --- Otsu params --- OTSU_BLUR_MIN, OTSU_BLUR_MAX, OTSU_BLUR_STEP, OTSU_BLUR_DEFAULT = 0, 31, 1, 5 # Gaussian pre-blur kernel; 0 disables # --- Adaptive Gaussian params --- ADAPTIVE_BLOCK_MIN, ADAPTIVE_BLOCK_MAX, ADAPTIVE_BLOCK_STEP, ADAPTIVE_BLOCK_DEFAULT = 3, 99, 2, 31 ADAPTIVE_C_MIN, ADAPTIVE_C_MAX, ADAPTIVE_C_STEP, ADAPTIVE_C_DEFAULT = -30, 30, 1, 10 # --- Tzefa b5 (neural) --- B5_REPO = "WARAJA/b5_model" B5_WEIGHTS_FILE = "b5_model.pth" B5_ENCODER = "mit_b5" B5_TILE_SIZE = 640 B5_DECODER_CHANNELS = (256, 128, 64, 32, 16) # ImageNet normalisation used at training time. IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEP, THRESHOLD_DEFAULT = 0.01, 0.99, 0.01, 0.5 # --- SBB (neural) --- SBB_REPO = "SBB/sbb_binarization" SBB_DEFAULT_PATCH = 256 # used only if the model has no fixed input size SBB_OVERLAP_FRACTION = 0.25 # patch overlap to suppress seam artifacts SBB_PAD_VALUE = 255 # reflect-free white pad for the foreground class # ============================================================================= # MODEL ARCHITECTURE (Tzefa b5) — mirrors the author's reference Space # ============================================================================= def _build_highres_manet(): # Imported here (not at module top) so the app boots and classical methods # work even when torch is absent. `forward` closes over this `torch`. import torch import torch.nn as nn import segmentation_models_pytorch as smp # noqa: F401 (pulls timm encoders) class HighResMAnet(nn.Module): def __init__(self, encoder_name=B5_ENCODER, encoder_weights=None, classes=1): super().__init__() self.base_model = smp.MAnet( encoder_name=encoder_name, encoder_weights=encoder_weights, in_channels=3, classes=classes, encoder_depth=5, decoder_channels=B5_DECODER_CHANNELS, ) self.high_res_stem = nn.Sequential( nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(16), nn.ReLU(inplace=True), nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), ) self.final_fusion = nn.Sequential( nn.Conv2d(16 + 32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(16, classes, kernel_size=1), ) def forward(self, x): high_res_features = self.high_res_stem(x) features = self.base_model.encoder(x) decoder_output = self.base_model.decoder(features) combined = torch.cat([decoder_output, high_res_features], dim=1) return self.final_fusion(combined) return HighResMAnet() # ============================================================================= # MODEL CACHES & LOADERS # ============================================================================= _b5_model = None _b5_device = None _sbb_model = None def _load_b5(): """Load the gated Tzefa b5 checkpoint. Raises on any failure (no fallback).""" global _b5_model, _b5_device if _b5_model is not None: return _b5_model, _b5_device token = os.environ.get("HF_TOKEN") if not token: raise RuntimeError( f"{B5_REPO} is gated. Set the HF_TOKEN secret (with access granted) to run this model." ) try: import torch from huggingface_hub import hf_hub_download except ImportError as e: raise RuntimeError(f"Tzefa b5 needs torch + huggingface_hub: {e}") device = "cuda" if torch.cuda.is_available() else "cpu" weights = hf_hub_download(repo_id=B5_REPO, filename=B5_WEIGHTS_FILE, token=token, repo_type="model") model = _build_highres_manet() ckpt = torch.load(weights, map_location=device) state_dict = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt model.load_state_dict(state_dict) model = model.to(device).eval() _b5_model, _b5_device = model, device return _b5_model, _b5_device def _load_sbb(): """Load the SBB tf-keras SavedModel. Raises on any failure (no fallback).""" global _sbb_model if _sbb_model is not None: return _sbb_model try: import tensorflow as tf from huggingface_hub import snapshot_download except ImportError as e: raise RuntimeError(f"SBB needs tensorflow + huggingface_hub: {e}") local_dir = snapshot_download(repo_id=SBB_REPO, repo_type="model") # Locate the directory that actually contains saved_model.pb saved_dir = None for root, _dirs, files in os.walk(local_dir): if "saved_model.pb" in files: saved_dir = root break if saved_dir is None: raise RuntimeError(f"No saved_model.pb found in {SBB_REPO}") _sbb_model = tf.keras.models.load_model(saved_dir, compile=False) return _sbb_model # ============================================================================= # CLASSICAL METHODS # ============================================================================= def _to_gray(image_pil): return np.array(image_pil.convert("L"), dtype=np.uint8) def run_sauvola(image_pil, window_size, k, r): from skimage.filters import threshold_sauvola gray = _to_gray(image_pil) window_size = int(window_size) | 1 # force odd thresh = threshold_sauvola(gray, window_size=window_size, k=float(k), r=float(r)) binary = (gray > thresh).astype(np.uint8) * 255 return Image.fromarray(binary) def run_niblack(image_pil, window_size, k): from skimage.filters import threshold_niblack gray = _to_gray(image_pil) window_size = int(window_size) | 1 thresh = threshold_niblack(gray, window_size=window_size, k=float(k)) binary = (gray > thresh).astype(np.uint8) * 255 return Image.fromarray(binary) def run_otsu(image_pil, blur_ksize): gray = _to_gray(image_pil) blur_ksize = int(blur_ksize) if not _CV2_AVAILABLE: raise RuntimeError("Otsu pre-blur path needs opencv (cv2). Install opencv-python-headless.") if blur_ksize > 0: blur_ksize |= 1 # odd kernel gray = cv2.GaussianBlur(gray, (blur_ksize, blur_ksize), 0) _t, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) return Image.fromarray(binary) def run_adaptive_gaussian(image_pil, block_size, c): if not _CV2_AVAILABLE: raise RuntimeError("Adaptive Gaussian needs opencv (cv2). Install opencv-python-headless.") gray = _to_gray(image_pil) block_size = int(block_size) | 1 binary = cv2.adaptiveThreshold( gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, int(c) ) return Image.fromarray(binary) # ============================================================================= # NEURAL METHODS # ============================================================================= def _b5_preprocess(tile_pil): import torch arr = np.array(tile_pil).astype(np.float32) / 255.0 arr = (arr - IMAGENET_MEAN) / IMAGENET_STD return torch.from_numpy(arr.transpose(2, 0, 1)) @GPU def run_b5(image_pil, threshold): """Tiled inference -> sigmoid foreground probability -> threshold.""" import torch import torch.nn.functional as F model, device = _load_b5() image_pil = image_pil.convert("RGB") orig_w, orig_h = image_pil.size pad_w = (B5_TILE_SIZE - (orig_w % B5_TILE_SIZE)) % B5_TILE_SIZE pad_h = (B5_TILE_SIZE - (orig_h % B5_TILE_SIZE)) % B5_TILE_SIZE padded = Image.new("RGB", (orig_w + pad_w, orig_h + pad_h), (255, 255, 255)) padded.paste(image_pil, (0, 0)) new_w, new_h = padded.size prob_map = np.zeros((new_h, new_w), dtype=np.float32) for y in range(0, new_h, B5_TILE_SIZE): for x in range(0, new_w, B5_TILE_SIZE): tile = padded.crop((x, y, x + B5_TILE_SIZE, y + B5_TILE_SIZE)) inp = _b5_preprocess(tile).unsqueeze(0).to(device).float() with torch.no_grad(): logits = model(inp) if logits.shape[-2:] != (B5_TILE_SIZE, B5_TILE_SIZE): logits = F.interpolate(logits, size=(B5_TILE_SIZE, B5_TILE_SIZE), mode="bilinear") prob_map[y:y + B5_TILE_SIZE, x:x + B5_TILE_SIZE] = torch.sigmoid(logits).cpu().numpy()[0, 0] prob_map = prob_map[:orig_h, :orig_w] binary = ((prob_map < float(threshold)) * 255).astype(np.uint8) # high fg prob -> black text return Image.fromarray(binary) def run_sbb(image_pil, threshold, invert): """Patch-tiled tf-keras inference -> foreground probability -> threshold.""" model = _load_sbb() rgb = np.array(image_pil.convert("RGB"), dtype=np.float32) / 255.0 H, W = rgb.shape[:2] # Determine the model's native patch size; fall back to a sane default. try: in_shape = model.inputs[0].shape ph = int(in_shape[1]) if in_shape[1] is not None else SBB_DEFAULT_PATCH pw = int(in_shape[2]) if in_shape[2] is not None else SBB_DEFAULT_PATCH except Exception: ph = pw = SBB_DEFAULT_PATCH step_h = max(1, int(ph * (1 - SBB_OVERLAP_FRACTION))) step_w = max(1, int(pw * (1 - SBB_OVERLAP_FRACTION))) # White-pad so the image fully tiles. pad_h = (step_h - (H - ph) % step_h) % step_h if H > ph else ph - H pad_w = (step_w - (W - pw) % step_w) % step_w if W > pw else pw - W padded = np.ones((H + max(0, pad_h), W + max(0, pad_w), 3), dtype=np.float32) padded[:H, :W] = rgb PH, PW = padded.shape[:2] prob_sum = np.zeros((PH, PW), dtype=np.float32) count = np.zeros((PH, PW), dtype=np.float32) for y in range(0, PH - ph + 1, step_h): for x in range(0, PW - pw + 1, step_w): patch = padded[y:y + ph, x:x + pw][None, ...] pred = model.predict(patch, verbose=0)[0] # (h, w, 2) fg = pred[..., 1] if pred.ndim == 3 and pred.shape[-1] >= 2 else np.squeeze(pred) prob_sum[y:y + ph, x:x + pw] += fg count[y:y + ph, x:x + pw] += 1.0 count[count == 0] = 1.0 prob_map = (prob_sum / count)[:H, :W] if invert: prob_map = 1.0 - prob_map binary = ((prob_map < float(threshold)) * 255).astype(np.uint8) return Image.fromarray(binary) # ============================================================================= # DISPATCH # ============================================================================= def process_image( input_img, algo, sauvola_w, sauvola_k, sauvola_r, niblack_w, niblack_k, otsu_blur, adaptive_block, adaptive_c, threshold, sbb_invert, ): if input_img is None: raise gr.Error("Upload an image first.") try: if algo == M_SAUVOLA: return run_sauvola(input_img, sauvola_w, sauvola_k, sauvola_r) if algo == M_NIBLACK: return run_niblack(input_img, niblack_w, niblack_k) if algo == M_OTSU: return run_otsu(input_img, otsu_blur) if algo == M_ADAPTIVE: return run_adaptive_gaussian(input_img, adaptive_block, adaptive_c) if algo == M_B5: return run_b5(input_img, threshold) if algo == M_SBB: return run_sbb(input_img, threshold, sbb_invert) raise gr.Error(f"Unknown method: {algo}") except gr.Error: raise except Exception as e: # Surface the real failure; do NOT fall back to another method. raise gr.Error(f"{algo} failed: {e}") # ============================================================================= # UI # ============================================================================= def _visibility(algo): return [ gr.update(visible=algo == M_SAUVOLA), gr.update(visible=algo == M_NIBLACK), gr.update(visible=algo == M_OTSU), gr.update(visible=algo == M_ADAPTIVE), gr.update(visible=algo in (M_B5, M_SBB)), gr.update(visible=algo == M_SBB), ] with gr.Blocks(title="Xibi Binarization") as demo: gr.Markdown("# 📄 Document Image Binarization Suite") gr.Markdown( "Compare classical adaptive thresholding against neural binarizers. " "Each method runs on its own — if a model fails to load it reports the error, " "it never silently substitutes a different method." ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="Upload Document") algo = gr.Dropdown(choices=METHODS, value=M_SAUVOLA, label="Binarization Method") with gr.Group(visible=True) as g_sauvola: s_w = gr.Slider(SAUVOLA_WINDOW_MIN, SAUVOLA_WINDOW_MAX, value=SAUVOLA_WINDOW_DEFAULT, step=SAUVOLA_WINDOW_STEP, label="Window size") s_k = gr.Slider(SAUVOLA_K_MIN, SAUVOLA_K_MAX, value=SAUVOLA_K_DEFAULT, step=SAUVOLA_K_STEP, label="k") s_r = gr.Slider(SAUVOLA_R_MIN, SAUVOLA_R_MAX, value=SAUVOLA_R_DEFAULT, step=SAUVOLA_R_STEP, label="r (dynamic range)") with gr.Group(visible=False) as g_niblack: n_w = gr.Slider(NIBLACK_WINDOW_MIN, NIBLACK_WINDOW_MAX, value=NIBLACK_WINDOW_DEFAULT, step=NIBLACK_WINDOW_STEP, label="Window size") n_k = gr.Slider(NIBLACK_K_MIN, NIBLACK_K_MAX, value=NIBLACK_K_DEFAULT, step=NIBLACK_K_STEP, label="k") with gr.Group(visible=False) as g_otsu: o_blur = gr.Slider(OTSU_BLUR_MIN, OTSU_BLUR_MAX, value=OTSU_BLUR_DEFAULT, step=OTSU_BLUR_STEP, label="Gaussian pre-blur (0 = off)") with gr.Group(visible=False) as g_adaptive: a_block = gr.Slider(ADAPTIVE_BLOCK_MIN, ADAPTIVE_BLOCK_MAX, value=ADAPTIVE_BLOCK_DEFAULT, step=ADAPTIVE_BLOCK_STEP, label="Block size") a_c = gr.Slider(ADAPTIVE_C_MIN, ADAPTIVE_C_MAX, value=ADAPTIVE_C_DEFAULT, step=ADAPTIVE_C_STEP, label="C (mean offset)") with gr.Group(visible=False) as g_neural: thr = gr.Slider(THRESHOLD_MIN, THRESHOLD_MAX, value=THRESHOLD_DEFAULT, step=THRESHOLD_STEP, label="Binarization threshold", info="Higher = thinner strokes, lower = thicker") with gr.Group(visible=False) as g_sbb: sbb_inv = gr.Checkbox(value=False, label="Invert SBB output (if foreground/background flipped)") submit_btn = gr.Button("Binarize", variant="primary") with gr.Column(scale=1): output_image = gr.Image(type="pil", label="Binarized Output") algo.change( fn=_visibility, inputs=algo, outputs=[g_sauvola, g_niblack, g_otsu, g_adaptive, g_neural, g_sbb], ) submit_btn.click( fn=process_image, inputs=[ input_image, algo, s_w, s_k, s_r, n_w, n_k, o_blur, a_block, a_c, thr, sbb_inv, ], outputs=output_image, ) if __name__ == "__main__": demo.launch(theme=gr.themes.Soft(), ssr_mode=False)