Image Segmentation
Transformers
English
clipseg
segmentation
construction
drywall
quality-assurance
text-conditioned
binary-mask
Instructions to use youngPhilosopher/drywall-qa-clipseg with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use youngPhilosopher/drywall-qa-clipseg with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="youngPhilosopher/drywall-qa-clipseg")# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("youngPhilosopher/drywall-qa-clipseg", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Evaluate trained CLIPSeg model and generate prediction masks + visuals.""" | |
| import json | |
| import time | |
| from pathlib import Path | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from src.data.dataset import DrywallSegDataset, collate_fn | |
| from src.model.clipseg_wrapper import load_model_and_processor | |
| from src.train import compute_metrics, get_device | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| def evaluate(config_path: str | None = None): | |
| config_path = config_path or str(PROJECT_ROOT / "configs" / "train_config.yaml") | |
| with open(config_path) as f: | |
| config = yaml.safe_load(f) | |
| device = get_device() | |
| threshold = config["evaluation"]["threshold"] | |
| # Load model with best checkpoint | |
| model, processor = load_model_and_processor(config["model"]["name"], config["model"]["freeze_backbone"]) | |
| ckpt_path = PROJECT_ROOT / "outputs" / "checkpoints" / "best_model.pt" | |
| model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True)) | |
| model = model.to(device) | |
| model.eval() | |
| # Model size | |
| model_size_mb = sum(p.numel() * p.element_size() for p in model.parameters()) / (1024 * 1024) | |
| # Test data | |
| splits_dir = PROJECT_ROOT / "data" / "splits" | |
| test_ds = DrywallSegDataset(str(splits_dir / "test.json"), processor, config["data"]["image_size"]) | |
| test_loader = DataLoader(test_ds, batch_size=config["training"]["batch_size"], shuffle=False, | |
| collate_fn=collate_fn, num_workers=0) | |
| # Run evaluation | |
| masks_dir = PROJECT_ROOT / "outputs" / "masks" | |
| masks_dir.mkdir(parents=True, exist_ok=True) | |
| all_metrics = {"taping": {"miou": [], "dice": []}, "cracks": {"miou": [], "dice": []}} | |
| inference_times = [] | |
| visual_examples = [] # Collect for visualization | |
| total_samples = 0 | |
| with torch.no_grad(): | |
| for batch in tqdm(test_loader, desc="Evaluating"): | |
| pixel_values = batch["pixel_values"].to(device) | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| t0 = time.time() | |
| outputs = model(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask) | |
| inference_times.append((time.time() - t0) / pixel_values.size(0)) | |
| logits = outputs.logits | |
| metrics = compute_metrics(logits, labels, threshold) | |
| preds = (torch.sigmoid(logits) > threshold).cpu().numpy().astype(np.uint8) | |
| for i in range(pixel_values.size(0)): | |
| ds_name = batch["dataset"][i] | |
| all_metrics[ds_name]["miou"].append(metrics["miou"]) | |
| all_metrics[ds_name]["dice"].append(metrics["dice"]) | |
| # Save prediction mask at original resolution | |
| orig_w, orig_h = batch["orig_width"][i], batch["orig_height"][i] | |
| pred_mask = Image.fromarray(preds[i] * 255, mode="L") | |
| pred_mask = pred_mask.resize((orig_w, orig_h), Image.NEAREST) | |
| prompt_slug = batch["prompt"][i].replace(" ", "_") | |
| img_stem = Path(batch["image_path"][i]).stem | |
| mask_filename = f"{img_stem}__{prompt_slug}.png" | |
| pred_mask.save(masks_dir / mask_filename) | |
| total_samples += 1 | |
| # Collect visual examples | |
| if len(visual_examples) < config["evaluation"]["num_visual_examples"]: | |
| visual_examples.append({ | |
| "image_path": batch["image_path"][i], | |
| "mask_path": batch["mask_path"][i], | |
| "pred_mask": preds[i], | |
| "prompt": batch["prompt"][i], | |
| "dataset": ds_name, | |
| }) | |
| # Aggregate metrics | |
| results = {"per_class": {}, "overall": {}} | |
| all_miou, all_dice = [], [] | |
| for ds_name in ["taping", "cracks"]: | |
| m = all_metrics[ds_name] | |
| if m["miou"]: | |
| results["per_class"][ds_name] = { | |
| "miou": round(float(np.mean(m["miou"])), 4), | |
| "dice": round(float(np.mean(m["dice"])), 4), | |
| "samples": len(m["miou"]), | |
| } | |
| all_miou.extend(m["miou"]) | |
| all_dice.extend(m["dice"]) | |
| results["overall"] = { | |
| "miou": round(float(np.mean(all_miou)), 4) if all_miou else 0, | |
| "dice": round(float(np.mean(all_dice)), 4) if all_dice else 0, | |
| "total_samples": total_samples, | |
| } | |
| results["runtime"] = { | |
| "avg_inference_ms": round(float(np.mean(inference_times)) * 1000, 1), | |
| "model_size_mb": round(model_size_mb, 1), | |
| } | |
| # Save results | |
| log_dir = PROJECT_ROOT / "outputs" / "logs" | |
| log_dir.mkdir(parents=True, exist_ok=True) | |
| with open(log_dir / "test_results.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\n{'='*60}") | |
| print(f"Test Results") | |
| print(f"{'='*60}") | |
| for ds_name, m in results["per_class"].items(): | |
| print(f" {ds_name:>10s}: mIoU={m['miou']:.4f} Dice={m['dice']:.4f} (n={m['samples']})") | |
| print(f" {'overall':>10s}: mIoU={results['overall']['miou']:.4f} Dice={results['overall']['dice']:.4f}") | |
| print(f" Avg inference: {results['runtime']['avg_inference_ms']:.1f} ms/image") | |
| print(f" Model size: {results['runtime']['model_size_mb']:.1f} MB") | |
| # Generate visual comparison figures | |
| _generate_visuals(visual_examples, PROJECT_ROOT / "reports" / "figures") | |
| return results | |
| def _generate_visuals(examples: list[dict], output_dir: Path): | |
| """Generate original | GT | prediction comparison figures.""" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| if not examples: | |
| return | |
| fig, axes = plt.subplots(len(examples), 3, figsize=(12, 4 * len(examples))) | |
| if len(examples) == 1: | |
| axes = [axes] | |
| for i, ex in enumerate(examples): | |
| img = Image.open(ex["image_path"]).convert("RGB") | |
| gt = Image.open(ex["mask_path"]).convert("L") | |
| pred = Image.fromarray(ex["pred_mask"] * 255, mode="L") | |
| axes[i][0].imshow(img) | |
| axes[i][0].set_title(f"Original ({ex['dataset']})") | |
| axes[i][0].axis("off") | |
| axes[i][1].imshow(gt, cmap="gray", vmin=0, vmax=255) | |
| axes[i][1].set_title("Ground Truth") | |
| axes[i][1].axis("off") | |
| axes[i][2].imshow(pred, cmap="gray", vmin=0, vmax=255) | |
| axes[i][2].set_title(f"Prediction: \"{ex['prompt']}\"") | |
| axes[i][2].axis("off") | |
| plt.tight_layout() | |
| plt.savefig(output_dir / "visual_comparison.png", dpi=150, bbox_inches="tight") | |
| plt.close() | |
| print(f"Saved visual comparison to {output_dir / 'visual_comparison.png'}") | |
| if __name__ == "__main__": | |
| evaluate() | |