| | import torch |
| | from torchvision import transforms |
| | from torchvision.transforms.functional import to_pil_image, to_tensor |
| | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor |
| | from diffusers import StableDiffusionInpaintPipeline, EulerDiscreteScheduler |
| | from PIL import Image |
| | import numpy as np |
| | import cv2 |
| |
|
| | class SegmentAnythingModel: |
| | def __init__(self, sam_checkpoint, model_type, device): |
| | self.sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
| | self.sam.to(device=device) |
| | self.mask_generator = SamAutomaticMaskGenerator( |
| | model=self.sam, |
| | points_per_side=32, |
| | pred_iou_thresh=0.99, |
| | stability_score_thresh=0.92, |
| | crop_n_layers=1, |
| | crop_n_points_downscale_factor=2, |
| | min_mask_region_area=100 |
| | ) |
| | self.target_size = (512, 512) |
| |
|
| | def preprocess_image(self, image): |
| | """Resize image to 512x512""" |
| | if isinstance(image, str): |
| | image = Image.open(image) |
| | elif isinstance(image, np.ndarray): |
| | image = Image.fromarray(image) |
| | |
| | |
| | width, height = image.size |
| | |
| | |
| | image = image.resize(self.target_size, Image.Resampling.LANCZOS) |
| | return np.array(image) |
| |
|
| | def generate_masks(self, image): |
| | processed_image = self.preprocess_image(image) |
| | return self.mask_generator.generate(processed_image) |
| |
|
| | class StableDiffusionInpaintingPipeline: |
| | def __init__(self, model_dir): |
| | |
| | self.scheduler = EulerDiscreteScheduler.from_pretrained(model_dir, subfolder="scheduler") |
| | |
| | |
| | self.pipe = StableDiffusionInpaintPipeline.from_pretrained( |
| | model_dir, |
| | scheduler=self.scheduler, |
| | revision="fp16", |
| | torch_dtype=torch.float16 |
| | ) |
| | self.pipe = self.pipe.to("cuda") |
| | self.pipe.enable_xformers_memory_efficient_attention() |
| | self.target_size = (512, 512) |
| |
|
| | def preprocess_image(self, image): |
| | """Ensure image is in the right format and size""" |
| | if isinstance(image, np.ndarray): |
| | image = Image.fromarray(image) |
| | return image.resize(self.target_size, Image.Resampling.LANCZOS) |
| |
|
| | def inpaint(self, prompt, image, mask_image, guidance_scale=10, num_inference_steps=60, generator=None): |
| | """ |
| | Args: |
| | prompt (str): The prompt for inpainting |
| | image (PIL.Image or np.ndarray): The original image |
| | mask_image (PIL.Image or np.ndarray): The mask for inpainting |
| | guidance_scale (float): Higher guidance scale encourages images that are closer to the prompt |
| | num_inference_steps (int): Number of denoising steps |
| | generator (torch.Generator): Generator for reproducibility |
| | """ |
| | |
| | if isinstance(image, np.ndarray): |
| | image = Image.fromarray(image) |
| | if isinstance(mask_image, np.ndarray): |
| | mask_image = Image.fromarray(mask_image) |
| |
|
| | |
| | image = image.resize(self.target_size, Image.Resampling.LANCZOS) |
| | mask_image = mask_image.resize(self.target_size, Image.Resampling.NEAREST) |
| |
|
| | |
| | output = self.pipe( |
| | prompt=prompt, |
| | image=image, |
| | mask_image=mask_image, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | generator=generator, |
| | height=512, |
| | width=512 |
| | ) |
| | |
| | return output.images[0] |
| |
|