| # ๐ง ClipSegMultiClass |
|
|
| Multiclass semantic segmentation using CLIP + CLIPSeg. |
| Fine-tuned version of [`CIDAS/clipseg-rd64-refined`](https://huggingface.co/CIDAS/clipseg-rd64-refined) |
| Supports multiple classes in a single forward pass. |
|
|
| --- |
|
|
| ## ๐ฌ Model |
|
|
| **Name:** [`BioMike/clipsegmulticlass_v1`](https://huggingface.co/BioMike/clipsegmulticlass_v1) |
| **Repository:** [github.com/BioMikeUkr/clipsegmulticlass](https://github.com/BioMikeUkr/clipsegmulticlass) |
| **Base:** `CIDAS/clipseg-rd64-refined` |
| **Classes:** `["background", "Pig", "Horse", "Sheep"]` |
| **Image Size:** 352ร352 |
| **Trained on:** OpenImages segmentation subset (custom fruit/animal dataset) |
|
|
| --- |
|
|
| ## ๐ Evaluation |
|
|
| | Model | Precision | Recall | F1 Score | Accuracy | |
| |-----------------------------|-----------|---------|----------|----------| |
| | CIDAS/clipseg-rd64-refined | 0.5239 | 0.2114 | 0.2882 | 0.2665 | |
| | BioMike/clipsegmulticlass_v1| 0.7460 | 0.5035 | 0.6009 | 0.6763 | |
| |
| --- |
| |
| ## ๐ฎ Demo |
| |
| ๐ Try it online: |
| [Hugging Face Space ๐](https://huggingface.co/spaces/BioMike/clipsegmulticlass) |
| |
| --- |
| |
| ## ๐ฆ Usage |
| |
| ```python |
| from PIL import Image |
| import torch |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from model import ClipSegMultiClassModel |
| from config import ClipSegMultiClassConfig |
| |
| # Load model |
| model = ClipSegMultiClassModel.from_pretrained("trained_clipseg_multiclass").to("cuda").eval() |
| config = model.config # contains label2color |
|
|
| # Load image |
| image = Image.open("pigs.jpg").convert("RGB") |
|
|
| # Run inference |
| mask = model.predict(image) # shape: [1, H, W] |
|
|
| # Visualize |
| def visualize_mask(mask_tensor: torch.Tensor, label2color: dict): |
| if mask_tensor.dim() == 3: |
| mask_tensor = mask_tensor.squeeze(0) |
| |
| mask_np = mask_tensor.cpu().numpy().astype(np.uint8) # [H, W] |
| h, w = mask_np.shape |
| color_mask = np.zeros((h, w, 3), dtype=np.uint8) |
| |
| for class_idx, color in label2color.items(): |
| color_mask[mask_np == class_idx] = color |
| |
| return color_mask |
| |
| color_mask = visualize_mask(mask, config.label2color) |
|
|
| plt.imshow(color_mask) |
| plt.axis("off") |
| plt.title("Predicted Segmentation Mask") |
| plt.show() |
| |