File size: 17,115 Bytes
13c1c10
 
c505a87
13c1c10
 
b53a535
13c1c10
 
 
b53a535
 
13c1c10
b53a535
 
13c1c10
b53a535
 
 
 
 
 
 
 
 
13c1c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b53a535
13c1c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b53a535
 
13c1c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b53a535
13c1c10
 
 
b53a535
13c1c10
 
b53a535
13c1c10
 
 
 
 
 
 
c505a87
553e5b1
b53a535
13c1c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b53a535
13c1c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c505a87
 
13c1c10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c505a87
13c1c10
 
 
 
 
 
 
 
c505a87
13c1c10
 
 
 
 
 
 
 
 
 
c505a87
 
13c1c10
b53a535
13c1c10
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
import os

import numpy as np
import gradio as gr
from PIL import Image

# Optional heavy deps are imported lazily inside the loaders so the app boots
# even when one stack is missing. A missing dep surfaces as a clear error on the
# method that needs it -- never a silent fallback to a different method.
try:
    import cv2
    _CV2_AVAILABLE = True
except ImportError:
    cv2 = None
    _CV2_AVAILABLE = False

try:
    import spaces
    GPU = spaces.GPU
except Exception:
    def GPU(fn):
        return fn


# =============================================================================
# CONSTANTS
# =============================================================================

# --- Method registry ---
M_SAUVOLA = "Sauvola (classical)"
M_NIBLACK = "Niblack (classical)"
M_OTSU = "Otsu (classical)"
M_ADAPTIVE = "Adaptive Gaussian (classical)"
M_B5 = "Tzefa b5 — MAnet/mit_b5 (neural)"
M_SBB = "SBB ResNet50-UNet (neural)"

METHODS = [M_SAUVOLA, M_NIBLACK, M_OTSU, M_ADAPTIVE, M_B5, M_SBB]

# --- Sauvola params ---
SAUVOLA_WINDOW_MIN, SAUVOLA_WINDOW_MAX, SAUVOLA_WINDOW_STEP, SAUVOLA_WINDOW_DEFAULT = 3, 99, 2, 25
SAUVOLA_K_MIN, SAUVOLA_K_MAX, SAUVOLA_K_STEP, SAUVOLA_K_DEFAULT = 0.0, 1.0, 0.01, 0.2
SAUVOLA_R_MIN, SAUVOLA_R_MAX, SAUVOLA_R_STEP, SAUVOLA_R_DEFAULT = 1, 256, 1, 128

# --- Niblack params ---
NIBLACK_WINDOW_MIN, NIBLACK_WINDOW_MAX, NIBLACK_WINDOW_STEP, NIBLACK_WINDOW_DEFAULT = 3, 99, 2, 25
NIBLACK_K_MIN, NIBLACK_K_MAX, NIBLACK_K_STEP, NIBLACK_K_DEFAULT = -1.0, 1.0, 0.01, -0.2

# --- Otsu params ---
OTSU_BLUR_MIN, OTSU_BLUR_MAX, OTSU_BLUR_STEP, OTSU_BLUR_DEFAULT = 0, 31, 1, 5  # Gaussian pre-blur kernel; 0 disables

# --- Adaptive Gaussian params ---
ADAPTIVE_BLOCK_MIN, ADAPTIVE_BLOCK_MAX, ADAPTIVE_BLOCK_STEP, ADAPTIVE_BLOCK_DEFAULT = 3, 99, 2, 31
ADAPTIVE_C_MIN, ADAPTIVE_C_MAX, ADAPTIVE_C_STEP, ADAPTIVE_C_DEFAULT = -30, 30, 1, 10

# --- Tzefa b5 (neural) ---
B5_REPO = "WARAJA/b5_model"
B5_WEIGHTS_FILE = "b5_model.pth"
B5_ENCODER = "mit_b5"
B5_TILE_SIZE = 640
B5_DECODER_CHANNELS = (256, 128, 64, 32, 16)
# ImageNet normalisation used at training time.
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
THRESHOLD_MIN, THRESHOLD_MAX, THRESHOLD_STEP, THRESHOLD_DEFAULT = 0.01, 0.99, 0.01, 0.5

# --- SBB (neural) ---
SBB_REPO = "SBB/sbb_binarization"
SBB_DEFAULT_PATCH = 256          # used only if the model has no fixed input size
SBB_OVERLAP_FRACTION = 0.25      # patch overlap to suppress seam artifacts
SBB_PAD_VALUE = 255              # reflect-free white pad for the foreground class


# =============================================================================
# MODEL ARCHITECTURE (Tzefa b5) — mirrors the author's reference Space
# =============================================================================

def _build_highres_manet():
    # Imported here (not at module top) so the app boots and classical methods
    # work even when torch is absent. `forward` closes over this `torch`.
    import torch
    import torch.nn as nn
    import segmentation_models_pytorch as smp  # noqa: F401  (pulls timm encoders)

    class HighResMAnet(nn.Module):
        def __init__(self, encoder_name=B5_ENCODER, encoder_weights=None, classes=1):
            super().__init__()
            self.base_model = smp.MAnet(
                encoder_name=encoder_name,
                encoder_weights=encoder_weights,
                in_channels=3,
                classes=classes,
                encoder_depth=5,
                decoder_channels=B5_DECODER_CHANNELS,
            )
            self.high_res_stem = nn.Sequential(
                nn.Conv2d(3, 16, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(16),
                nn.ReLU(inplace=True),
                nn.Conv2d(16, 32, kernel_size=3, padding=1, stride=1),
                nn.BatchNorm2d(32),
                nn.ReLU(inplace=True),
            )
            self.final_fusion = nn.Sequential(
                nn.Conv2d(16 + 32, 16, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(16, classes, kernel_size=1),
            )

        def forward(self, x):
            high_res_features = self.high_res_stem(x)
            features = self.base_model.encoder(x)
            decoder_output = self.base_model.decoder(features)
            combined = torch.cat([decoder_output, high_res_features], dim=1)
            return self.final_fusion(combined)

    return HighResMAnet()


# =============================================================================
# MODEL CACHES & LOADERS
# =============================================================================

_b5_model = None
_b5_device = None
_sbb_model = None


def _load_b5():
    """Load the gated Tzefa b5 checkpoint. Raises on any failure (no fallback)."""
    global _b5_model, _b5_device
    if _b5_model is not None:
        return _b5_model, _b5_device

    token = os.environ.get("HF_TOKEN")
    if not token:
        raise RuntimeError(
            f"{B5_REPO} is gated. Set the HF_TOKEN secret (with access granted) to run this model."
        )

    try:
        import torch
        from huggingface_hub import hf_hub_download
    except ImportError as e:
        raise RuntimeError(f"Tzefa b5 needs torch + huggingface_hub: {e}")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    weights = hf_hub_download(repo_id=B5_REPO, filename=B5_WEIGHTS_FILE, token=token, repo_type="model")
    model = _build_highres_manet()
    ckpt = torch.load(weights, map_location=device)
    state_dict = ckpt["model_state_dict"] if isinstance(ckpt, dict) and "model_state_dict" in ckpt else ckpt
    model.load_state_dict(state_dict)
    model = model.to(device).eval()

    _b5_model, _b5_device = model, device
    return _b5_model, _b5_device


def _load_sbb():
    """Load the SBB tf-keras SavedModel. Raises on any failure (no fallback)."""
    global _sbb_model
    if _sbb_model is not None:
        return _sbb_model

    try:
        import tensorflow as tf
        from huggingface_hub import snapshot_download
    except ImportError as e:
        raise RuntimeError(f"SBB needs tensorflow + huggingface_hub: {e}")

    local_dir = snapshot_download(repo_id=SBB_REPO, repo_type="model")
    # Locate the directory that actually contains saved_model.pb
    saved_dir = None
    for root, _dirs, files in os.walk(local_dir):
        if "saved_model.pb" in files:
            saved_dir = root
            break
    if saved_dir is None:
        raise RuntimeError(f"No saved_model.pb found in {SBB_REPO}")

    _sbb_model = tf.keras.models.load_model(saved_dir, compile=False)
    return _sbb_model


# =============================================================================
# CLASSICAL METHODS
# =============================================================================

def _to_gray(image_pil):
    return np.array(image_pil.convert("L"), dtype=np.uint8)


def run_sauvola(image_pil, window_size, k, r):
    from skimage.filters import threshold_sauvola
    gray = _to_gray(image_pil)
    window_size = int(window_size) | 1  # force odd
    thresh = threshold_sauvola(gray, window_size=window_size, k=float(k), r=float(r))
    binary = (gray > thresh).astype(np.uint8) * 255
    return Image.fromarray(binary)


def run_niblack(image_pil, window_size, k):
    from skimage.filters import threshold_niblack
    gray = _to_gray(image_pil)
    window_size = int(window_size) | 1
    thresh = threshold_niblack(gray, window_size=window_size, k=float(k))
    binary = (gray > thresh).astype(np.uint8) * 255
    return Image.fromarray(binary)


def run_otsu(image_pil, blur_ksize):
    gray = _to_gray(image_pil)
    blur_ksize = int(blur_ksize)
    if not _CV2_AVAILABLE:
        raise RuntimeError("Otsu pre-blur path needs opencv (cv2). Install opencv-python-headless.")
    if blur_ksize > 0:
        blur_ksize |= 1  # odd kernel
        gray = cv2.GaussianBlur(gray, (blur_ksize, blur_ksize), 0)
    _t, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    return Image.fromarray(binary)


def run_adaptive_gaussian(image_pil, block_size, c):
    if not _CV2_AVAILABLE:
        raise RuntimeError("Adaptive Gaussian needs opencv (cv2). Install opencv-python-headless.")
    gray = _to_gray(image_pil)
    block_size = int(block_size) | 1
    binary = cv2.adaptiveThreshold(
        gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, int(c)
    )
    return Image.fromarray(binary)


# =============================================================================
# NEURAL METHODS
# =============================================================================

def _b5_preprocess(tile_pil):
    import torch
    arr = np.array(tile_pil).astype(np.float32) / 255.0
    arr = (arr - IMAGENET_MEAN) / IMAGENET_STD
    return torch.from_numpy(arr.transpose(2, 0, 1))


@GPU
def run_b5(image_pil, threshold):
    """Tiled inference -> sigmoid foreground probability -> threshold."""
    import torch
    import torch.nn.functional as F

    model, device = _load_b5()
    image_pil = image_pil.convert("RGB")
    orig_w, orig_h = image_pil.size

    pad_w = (B5_TILE_SIZE - (orig_w % B5_TILE_SIZE)) % B5_TILE_SIZE
    pad_h = (B5_TILE_SIZE - (orig_h % B5_TILE_SIZE)) % B5_TILE_SIZE
    padded = Image.new("RGB", (orig_w + pad_w, orig_h + pad_h), (255, 255, 255))
    padded.paste(image_pil, (0, 0))
    new_w, new_h = padded.size

    prob_map = np.zeros((new_h, new_w), dtype=np.float32)
    for y in range(0, new_h, B5_TILE_SIZE):
        for x in range(0, new_w, B5_TILE_SIZE):
            tile = padded.crop((x, y, x + B5_TILE_SIZE, y + B5_TILE_SIZE))
            inp = _b5_preprocess(tile).unsqueeze(0).to(device).float()
            with torch.no_grad():
                logits = model(inp)
                if logits.shape[-2:] != (B5_TILE_SIZE, B5_TILE_SIZE):
                    logits = F.interpolate(logits, size=(B5_TILE_SIZE, B5_TILE_SIZE), mode="bilinear")
                prob_map[y:y + B5_TILE_SIZE, x:x + B5_TILE_SIZE] = torch.sigmoid(logits).cpu().numpy()[0, 0]

    prob_map = prob_map[:orig_h, :orig_w]
    binary = ((prob_map < float(threshold)) * 255).astype(np.uint8)  # high fg prob -> black text
    return Image.fromarray(binary)


def run_sbb(image_pil, threshold, invert):
    """Patch-tiled tf-keras inference -> foreground probability -> threshold."""
    model = _load_sbb()
    rgb = np.array(image_pil.convert("RGB"), dtype=np.float32) / 255.0
    H, W = rgb.shape[:2]

    # Determine the model's native patch size; fall back to a sane default.
    try:
        in_shape = model.inputs[0].shape
        ph = int(in_shape[1]) if in_shape[1] is not None else SBB_DEFAULT_PATCH
        pw = int(in_shape[2]) if in_shape[2] is not None else SBB_DEFAULT_PATCH
    except Exception:
        ph = pw = SBB_DEFAULT_PATCH

    step_h = max(1, int(ph * (1 - SBB_OVERLAP_FRACTION)))
    step_w = max(1, int(pw * (1 - SBB_OVERLAP_FRACTION)))

    # White-pad so the image fully tiles.
    pad_h = (step_h - (H - ph) % step_h) % step_h if H > ph else ph - H
    pad_w = (step_w - (W - pw) % step_w) % step_w if W > pw else pw - W
    padded = np.ones((H + max(0, pad_h), W + max(0, pad_w), 3), dtype=np.float32)
    padded[:H, :W] = rgb
    PH, PW = padded.shape[:2]

    prob_sum = np.zeros((PH, PW), dtype=np.float32)
    count = np.zeros((PH, PW), dtype=np.float32)

    for y in range(0, PH - ph + 1, step_h):
        for x in range(0, PW - pw + 1, step_w):
            patch = padded[y:y + ph, x:x + pw][None, ...]
            pred = model.predict(patch, verbose=0)[0]  # (h, w, 2)
            fg = pred[..., 1] if pred.ndim == 3 and pred.shape[-1] >= 2 else np.squeeze(pred)
            prob_sum[y:y + ph, x:x + pw] += fg
            count[y:y + ph, x:x + pw] += 1.0

    count[count == 0] = 1.0
    prob_map = (prob_sum / count)[:H, :W]

    if invert:
        prob_map = 1.0 - prob_map
    binary = ((prob_map < float(threshold)) * 255).astype(np.uint8)
    return Image.fromarray(binary)


# =============================================================================
# DISPATCH
# =============================================================================

def process_image(
    input_img, algo,
    sauvola_w, sauvola_k, sauvola_r,
    niblack_w, niblack_k,
    otsu_blur,
    adaptive_block, adaptive_c,
    threshold, sbb_invert,
):
    if input_img is None:
        raise gr.Error("Upload an image first.")
    try:
        if algo == M_SAUVOLA:
            return run_sauvola(input_img, sauvola_w, sauvola_k, sauvola_r)
        if algo == M_NIBLACK:
            return run_niblack(input_img, niblack_w, niblack_k)
        if algo == M_OTSU:
            return run_otsu(input_img, otsu_blur)
        if algo == M_ADAPTIVE:
            return run_adaptive_gaussian(input_img, adaptive_block, adaptive_c)
        if algo == M_B5:
            return run_b5(input_img, threshold)
        if algo == M_SBB:
            return run_sbb(input_img, threshold, sbb_invert)
        raise gr.Error(f"Unknown method: {algo}")
    except gr.Error:
        raise
    except Exception as e:
        # Surface the real failure; do NOT fall back to another method.
        raise gr.Error(f"{algo} failed: {e}")


# =============================================================================
# UI
# =============================================================================

def _visibility(algo):
    return [
        gr.update(visible=algo == M_SAUVOLA),
        gr.update(visible=algo == M_NIBLACK),
        gr.update(visible=algo == M_OTSU),
        gr.update(visible=algo == M_ADAPTIVE),
        gr.update(visible=algo in (M_B5, M_SBB)),
        gr.update(visible=algo == M_SBB),
    ]


with gr.Blocks(title="Xibi Binarization") as demo:
    gr.Markdown("# 📄 Document Image Binarization Suite")
    gr.Markdown(
        "Compare classical adaptive thresholding against neural binarizers. "
        "Each method runs on its own — if a model fails to load it reports the error, "
        "it never silently substitutes a different method."
    )

    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(type="pil", label="Upload Document")
            algo = gr.Dropdown(choices=METHODS, value=M_SAUVOLA, label="Binarization Method")

            with gr.Group(visible=True) as g_sauvola:
                s_w = gr.Slider(SAUVOLA_WINDOW_MIN, SAUVOLA_WINDOW_MAX, value=SAUVOLA_WINDOW_DEFAULT,
                                step=SAUVOLA_WINDOW_STEP, label="Window size")
                s_k = gr.Slider(SAUVOLA_K_MIN, SAUVOLA_K_MAX, value=SAUVOLA_K_DEFAULT,
                                step=SAUVOLA_K_STEP, label="k")
                s_r = gr.Slider(SAUVOLA_R_MIN, SAUVOLA_R_MAX, value=SAUVOLA_R_DEFAULT,
                                step=SAUVOLA_R_STEP, label="r (dynamic range)")

            with gr.Group(visible=False) as g_niblack:
                n_w = gr.Slider(NIBLACK_WINDOW_MIN, NIBLACK_WINDOW_MAX, value=NIBLACK_WINDOW_DEFAULT,
                                step=NIBLACK_WINDOW_STEP, label="Window size")
                n_k = gr.Slider(NIBLACK_K_MIN, NIBLACK_K_MAX, value=NIBLACK_K_DEFAULT,
                                step=NIBLACK_K_STEP, label="k")

            with gr.Group(visible=False) as g_otsu:
                o_blur = gr.Slider(OTSU_BLUR_MIN, OTSU_BLUR_MAX, value=OTSU_BLUR_DEFAULT,
                                   step=OTSU_BLUR_STEP, label="Gaussian pre-blur (0 = off)")

            with gr.Group(visible=False) as g_adaptive:
                a_block = gr.Slider(ADAPTIVE_BLOCK_MIN, ADAPTIVE_BLOCK_MAX, value=ADAPTIVE_BLOCK_DEFAULT,
                                    step=ADAPTIVE_BLOCK_STEP, label="Block size")
                a_c = gr.Slider(ADAPTIVE_C_MIN, ADAPTIVE_C_MAX, value=ADAPTIVE_C_DEFAULT,
                                step=ADAPTIVE_C_STEP, label="C (mean offset)")

            with gr.Group(visible=False) as g_neural:
                thr = gr.Slider(THRESHOLD_MIN, THRESHOLD_MAX, value=THRESHOLD_DEFAULT, step=THRESHOLD_STEP,
                                label="Binarization threshold",
                                info="Higher = thinner strokes, lower = thicker")

            with gr.Group(visible=False) as g_sbb:
                sbb_inv = gr.Checkbox(value=False, label="Invert SBB output (if foreground/background flipped)")

            submit_btn = gr.Button("Binarize", variant="primary")

        with gr.Column(scale=1):
            output_image = gr.Image(type="pil", label="Binarized Output")

    algo.change(
        fn=_visibility,
        inputs=algo,
        outputs=[g_sauvola, g_niblack, g_otsu, g_adaptive, g_neural, g_sbb],
    )

    submit_btn.click(
        fn=process_image,
        inputs=[
            input_image, algo,
            s_w, s_k, s_r,
            n_w, n_k,
            o_blur,
            a_block, a_c,
            thr, sbb_inv,
        ],
        outputs=output_image,
    )


if __name__ == "__main__":
    demo.launch(theme=gr.themes.Soft(), ssr_mode=False)