| import torch |
| import torchvision.transforms.v2 as T |
| import torch.nn.functional as F |
| from .utils import expand_mask |
|
|
| class LoadCLIPSegModels: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": {}, |
| } |
|
|
| RETURN_TYPES = ("CLIP_SEG",) |
| FUNCTION = "execute" |
| CATEGORY = "essentials/segmentation" |
|
|
| def execute(self): |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation |
| processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") |
| model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") |
|
|
| return ((processor, model),) |
|
|
| class ApplyCLIPSeg: |
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "clip_seg": ("CLIP_SEG",), |
| "image": ("IMAGE",), |
| "prompt": ("STRING", { "multiline": False, "default": "" }), |
| "threshold": ("FLOAT", { "default": 0.4, "min": 0.0, "max": 1.0, "step": 0.05 }), |
| "smooth": ("INT", { "default": 9, "min": 0, "max": 32, "step": 1 }), |
| "dilate": ("INT", { "default": 0, "min": -32, "max": 32, "step": 1 }), |
| "blur": ("INT", { "default": 0, "min": 0, "max": 64, "step": 1 }), |
| }, |
| } |
|
|
| RETURN_TYPES = ("MASK",) |
| FUNCTION = "execute" |
| CATEGORY = "essentials/segmentation" |
|
|
| def execute(self, image, clip_seg, prompt, threshold, smooth, dilate, blur): |
| processor, model = clip_seg |
|
|
| imagenp = image.mul(255).clamp(0, 255).byte().cpu().numpy() |
|
|
| outputs = [] |
| for i in imagenp: |
| inputs = processor(text=prompt, images=[i], return_tensors="pt") |
| out = model(**inputs) |
| out = out.logits.unsqueeze(1) |
| out = torch.sigmoid(out[0][0]) |
| out = (out > threshold) |
| outputs.append(out) |
|
|
| del imagenp |
|
|
| outputs = torch.stack(outputs, dim=0) |
|
|
| if smooth > 0: |
| if smooth % 2 == 0: |
| smooth += 1 |
| outputs = T.functional.gaussian_blur(outputs, smooth) |
|
|
| outputs = outputs.float() |
|
|
| if dilate != 0: |
| outputs = expand_mask(outputs, dilate, True) |
|
|
| if blur > 0: |
| if blur % 2 == 0: |
| blur += 1 |
| outputs = T.functional.gaussian_blur(outputs, blur) |
| |
| |
| outputs = F.interpolate(outputs.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode='bicubic').squeeze(1) |
|
|
| return (outputs,) |
|
|
| SEG_CLASS_MAPPINGS = { |
| "ApplyCLIPSeg+": ApplyCLIPSeg, |
| "LoadCLIPSegModels+": LoadCLIPSegModels, |
| } |
|
|
| SEG_NAME_MAPPINGS = { |
| "ApplyCLIPSeg+": "🔧 Apply CLIPSeg", |
| "LoadCLIPSegModels+": "🔧 Load CLIPSeg Models", |
| } |