| | import os |
| | from typing import Optional, List |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch import autocast |
| | from diffusers import PNDMScheduler, LMSDiscreteScheduler |
| | from PIL import Image |
| | from cog import BasePredictor, Input, Path |
| |
|
| | from image_to_image import ( |
| | StableDiffusionImg2ImgPipeline, |
| | preprocess_init_image, |
| | preprocess_mask, |
| | ) |
| |
|
| | def patch_conv(**patch): |
| | cls = torch.nn.Conv2d |
| | init = cls.__init__ |
| | def __init__(self, *args, **kwargs): |
| | return init(self, *args, **kwargs, **patch) |
| | cls.__init__ = __init__ |
| |
|
| | patch_conv(padding_mode='circular') |
| |
|
| | MODEL_CACHE = "diffusers-cache" |
| |
|
| |
|
| | class Predictor(BasePredictor): |
| | def setup(self): |
| | """Load the model into memory to make running multiple predictions efficient""" |
| | print("Loading pipeline...") |
| | scheduler = PNDMScheduler( |
| | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| | ) |
| | self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
| | "CompVis/stable-diffusion-v1-4", |
| | scheduler=scheduler, |
| | revision="fp16", |
| | torch_dtype=torch.float16, |
| | cache_dir=MODEL_CACHE, |
| | local_files_only=True, |
| | ).to("cuda") |
| |
|
| | @torch.inference_mode() |
| | @torch.cuda.amp.autocast() |
| | def predict( |
| | self, |
| | prompt: str = Input(description="Input prompt", default=""), |
| | width: int = Input( |
| | description="Width of output image. Maximum size is 1024x768 or 768x1024 because of memory limits", |
| | choices=[128, 256, 512, 768, 1024], |
| | default=512, |
| | ), |
| | height: int = Input( |
| | description="Height of output image. Maximum size is 1024x768 or 768x1024 because of memory limits", |
| | choices=[128, 256, 512, 768, 1024], |
| | default=512, |
| | ), |
| | init_image: Path = Input( |
| | description="Inital image to generate variations of. Will be resized to the specified width and height", |
| | default=None, |
| | ), |
| | mask: Path = Input( |
| | description="Black and white image to use as mask for inpainting over init_image. Black pixels are inpainted and white pixels are preserved. Experimental feature, tends to work better with prompt strength of 0.5-0.7", |
| | default=None, |
| | ), |
| | prompt_strength: float = Input( |
| | description="Prompt strength when using init image. 1.0 corresponds to full destruction of information in init image", |
| | default=0.8, |
| | ), |
| | num_outputs: int = Input( |
| | description="Number of images to output", choices=[1, 4], default=1 |
| | ), |
| | num_inference_steps: int = Input( |
| | description="Number of denoising steps", ge=1, le=500, default=50 |
| | ), |
| | guidance_scale: float = Input( |
| | description="Scale for classifier-free guidance", ge=1, le=20, default=7.5 |
| | ), |
| | seed: int = Input( |
| | description="Random seed. Leave blank to randomize the seed", default=None |
| | ), |
| | ) -> List[Path]: |
| | """Run a single prediction on the model""" |
| | if seed is None: |
| | seed = int.from_bytes(os.urandom(2), "big") |
| | print(f"Using seed: {seed}") |
| |
|
| | if width == height == 1024: |
| | raise ValueError( |
| | "Maximum size is 1024x768 or 768x1024 pixels, because of memory limits. Please select a lower width or height." |
| | ) |
| |
|
| | if init_image: |
| | init_image = Image.open(init_image).convert("RGB") |
| | init_image = preprocess_init_image(init_image, width, height).to("cuda") |
| |
|
| | |
| | scheduler = PNDMScheduler( |
| | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| | ) |
| | else: |
| | |
| | scheduler = LMSDiscreteScheduler( |
| | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| | ) |
| |
|
| | self.pipe.scheduler = scheduler |
| |
|
| | if mask: |
| | mask = Image.open(mask).convert("RGB") |
| | mask = preprocess_mask(mask, width, height).to("cuda") |
| |
|
| | generator = torch.Generator("cuda").manual_seed(seed) |
| | output = self.pipe( |
| | prompt=[prompt] * num_outputs if prompt is not None else None, |
| | init_image=init_image, |
| | mask=mask, |
| | width=width, |
| | height=height, |
| | prompt_strength=prompt_strength, |
| | guidance_scale=guidance_scale, |
| | generator=generator, |
| | num_inference_steps=num_inference_steps, |
| | ) |
| | if any(output["nsfw_content_detected"]): |
| | raise Exception("NSFW content detected, please try a different prompt") |
| |
|
| | output_paths = [] |
| | for i, sample in enumerate(output["sample"]): |
| | output_path = f"/tmp/out-{i}.png" |
| | sample.save(output_path) |
| | output_paths.append(Path(output_path)) |
| |
|
| | return output_paths |
| |
|