Image Segmentation
Flair
Keras
tensorflow
medical-imaging
white-matter-hyperintensities
mri
deep-learning
neurology
multiple-sclerosis
Instructions to use Bawil/wmh_leverage_normal_abnormal_segmentation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Flair
How to use Bawil/wmh_leverage_normal_abnormal_segmentation with Flair:
from flair.models import SequenceTagger tagger = SequenceTagger.load("Bawil/wmh_leverage_normal_abnormal_segmentation") - Keras
How to use Bawil/wmh_leverage_normal_abnormal_segmentation with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://Bawil/wmh_leverage_normal_abnormal_segmentation") - Notebooks
- Google Colab
- Kaggle
| license: mit | |
| tags: | |
| - medical-imaging | |
| - image-segmentation | |
| - white-matter-hyperintensities | |
| - mri | |
| - flair | |
| - deep-learning | |
| - tensorflow | |
| - keras | |
| - neurology | |
| - multiple-sclerosis | |
| datasets: | |
| - custom | |
| - msseg2016 | |
| metrics: | |
| - dice-coefficient | |
| - hausdorff-distance | |
| library_name: tensorflow | |
| pipeline_tag: image-segmentation | |
| # WMH Segmentation: Normal vs Abnormal Classification | |
| Pre-trained models for **white matter hyperintensity (WMH) segmentation** with explicit distinction between normal periventricular changes and pathological lesions. | |
| ## Model Description | |
| This repository contains 8 pre-trained deep learning models (4 architectures × 2 training scenarios) for automated WMH segmentation from FLAIR MRI images. The models implement a novel **three-class approach** that distinguishes between: | |
| - **Class 0**: Background | |
| - **Class 1**: Normal WMH (aging-related periventricular changes) | |
| - **Class 2**: Abnormal WMH (pathologically significant lesions) | |
| This approach addresses the critical challenge of false positive detection in periventricular regions, achieving up to **27.1% improvement** in Dice coefficient compared to traditional binary segmentation. | |
| ## Model Architectures | |
| | Architecture | Parameters | Best Dice (3-Class) | Binary Baseline | Improvement | | |
| |--------------|-----------|---------------------|-----------------|-------------| | |
| | **U-Net** ⭐ | 31.0M | **0.768** | 0.497 | **+54.5%** | | |
| | **Attention U-Net** | 34.9M | 0.740 | 0.486 | +52.1% | | |
| | **TransUNet** | 105.3M | 0.700 | 0.510 | +37.3% | | |
| | **DeepLabV3Plus** | 40.3M | 0.586 | 0.374 | +56.7% | | |
| ⭐ **Recommended**: U-Net with Scenario 2 (three-class) for optimal performance | |
| ## Repository Structure | |
| ``` | |
| models/ | |
| ├── unet/models/ | |
| │ ├── scenario1_binary_model.h5 # Binary: Background vs Abnormal | |
| │ └── scenario2_multiclass_model.h5 # 3-Class: Background, Normal, Abnormal | |
| ├── attention_unet/models/ | |
| │ ├── scenario1_binary_model.h5 | |
| │ └── scenario2_multiclass_model.h5 | |
| ├── deeplabv3plus/models/ | |
| │ ├── scenario1_binary_model.h5 | |
| │ └── scenario2_multiclass_model.h5 | |
| └── transunet/models/ | |
| ├── scenario1_binary_model.h5 | |
| └── scenario2_multiclass_model.h5 | |
| ``` | |
| ## Quick Start | |
| ### Installation | |
| ```bash | |
| pip install huggingface_hub tensorflow numpy nibabel | |
| ``` | |
| ### Download Models | |
| ```python | |
| from huggingface_hub import hf_hub_download | |
| # Download best performing model (U-Net Three-Class) | |
| model_path = hf_hub_download( | |
| repo_id="Bawil/wmh_leverage_normal_abnormal_segmentation", | |
| filename="unet/models/scenario2_multiclass_model.h5" | |
| ) | |
| # Load model | |
| from tensorflow.keras.models import load_model | |
| model = load_model(model_path) | |
| ``` | |
| ### Inference Example | |
| ```python | |
| import numpy as np | |
| from tensorflow.keras.models import load_model | |
| # Load pre-trained model | |
| model = load_model(model_path) | |
| # Prepare input (256x256 grayscale FLAIR MRI, normalized) | |
| # input_image shape: (batch_size, 256, 256, 1) | |
| input_image = preprocess_flair(your_flair_image) | |
| # Run inference | |
| predictions = model.predict(input_image) | |
| # Get class predictions | |
| predicted_classes = np.argmax(predictions, axis=-1) | |
| # 0: Background | |
| # 1: Normal WMH (periventricular) | |
| # 2: Abnormal WMH (pathological) | |
| # Extract pathological lesions only | |
| abnormal_mask = (predicted_classes == 2).astype(np.uint8) | |
| ``` | |
| ## Training Data | |
| ### Dataset Composition | |
| - **Local Dataset**: 100 MS patients (2,000 FLAIR MRI slices) | |
| - Demographics: 26 males, 74 females | |
| - Age range: 18-68 years | |
| - Scanner: 1.5-Tesla TOSHIBA Vantage | |
| - **Public Dataset**: MSSEG2016 (15 patients, 750 FLAIR slices) | |
| ### Annotations | |
| - Expert annotations by board-certified neuroradiologists (20+ years experience) | |
| - Three-class labeling: Background, Normal WMH, Abnormal WMH | |
| - Approved by Ethics Committee (IR.TBZMED.REC.1402.902) | |
| ### Data Split | |
| - **Training**: 80% patients (local) + 60% patients (public) | |
| - **Validation**: 10% patients (local) + 20% patients (public) | |
| - **Testing**: 10% patients (local) + 20% patients (public) | |
| - **Strategy**: Patient-level stratified split (no slice-level leakage) | |
| ## Model Training | |
| ### Configuration | |
| - **Framework**: TensorFlow 2.11, Keras | |
| - **Optimizer**: Adam (learning rate: 0.0001) | |
| - **Loss Functions**: | |
| - Scenario 1: Weighted binary cross-entropy | |
| - Scenario 2: Weighted categorical cross-entropy | |
| - **Epochs**: 50 (with early stopping) | |
| - **Batch Size**: 8 | |
| - **Input Size**: 256×256×1 | |
| - **Data Augmentation**: Rotation, flipping, elastic deformation | |
| ### Hardware | |
| - **GPU**: NVIDIA RTX 3060 (12GB VRAM) | |
| - **Training Time**: 2-3 hours per model | |
| - **Inference Time**: ~35-40ms per image | |
| ## Model Performance | |
| ### Dice Coefficient (Primary Metric) | |
| | Model | Scenario 1 | Scenario 2 | Δ Improvement | p-value | Cohen's d | | |
| |-------|-----------|-----------|---------------|---------|-----------| | |
| | U-Net | 0.497±0.145 | **0.768±0.124** | **+0.271** | <0.0001 | 0.564 | | |
| | Attention U-Net | 0.486±0.157 | 0.740±0.133 | +0.253 | <0.0001 | 0.442 | | |
| | TransUNet | 0.510±0.116 | 0.700±0.097 | +0.190 | <0.0001 | 0.478 | | |
| | DeepLabV3Plus | 0.374±0.110 | 0.586±0.092 | +0.212 | <0.0001 | 0.565 | | |
| ### Additional Metrics | |
| - **Hausdorff Distance**: 27.4mm (U-Net 3-class) vs 29.8mm (binary) | |
| - **Precision**: Significant improvement in pathological lesion detection | |
| - **False Positive Reduction**: Marked decrease in periventricular regions | |
| - **Clinical Feasibility**: 1.5s total processing time per case (40 slices) | |
| ### Statistical Validation | |
| - Paired t-tests confirm significant improvements (all p < 0.0001) | |
| - Effect sizes range from medium (0.44) to large (0.56) | |
| - 95% confidence intervals reported for all metrics | |
| - Wilcoxon signed-rank test for non-parametric validation | |
| ## Use Cases | |
| ### Clinical Applications | |
| - **MS Lesion Quantification**: Accurate measurement of disease burden | |
| - **Differential Diagnosis**: Distinguish pathological from normal aging | |
| - **Longitudinal Monitoring**: Track disease progression over time | |
| - **Treatment Response**: Evaluate therapeutic efficacy | |
| - **Radiological Reporting**: Reduce false positive alerts | |
| ### Research Applications | |
| - **Baseline Comparisons**: Standardized evaluation framework | |
| - **Method Development**: Foundation for advanced segmentation approaches | |
| - **Multi-center Studies**: Protocol for broader validation | |
| - **Reproducible Research**: Complete implementation available | |
| ## Limitations | |
| - **Single Modality**: Trained on FLAIR MRI only | |
| - **Scanner Specificity**: Primarily 1.5T TOSHIBA data | |
| - **Disease Focus**: Optimized for MS patients | |
| - **2D Segmentation**: Slice-by-slice processing (no 3D context) | |
| - **Resolution**: Fixed 256×256 input size | |
| ## Model Card | |
| ### Intended Use | |
| - **Primary**: Automated WMH segmentation for research and clinical decision support | |
| - **Users**: Radiologists, neurologists, researchers, AI developers | |
| - **Out-of-scope**: Not FDA/CE approved; not for standalone clinical diagnosis | |
| ### Ethical Considerations | |
| - **Privacy**: All data anonymized per HIPAA/GDPR standards | |
| - **Bias**: Limited scanner/protocol diversity may affect generalization | |
| - **Clinical Validation**: Requires expert review before clinical use | |
| - **Transparency**: Complete methodology and code openly available | |
| ### Model Card Authors | |
| Mahdi Bashiri Bawil, Mousa Shamsi, Ali Fahmi Jafargholkhanloo, Abolhassan Shakeri Bavil | |
| ## Citation | |
| ```bibtex | |
| @article{bawil2025wmh, | |
| title={Incorporating Normal Periventricular Changes for Enhanced Pathological | |
| White Matter Hyperintensity Segmentation: On Multi-Class Deep Learning Approaches}, | |
| author={Bawil, Mahdi Bashiri and Shamsi, Mousa and Jafargholkhanloo, Ali Fahmi and | |
| Bavil, Abolhassan Shakeri}, | |
| year={2025}, | |
| note={Models: https://huggingface.co/Bawil/wmh_leverage_normal_abnormal_segmentation} | |
| } | |
| ``` | |
| ## License | |
| MIT License - See [LICENSE](https://github.com/Mahdi-Bashiri/wmh-normal-abnormal-segmentation/blob/main/LICENSE) | |
| ## Additional Resources | |
| - **📄 Paper**: [Under Review] | |
| - **💻 GitHub Repository**: [Mahdi-Bashiri/wmh-normal-abnormal-segmentation](https://github.com/Mahdi-Bashiri/wmh-normal-abnormal-segmentation) | |
| - **📧 Contact**: m_bashiri99@sut.ac.ir | |
| - **🏥 Institution**: Sahand University of Technology & Tabriz University of Medical Sciences | |
| ## Acknowledgments | |
| - **Golgasht Medical Imaging Center**, Tabriz, Iran for providing clinical data | |
| - Expert neuroradiologists for manual annotations | |
| - Ethics Committee approval: IR.TBZMED.REC.1402.902 | |
| --- | |
| **Keywords**: white matter hyperintensities, FLAIR MRI, medical imaging, deep learning, image segmentation, multiple sclerosis, U-Net, attention mechanisms, transformers, clinical AI | |