File size: 15,518 Bytes
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1664f8
 
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8890ad0
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1664f8
 
 
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1664f8
 
3e09c97
 
 
f1664f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e09c97
 
 
 
 
 
 
ba6cfa8
3e09c97
 
 
ba6cfa8
3e09c97
 
 
ba6cfa8
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1664f8
 
 
3e09c97
 
 
f1664f8
3e09c97
 
 
 
 
f1664f8
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d73e2bd
 
 
 
 
 
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8890ad0
 
 
 
3e09c97
8890ad0
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1664f8
 
 
 
 
 
 
 
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1664f8
3e09c97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d73e2bd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
"""
MONAI WholeBody CT Segmentation - Hugging Face Space
Segments 104 anatomical structures from CT scans using MONAI's SegResNet model
"""

import os
import tempfile
import numpy as np
import gradio as gr
import torch
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from huggingface_hub import hf_hub_download
from monai.networks.nets import SegResNet
from monai.transforms import (
    Compose,
    LoadImage,
    EnsureChannelFirst,
    Orientation,
    Spacing,
    ScaleIntensityRange,
    CropForeground,
    Activations,
    AsDiscrete,
)
from monai.inferers import sliding_window_inference
from labels import LABEL_NAMES, get_color_map, get_label_name, get_organ_categories
import trimesh
from skimage import measure

# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_REPO = "MONAI/wholeBody_ct_segmentation"
SPATIAL_SIZE = (96, 96, 96)
PIXDIM = (3.0, 3.0, 3.0)  # Low-res model spacing

# Global model variable
model = None


def load_model():
    """Download and load the MONAI SegResNet model"""
    global model
    if model is not None:
        return model
    
    print("Downloading model weights...")
    try:
        model_path = hf_hub_download(
            repo_id=MODEL_REPO,
            filename="models/model_lowres.pt",
        )
    except Exception as e:
        print(f"Failed to download from HF, trying alternative: {e}")
        # Fallback: try to download from MONAI model zoo
        model_path = hf_hub_download(
            repo_id=MODEL_REPO,
            filename="models/model.pt",
        )
    
    print(f"Loading model from {model_path}...")
    
    # Initialize SegResNet with 105 output channels (background + 104 classes)
    model = SegResNet(
        blocks_down=[1, 2, 2, 4],
        blocks_up=[1, 1, 1],
        init_filters=32,
        in_channels=1,
        out_channels=105,
        dropout_prob=0.2,
    )
    
    # Load weights
    checkpoint = torch.load(model_path, map_location=DEVICE)
    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model.load_state_dict(checkpoint)
    
    model.to(DEVICE)
    model.eval()
    print(f"Model loaded successfully on {DEVICE}")
    
    return model


def get_preprocessing_transforms():
    """Get MONAI preprocessing transforms"""
    return Compose([
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        Orientation(axcodes="RAS"),
        Spacing(pixdim=PIXDIM, mode="bilinear"),
        ScaleIntensityRange(
            a_min=-1024, a_max=1024,
            b_min=0.0, b_max=1.0,
            clip=True
        ),
    ])


def get_postprocessing_transforms():
    """Get MONAI postprocessing transforms"""
    return Compose([
        Activations(softmax=True),
        AsDiscrete(argmax=True),
    ])


def run_inference(image_path: str, progress=gr.Progress()):
    """Run segmentation inference on a CT image"""
    progress(0.1, desc="Loading model...")
    model = load_model()
    
    progress(0.2, desc="Preprocessing image...")
    preprocess = get_preprocessing_transforms()
    postprocess = get_postprocessing_transforms()
    
    # Load and preprocess
    image_nib = nib.load(image_path)
    original_data = image_nib.get_fdata() # Keep original data for visualization
    
    image = preprocess(image_path)
    image = image.unsqueeze(0).to(DEVICE)  # Add batch dimension
    
    progress(0.4, desc="Running segmentation (this may take a few minutes)...")
    
    with torch.no_grad():
        # Use sliding window inference for large volumes
        outputs = sliding_window_inference(
            image,
            roi_size=SPATIAL_SIZE,
            sw_batch_size=4,
            predictor=model,
            overlap=0.5,
        )
    
    progress(0.8, desc="Post-processing...")
    
    # Post-processing
    seg_data = postprocess(outputs).squeeze().cpu().numpy().astype(np.uint8)
    
    progress(1.0, desc="Complete!")
    
    return original_data, seg_data


def generate_3d_mesh(seg_data, step_size=2):
    """Generate a 3D mesh from segmentation data using Marching Cubes"""
    if seg_data is None or np.max(seg_data) == 0:
        return None
        
    try:
        # Create a boolean mask of all structures (excluding background 0)
        # Using a step_size > 1 reduces resolution but speeds up generation significantly
        # This is crucial for CPU performance on Hugging Face Spaces
        mask = seg_data > 0
        
        # Marching cubes to get vertices and faces
        # level=0.5 because boolean mask is 0 or 1
        verts, faces, normals, values = measure.marching_cubes(mask, level=0.5, step_size=step_size)
        
        # Create trimesh object
        mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
        
        # Export to a temporary GLB file (efficient binary format)
        temp_file = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
        mesh.export(temp_file.name)
        temp_file.close()
        
        return temp_file.name
    except Exception as e:
        print(f"Error generating 3D mesh: {e}")
        return None


def create_slice_visualization(ct_data, seg_data, axis, slice_idx, alpha=0.5, show_overlay=True):
    """Create a visualization of a CT slice with segmentation overlay"""
    
    # Get the slice based on axis
    if axis == "Axial":
        slice_idx = max(0, min(slice_idx, ct_data.shape[2] - 1))
        ct_slice = ct_data[:, :, slice_idx]
        seg_slice = seg_data[:, :, slice_idx] if seg_data is not None else None
    elif axis == "Coronal":
        slice_idx = max(0, min(slice_idx, ct_data.shape[1] - 1))
        ct_slice = ct_data[:, slice_idx, :]
        seg_slice = seg_data[:, slice_idx, :] if seg_data is not None else None
    else:  # Sagittal
        slice_idx = max(0, min(slice_idx, ct_data.shape[0] - 1))
        ct_slice = ct_data[slice_idx, :, :]
        seg_slice = seg_data[slice_idx, :, :] if seg_data is not None else None
    
    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(8, 8))
    
    # Normalize CT for display
    ct_normalized = np.clip(ct_slice, -1024, 1024)
    ct_normalized = (ct_normalized - ct_normalized.min()) / (ct_normalized.max() - ct_normalized.min() + 1e-8)
    
    # Display CT
    ax.imshow(ct_normalized.T, cmap='gray', origin='lower')
    
    # Overlay segmentation
    if show_overlay and seg_slice is not None and np.any(seg_slice > 0):
        colors = get_color_map() / 255.0
        colors[0] = [0, 0, 0, 0]  # Make background transparent
        
        # Create RGBA overlay
        seg_rgba = colors[seg_slice.astype(int)]
        seg_rgba = np.concatenate([seg_rgba, np.ones((*seg_slice.shape, 1)) * alpha], axis=-1)
        seg_rgba[seg_slice == 0, 3] = 0  # Transparent background
        
        ax.imshow(seg_rgba.transpose(1, 0, 2), origin='lower')
    
    ax.axis('off')
    ax.set_title(f"{axis} View - Slice {slice_idx}")
    
    plt.tight_layout()
    return fig


def get_detected_structures(seg_data):
    """Get list of detected anatomical structures"""
    unique_labels = np.unique(seg_data)
    unique_labels = unique_labels[unique_labels > 0]  # Exclude background
    
    structures = []
    for label in unique_labels:
        name = get_label_name(label)
        count = np.sum(seg_data == label)
        structures.append(f"• {name} (Label {label})")
    
    return "\n".join(structures) if structures else "No structures detected"


# Global state for current visualization
current_ct_data = None
current_seg_data = None


def process_upload(file_path, progress=gr.Progress()):
    """Process uploaded CT file and run segmentation"""
    global current_ct_data, current_seg_data
    
    if file_path is None:
        return None, "Please upload a NIfTI file", gr.update(maximum=1), gr.update(maximum=1), gr.update(maximum=1)
    
    try:
        ct_data, seg_data = run_inference(file_path, progress)
        current_ct_data = ct_data
        current_seg_data = seg_data
        
        # Get initial visualization
        mid_axial = ct_data.shape[2] // 2
        mid_coronal = ct_data.shape[1] // 2
        mid_sagittal = ct_data.shape[0] // 2
        
        fig = create_slice_visualization(ct_data, seg_data, "Axial", mid_axial)
        structures = get_detected_structures(seg_data)
        
        # Generate 3D mesh (this might take a few seconds)
        mesh_path = generate_3d_mesh(seg_data)
        
        return (
            fig,
            structures,
            mesh_path,
            gr.update(maximum=ct_data.shape[2] - 1, value=mid_axial),
            gr.update(maximum=ct_data.shape[1] - 1, value=mid_coronal),
            gr.update(maximum=ct_data.shape[0] - 1, value=mid_sagittal),
        )
    except Exception as e:
        return None, f"Error processing file: {str(e)}", None, gr.update(), gr.update(), gr.update()


def update_visualization(axis, slice_idx, alpha, show_overlay):
    """Update the visualization based on slider changes"""
    global current_ct_data, current_seg_data
    
    if current_ct_data is None:
        return None
    
    fig = create_slice_visualization(
        current_ct_data,
        current_seg_data,
        axis,
        int(slice_idx),
        alpha,
        show_overlay
    )
    return fig


def load_example(example_name):
    """Load a bundled example CT scan"""
    example_dir = os.path.join(os.path.dirname(__file__), "examples")
    example_path = os.path.join(example_dir, example_name)
    
    if os.path.exists(example_path):
        return example_path
    return None


# Create Gradio interface
with gr.Blocks(
    title="MONAI WholeBody CT Segmentation",
    theme=gr.themes.Soft(),
    css="""
    .gradio-container {max-width: 1200px !important}
    .output-image {min-height: 500px}
    """
) as demo:
    gr.Markdown("""
    # 🏥 MONAI WholeBody CT Segmentation
    
    **Automatic segmentation of 104 anatomical structures from CT scans**
    
    This application uses MONAI's pre-trained SegResNet model trained on the TotalSegmentator dataset.
    Upload a CT scan in NIfTI format (.nii or .nii.gz) to get started.
    
    > ⚡ **Note**: Processing may take 1-5 minutes depending on the CT volume size.
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            # Input section
            gr.Markdown("### 📤 Upload CT Scan")
            file_input = gr.File(
                label="Upload NIfTI file (.nii, .nii.gz)",
                file_types=[".nii", ".nii.gz", ".gz"],
                type="filepath"
            )
            
            # Example files
            gr.Markdown("### 📁 Example Files")
            
            # Dynamically list all .nii.gz files in examples folder
            example_files = [[os.path.join("examples", f)] for f in os.listdir("examples") if f.endswith(".nii.gz")]
            
            example_gallery = gr.Examples(
                examples=example_files,
                inputs=[file_input],
                label="Click to load example"
            )
            
            process_btn = gr.Button("🔬 Run Segmentation", variant="primary", size="lg")
            
            # Visualization controls
            gr.Markdown("### 🎛️ Visualization Controls")
            
            view_axis = gr.Radio(
                choices=["Axial", "Coronal", "Sagittal"],
                value="Axial",
                label="View Axis"
            )
            
            with gr.Row():
                axial_slider = gr.Slider(0, 100, value=50, step=1, label="Axial Slice")
                coronal_slider = gr.Slider(0, 100, value=50, step=1, label="Coronal Slice")
                sagittal_slider = gr.Slider(0, 100, value=50, step=1, label="Sagittal Slice")
            
            alpha_slider = gr.Slider(0, 1, value=0.5, step=0.1, label="Overlay Opacity")
            show_overlay = gr.Checkbox(value=True, label="Show Segmentation Overlay")
            
        with gr.Column(scale=2):
            # Output section
            gr.Markdown("### 🖼️ Segmentation Result")
            output_image = gr.Plot(label="CT with Segmentation Overlay")
            
            gr.Markdown("### 📋 Detected Structures")
            structures_output = gr.Textbox(
                label="Anatomical Structures Found",
                lines=10,
                max_lines=20
            )
            
            # 3D Model Output
            gr.Markdown("### 🧊 3D View")
            model_3d_output = gr.Model3D(
                label="3D Segmentation Mesh",
                clear_color=[0.0, 0.0, 0.0, 0.0],
                camera_position=(90, 90, 3) 
            )
    
    # Model info section
    with gr.Accordion("ℹ️ Model Information", open=False):
        gr.Markdown("""
        ### About the Model
        
        This model is based on **SegResNet** architecture from MONAI, trained on the **TotalSegmentator** dataset.
        
        **Capabilities:**
        - Segments 104 distinct anatomical structures
        - Works on whole-body CT scans
        - Uses 3.0mm isotropic spacing (low-resolution model for faster inference)
        
        **Segmented Structures include:**
        - **Major Organs**: Liver, Spleen, Kidneys, Pancreas, Gallbladder, Stomach, Bladder
        - **Cardiovascular**: Heart chambers, Aorta, Vena Cava, Portal Vein
        - **Respiratory**: Lung lobes, Trachea
        - **Skeletal**: Vertebrae (C1-L5), Ribs, Hip bones, Femur, Humerus, Scapula
        - **Muscles**: Gluteal muscles, Iliopsoas
        - And many more...
        
        **References:**
        - [MONAI Model Zoo](https://monai.io/model-zoo.html)
        - [TotalSegmentator Paper](https://pubs.rsna.org/doi/10.1148/ryai.230024)
        """)
    
    # Event handlers
    process_btn.click(
        fn=process_upload,
        inputs=[file_input],
        outputs=[output_image, structures_output, model_3d_output, axial_slider, coronal_slider, sagittal_slider]
    )
    
    # Update visualization when controls change
    for control in [view_axis, alpha_slider, show_overlay]:
        control.change(
            fn=lambda axis, alpha, overlay, ax_s, cor_s, sag_s: update_visualization(
                axis,
                ax_s if axis == "Axial" else (cor_s if axis == "Coronal" else sag_s),
                alpha,
                overlay
            ),
            inputs=[view_axis, alpha_slider, show_overlay, axial_slider, coronal_slider, sagittal_slider],
            outputs=[output_image]
        )
    
    # Update when sliders change
    axial_slider.change(
        fn=lambda s, alpha, overlay: update_visualization("Axial", s, alpha, overlay),
        inputs=[axial_slider, alpha_slider, show_overlay],
        outputs=[output_image]
    )
    
    coronal_slider.change(
        fn=lambda s, alpha, overlay: update_visualization("Coronal", s, alpha, overlay),
        inputs=[coronal_slider, alpha_slider, show_overlay],
        outputs=[output_image]
    )
    
    sagittal_slider.change(
        fn=lambda s, alpha, overlay: update_visualization("Sagittal", s, alpha, overlay),
        inputs=[sagittal_slider, alpha_slider, show_overlay],
        outputs=[output_image]
    )


if __name__ == "__main__":
    # Ensure examples directory exists
    os.makedirs("examples", exist_ok=True)
    
    # Launch the app
    demo.launch()