Xbits's picture
Rework binarization methods: real tunable params, no silent fallbacks
13c1c10
import os
import numpy as np
import gradio as gr
from PIL import Image
# Optional heavy deps are imported lazily inside the loaders so the app boots
# even when one stack is missing. A missing dep surfaces as a clear error on the
# method that needs it -- never a silent fallback to a different method.
try:
import cv2
_CV2_AVAILABLE = True
except ImportError:
cv2 = None
_CV2_AVAILABLE = False
try:
import spaces
GPU = spaces.GPU
except Exception:
def GPU(fn):
return fn
# =============================================================================
# CONSTANTS
# =============================================================================
# --- Method registry ---
M_SAUVOLA = "Sauvola (classical)"
M_NIBLACK = "Niblack (classical)"
M_OTSU = "Otsu (classical)"
M_ADAPTIVE = "Adaptive Gaussian (classical)"
M_B5 = "Tzefa b5 — MAnet/mit_b5 (neural)"
M_SBB = "SBB ResNet50-UNet (neural)"
METHODS = [M_SAUVOLA, M_NIBLACK, M_OTSU, M_ADAPTIVE, M_B5, M_SBB]
# --- Sauvola params ---
SAUVOLA_WINDOW_MIN, SAUVOLA_WINDOW_MAX, SAUVOLA_WINDOW_STEP, SAUVOLA_WINDOW_DEFAULT = 3, 99, 2, 25
SAUVOLA_K_MIN, SAUVOLA_K_MAX, SAUVOLA_K_STEP, SAUVOLA_K_DEFAULT = 0.0, 1.0, 0.01, 0.2
SAUVOLA_R_MIN, SAUVOLA_R_MAX, SAUVOLA_R_STEP, SAUVOLA_R_DEFAULT = 1, 256, 1, 128
# --- Niblack params ---
NIBLACK_WINDOW_MIN, NIBLACK_WINDOW_MAX, NIBLACK_WINDOW_STEP, NIBLACK_WINDOW_DEFAULT = 3, 99, 2, 25
NIBLACK_K_MIN, NIBLACK_K_MAX, NIBLACK_K_STEP, NIBLACK_K_DEFAULT = -1.0, 1.0, 0.01, -0.2
# --- Otsu params ---
OTSU_BLUR_MIN, OTSU_BLUR_MAX, OTSU_BLUR_STEP, OTSU_BLUR_DEFAULT = 0, 31, 1, 5 # Gaussian pre-blur kernel; 0 disables
# --- Adaptive Gaussian params ---
ADAPTIVE_BLOCK_MIN, ADAPTIVE_BLOCK_MAX, ADAPTIVE_BLOCK_STEP, ADAPTIVE_BLOCK_DEFAULT = 3, 99, 2, 31
ADAPTIVE_C_MIN, ADAPTIVE_C_MAX, ADAPTIVE_C_STEP, ADAPTIVE_C_DEFAULT = -30, 30, 1, 10
# --- Tzefa b5 (neural) ---
B5_REPO = "WARAJA/b5_model"
B5_WEIGHTS_FILE = "b5_model.pth"
B5_ENCODER = "mit_b5"
B5_TILE_SIZE = 640
B5_DECODER_CHANNELS = (256, 128, 64, 32, 16)
# ImageNet normalisation used at training time.
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEP, THRESHOLD_DEFAULT = 0.01, 0.99, 0.01, 0.5
# --- SBB (neural) ---
SBB_REPO = "SBB/sbb_binarization"
SBB_DEFAULT_PATCH = 256 # used only if the model has no fixed input size
SBB_OVERLAP_FRACTION = 0.25 # patch overlap to suppress seam artifacts
SBB_PAD_VALUE = 255 # reflect-free white pad for the foreground class
# =============================================================================
# MODEL ARCHITECTURE (Tzefa b5) — mirrors the author's reference Space
# =============================================================================
def _build_highres_manet():
# Imported here (not at module top) so the app boots and classical methods
# work even when torch is absent. `forward` closes over this `torch`.
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp # noqa: F401 (pulls timm encoders)
class HighResMAnet(nn.Module):
def __init__(self, encoder_name=B5_ENCODER, encoder_weights=None, classes=1):
super().__init__()
self.base_model = smp.MAnet(
encoder_name=encoder_name,
encoder_weights=encoder_weights,
in_channels=3,
classes=classes,
encoder_depth=5,
decoder_channels=B5_DECODER_CHANNELS,
)
self.high_res_stem = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(16),
nn.ReLU(inplace=True),
nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
)
self.final_fusion = nn.Sequential(
nn.Conv2d(16 + 32, 16, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(16, classes, kernel_size=1),
)
def forward(self, x):
high_res_features = self.high_res_stem(x)
features = self.base_model.encoder(x)
decoder_output = self.base_model.decoder(features)
combined = torch.cat([decoder_output, high_res_features], dim=1)
return self.final_fusion(combined)
return HighResMAnet()
# =============================================================================
# MODEL CACHES & LOADERS
# =============================================================================
_b5_model = None
_b5_device = None
_sbb_model = None
def _load_b5():
"""Load the gated Tzefa b5 checkpoint. Raises on any failure (no fallback)."""
global _b5_model, _b5_device
if _b5_model is not None:
return _b5_model, _b5_device
token = os.environ.get("HF_TOKEN")
if not token:
raise RuntimeError(
f"{B5_REPO} is gated. Set the HF_TOKEN secret (with access granted) to run this model."
)
try:
import torch
from huggingface_hub import hf_hub_download
except ImportError as e:
raise RuntimeError(f"Tzefa b5 needs torch + huggingface_hub: {e}")
device = "cuda" if torch.cuda.is_available() else "cpu"
weights = hf_hub_download(repo_id=B5_REPO, filename=B5_WEIGHTS_FILE, token=token, repo_type="model")
model = _build_highres_manet()
ckpt = torch.load(weights, map_location=device)
state_dict = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
model.load_state_dict(state_dict)
model = model.to(device).eval()
_b5_model, _b5_device = model, device
return _b5_model, _b5_device
def _load_sbb():
"""Load the SBB tf-keras SavedModel. Raises on any failure (no fallback)."""
global _sbb_model
if _sbb_model is not None:
return _sbb_model
try:
import tensorflow as tf
from huggingface_hub import snapshot_download
except ImportError as e:
raise RuntimeError(f"SBB needs tensorflow + huggingface_hub: {e}")
local_dir = snapshot_download(repo_id=SBB_REPO, repo_type="model")
# Locate the directory that actually contains saved_model.pb
saved_dir = None
for root, _dirs, files in os.walk(local_dir):
if "saved_model.pb" in files:
saved_dir = root
break
if saved_dir is None:
raise RuntimeError(f"No saved_model.pb found in {SBB_REPO}")
_sbb_model = tf.keras.models.load_model(saved_dir, compile=False)
return _sbb_model
# =============================================================================
# CLASSICAL METHODS
# =============================================================================
def _to_gray(image_pil):
return np.array(image_pil.convert("L"), dtype=np.uint8)
def run_sauvola(image_pil, window_size, k, r):
from skimage.filters import threshold_sauvola
gray = _to_gray(image_pil)
window_size = int(window_size) | 1 # force odd
thresh = threshold_sauvola(gray, window_size=window_size, k=float(k), r=float(r))
binary = (gray > thresh).astype(np.uint8) * 255
return Image.fromarray(binary)
def run_niblack(image_pil, window_size, k):
from skimage.filters import threshold_niblack
gray = _to_gray(image_pil)
window_size = int(window_size) | 1
thresh = threshold_niblack(gray, window_size=window_size, k=float(k))
binary = (gray > thresh).astype(np.uint8) * 255
return Image.fromarray(binary)
def run_otsu(image_pil, blur_ksize):
gray = _to_gray(image_pil)
blur_ksize = int(blur_ksize)
if not _CV2_AVAILABLE:
raise RuntimeError("Otsu pre-blur path needs opencv (cv2). Install opencv-python-headless.")
if blur_ksize > 0:
blur_ksize |= 1 # odd kernel
gray = cv2.GaussianBlur(gray, (blur_ksize, blur_ksize), 0)
_t, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return Image.fromarray(binary)
def run_adaptive_gaussian(image_pil, block_size, c):
if not _CV2_AVAILABLE:
raise RuntimeError("Adaptive Gaussian needs opencv (cv2). Install opencv-python-headless.")
gray = _to_gray(image_pil)
block_size = int(block_size) | 1
binary = cv2.adaptiveThreshold(
gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, int(c)
)
return Image.fromarray(binary)
# =============================================================================
# NEURAL METHODS
# =============================================================================
def _b5_preprocess(tile_pil):
import torch
arr = np.array(tile_pil).astype(np.float32) / 255.0
arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
return torch.from_numpy(arr.transpose(2, 0, 1))
@GPU
def run_b5(image_pil, threshold):
"""Tiled inference -> sigmoid foreground probability -> threshold."""
import torch
import torch.nn.functional as F
model, device = _load_b5()
image_pil = image_pil.convert("RGB")
orig_w, orig_h = image_pil.size
pad_w = (B5_TILE_SIZE - (orig_w % B5_TILE_SIZE)) % B5_TILE_SIZE
pad_h = (B5_TILE_SIZE - (orig_h % B5_TILE_SIZE)) % B5_TILE_SIZE
padded = Image.new("RGB", (orig_w + pad_w, orig_h + pad_h), (255, 255, 255))
padded.paste(image_pil, (0, 0))
new_w, new_h = padded.size
prob_map = np.zeros((new_h, new_w), dtype=np.float32)
for y in range(0, new_h, B5_TILE_SIZE):
for x in range(0, new_w, B5_TILE_SIZE):
tile = padded.crop((x, y, x + B5_TILE_SIZE, y + B5_TILE_SIZE))
inp = _b5_preprocess(tile).unsqueeze(0).to(device).float()
with torch.no_grad():
logits = model(inp)
if logits.shape[-2:] != (B5_TILE_SIZE, B5_TILE_SIZE):
logits = F.interpolate(logits, size=(B5_TILE_SIZE, B5_TILE_SIZE), mode="bilinear")
prob_map[y:y + B5_TILE_SIZE, x:x + B5_TILE_SIZE] = torch.sigmoid(logits).cpu().numpy()[0, 0]
prob_map = prob_map[:orig_h, :orig_w]
binary = ((prob_map < float(threshold)) * 255).astype(np.uint8) # high fg prob -> black text
return Image.fromarray(binary)
def run_sbb(image_pil, threshold, invert):
"""Patch-tiled tf-keras inference -> foreground probability -> threshold."""
model = _load_sbb()
rgb = np.array(image_pil.convert("RGB"), dtype=np.float32) / 255.0
H, W = rgb.shape[:2]
# Determine the model's native patch size; fall back to a sane default.
try:
in_shape = model.inputs[0].shape
ph = int(in_shape[1]) if in_shape[1] is not None else SBB_DEFAULT_PATCH
pw = int(in_shape[2]) if in_shape[2] is not None else SBB_DEFAULT_PATCH
except Exception:
ph = pw = SBB_DEFAULT_PATCH
step_h = max(1, int(ph * (1 - SBB_OVERLAP_FRACTION)))
step_w = max(1, int(pw * (1 - SBB_OVERLAP_FRACTION)))
# White-pad so the image fully tiles.
pad_h = (step_h - (H - ph) % step_h) % step_h if H > ph else ph - H
pad_w = (step_w - (W - pw) % step_w) % step_w if W > pw else pw - W
padded = np.ones((H + max(0, pad_h), W + max(0, pad_w), 3), dtype=np.float32)
padded[:H, :W] = rgb
PH, PW = padded.shape[:2]
prob_sum = np.zeros((PH, PW), dtype=np.float32)
count = np.zeros((PH, PW), dtype=np.float32)
for y in range(0, PH - ph + 1, step_h):
for x in range(0, PW - pw + 1, step_w):
patch = padded[y:y + ph, x:x + pw][None, ...]
pred = model.predict(patch, verbose=0)[0] # (h, w, 2)
fg = pred[..., 1] if pred.ndim == 3 and pred.shape[-1] >= 2 else np.squeeze(pred)
prob_sum[y:y + ph, x:x + pw] += fg
count[y:y + ph, x:x + pw] += 1.0
count[count == 0] = 1.0
prob_map = (prob_sum / count)[:H, :W]
if invert:
prob_map = 1.0 - prob_map
binary = ((prob_map < float(threshold)) * 255).astype(np.uint8)
return Image.fromarray(binary)
# =============================================================================
# DISPATCH
# =============================================================================
def process_image(
input_img, algo,
sauvola_w, sauvola_k, sauvola_r,
niblack_w, niblack_k,
otsu_blur,
adaptive_block, adaptive_c,
threshold, sbb_invert,
):
if input_img is None:
raise gr.Error("Upload an image first.")
try:
if algo == M_SAUVOLA:
return run_sauvola(input_img, sauvola_w, sauvola_k, sauvola_r)
if algo == M_NIBLACK:
return run_niblack(input_img, niblack_w, niblack_k)
if algo == M_OTSU:
return run_otsu(input_img, otsu_blur)
if algo == M_ADAPTIVE:
return run_adaptive_gaussian(input_img, adaptive_block, adaptive_c)
if algo == M_B5:
return run_b5(input_img, threshold)
if algo == M_SBB:
return run_sbb(input_img, threshold, sbb_invert)
raise gr.Error(f"Unknown method: {algo}")
except gr.Error:
raise
except Exception as e:
# Surface the real failure; do NOT fall back to another method.
raise gr.Error(f"{algo} failed: {e}")
# =============================================================================
# UI
# =============================================================================
def _visibility(algo):
return [
gr.update(visible=algo == M_SAUVOLA),
gr.update(visible=algo == M_NIBLACK),
gr.update(visible=algo == M_OTSU),
gr.update(visible=algo == M_ADAPTIVE),
gr.update(visible=algo in (M_B5, M_SBB)),
gr.update(visible=algo == M_SBB),
]
with gr.Blocks(title="Xibi Binarization") as demo:
gr.Markdown("# 📄 Document Image Binarization Suite")
gr.Markdown(
"Compare classical adaptive thresholding against neural binarizers. "
"Each method runs on its own — if a model fails to load it reports the error, "
"it never silently substitutes a different method."
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="Upload Document")
algo = gr.Dropdown(choices=METHODS, value=M_SAUVOLA, label="Binarization Method")
with gr.Group(visible=True) as g_sauvola:
s_w = gr.Slider(SAUVOLA_WINDOW_MIN, SAUVOLA_WINDOW_MAX, value=SAUVOLA_WINDOW_DEFAULT,
step=SAUVOLA_WINDOW_STEP, label="Window size")
s_k = gr.Slider(SAUVOLA_K_MIN, SAUVOLA_K_MAX, value=SAUVOLA_K_DEFAULT,
step=SAUVOLA_K_STEP, label="k")
s_r = gr.Slider(SAUVOLA_R_MIN, SAUVOLA_R_MAX, value=SAUVOLA_R_DEFAULT,
step=SAUVOLA_R_STEP, label="r (dynamic range)")
with gr.Group(visible=False) as g_niblack:
n_w = gr.Slider(NIBLACK_WINDOW_MIN, NIBLACK_WINDOW_MAX, value=NIBLACK_WINDOW_DEFAULT,
step=NIBLACK_WINDOW_STEP, label="Window size")
n_k = gr.Slider(NIBLACK_K_MIN, NIBLACK_K_MAX, value=NIBLACK_K_DEFAULT,
step=NIBLACK_K_STEP, label="k")
with gr.Group(visible=False) as g_otsu:
o_blur = gr.Slider(OTSU_BLUR_MIN, OTSU_BLUR_MAX, value=OTSU_BLUR_DEFAULT,
step=OTSU_BLUR_STEP, label="Gaussian pre-blur (0 = off)")
with gr.Group(visible=False) as g_adaptive:
a_block = gr.Slider(ADAPTIVE_BLOCK_MIN, ADAPTIVE_BLOCK_MAX, value=ADAPTIVE_BLOCK_DEFAULT,
step=ADAPTIVE_BLOCK_STEP, label="Block size")
a_c = gr.Slider(ADAPTIVE_C_MIN, ADAPTIVE_C_MAX, value=ADAPTIVE_C_DEFAULT,
step=ADAPTIVE_C_STEP, label="C (mean offset)")
with gr.Group(visible=False) as g_neural:
thr = gr.Slider(THRESHOLD_MIN, THRESHOLD_MAX, value=THRESHOLD_DEFAULT, step=THRESHOLD_STEP,
label="Binarization threshold",
info="Higher = thinner strokes, lower = thicker")
with gr.Group(visible=False) as g_sbb:
sbb_inv = gr.Checkbox(value=False, label="Invert SBB output (if foreground/background flipped)")
submit_btn = gr.Button("Binarize", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="Binarized Output")
algo.change(
fn=_visibility,
inputs=algo,
outputs=[g_sauvola, g_niblack, g_otsu, g_adaptive, g_neural, g_sbb],
)
submit_btn.click(
fn=process_image,
inputs=[
input_image, algo,
s_w, s_k, s_r,
n_w, n_k,
o_blur,
a_block, a_c,
thr, sbb_inv,
],
outputs=output_image,
)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft(), ssr_mode=False)