| | import torch |
| | import torchvision |
| | from PIL import Image |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | import gradio as gr |
| |
|
| | |
| | model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) |
| | model.eval() |
| |
|
| | |
| | COCO_INSTANCE_CATEGORY_NAMES = [ |
| | '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', |
| | 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', |
| | 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', |
| | 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', |
| | 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', |
| | 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', |
| | 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', |
| | 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', |
| | 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', |
| | 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', |
| | 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', |
| | 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', |
| | 'hair drier', 'toothbrush' |
| | ] |
| |
|
| | |
| | def segment_objects(image, threshold=0.5): |
| | transform = torchvision.transforms.ToTensor() |
| | img_tensor = transform(image).unsqueeze(0) |
| |
|
| | with torch.no_grad(): |
| | output = model(img_tensor)[0] |
| |
|
| | masks = output['masks'] |
| | boxes = output['boxes'] |
| | labels = output['labels'] |
| | scores = output['scores'] |
| |
|
| | image_np = np.array(image).copy() |
| | fig, ax = plt.subplots(1, figsize=(10, 10)) |
| | ax.imshow(image_np) |
| |
|
| | for i in range(len(masks)): |
| | if scores[i] >= threshold: |
| | mask = masks[i, 0].cpu().numpy() |
| | mask = mask > 0.5 |
| |
|
| | |
| | color = np.random.rand(3) |
| | colored_mask = np.zeros_like(image_np, dtype=np.uint8) |
| | for c in range(3): |
| | colored_mask[:, :, c] = mask * int(color[c] * 255) |
| |
|
| | |
| | image_np = np.where(mask[:, :, None], 0.5 * image_np + 0.5 * colored_mask, image_np).astype(np.uint8) |
| |
|
| | |
| | x1, y1, x2, y2 = boxes[i].cpu().numpy() |
| | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, |
| | fill=False, color=color, linewidth=2)) |
| | label = COCO_INSTANCE_CATEGORY_NAMES[labels[i].item()] |
| | ax.text(x1, y1, f"{label}: {scores[i]:.2f}", |
| | bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10) |
| |
|
| | ax.imshow(image_np) |
| | ax.axis('off') |
| | output_path = "output_maskrcnn_with_masks.png" |
| | plt.savefig(output_path, bbox_inches='tight', pad_inches=0) |
| | plt.close() |
| | return output_path |
| |
|
| | |
| | interface = gr.Interface( |
| | fn=segment_objects, |
| | inputs=[ |
| | gr.Image(type="pil", label="Upload Image"), |
| | gr.Slider(0.0, 1.0, value=0.5, step=0.05, label="Confidence Threshold") |
| | ], |
| | outputs=gr.Image(type="filepath", label="Segmented Output"), |
| | title="Mask R-CNN Instance Segmentation", |
| | description="Upload an image to detect and segment objects using a pretrained Mask R-CNN model (TorchVision)." |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | interface.launch(debug=True) |
| |
|