| | from typing import Dict, List, Any |
| | from transformers import pipeline,CLIPSegProcessor, CLIPSegForImageSegmentation |
| | from PIL import Image |
| | import torch |
| | import base64 |
| | import io |
| | import numpy as np |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | |
| | |
| | |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | self.processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
| | self.model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(self.device) |
| | self.depth_pipe = pipeline("depth-estimation", model="depth-anything/Depth-Anything-V2-Small-hf") |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | data args: |
| | inputs (:obj: `str` | `PIL.Image` | `np.array`) |
| | kwargs |
| | Return: |
| | A :obj:`list` | `dict`: will be serialized and returned |
| | """ |
| | if "inputs" not in data: |
| | return [{"error": "Missing 'inputs' key"}] |
| |
|
| | inputs_data = data["inputs"] |
| | if "image" not in inputs_data or "text" not in inputs_data: |
| | return [{"error": "Missing 'image' or 'text' key in input data"}] |
| |
|
| | try: |
| | |
| | image = self.decode_image(inputs_data["image"]) |
| | prompts = inputs_data["text"] |
| | |
| | |
| | inputs = self.processor( |
| | text=prompts, |
| | images=[image] * len(prompts), |
| | padding="max_length", |
| | return_tensors="pt" |
| | ).to("cuda") |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = self.model(**inputs) |
| | |
| | segmentation_mask = outputs.logits.cpu().numpy() |
| | segmentation_mask = segmentation_mask.squeeze() |
| |
|
| | segmentation_mask = (segmentation_mask - segmentation_mask.min()) / (segmentation_mask.max() - segmentation_mask.min() + 1e-6) |
| | segmentation_mask = (segmentation_mask * 255).astype(np.uint8) |
| | |
| | seg_image = Image.fromarray(segmentation_mask) |
| |
|
| | seg_image_base64 = self.encode_image(seg_image) |
| |
|
| | return [{"seg_image": seg_image_base64}] |
| | |
| | except Exception as e: |
| | return [{"error": str(e)}] |
| |
|
| | |
| | def decode_image(self, image_data: str) -> Image.Image: |
| | """Decodes a base64-encoded image into a PIL image.""" |
| | image_bytes = base64.b64decode(image_data) |
| | return Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| |
|
| | def encode_image(self, image: Image.Image) -> str: |
| | """Encodes a PIL image to a base64 string.""" |
| | buffered = io.BytesIO() |
| | image.save(buffered, format="PNG") |
| | return base64.b64encode(buffered.getvalue()).decode("utf-8") |
| | |
| | def process_depth(self, image): |
| | print("Processing depth") |
| | print(type(image)) |
| | if isinstance(image, np.ndarray): |
| | image = Image.fromarray(image.astype("uint8")) |
| | output = self.depth_pipe(image) |
| | depth_map = np.array(output["depth"]) |
| |
|
| | |
| | depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min() + 1e-6) |
| | depth_map = (depth_map * 255).astype(np.uint8) |
| |
|
| | return Image.fromarray(depth_map) |
| | |
| |
|